diff --git a/.gitignore b/.gitignore index 7fa359a..37c40aa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,39 +1,114 @@ -__pycache__/ -*.py[cod] *$py.class +*.cover +*.egg *.egg-info -.pytest_cache -.ipynb_checkpoints - -thumbs.db -.DS_Store -.idea +*.egg-info/ +*.flac +*.gz *.log -*rtx* -*.pdf +*.manifest *.mkv +*.mo +*.mp3 *.mp4 -*a40* -*durip* +*.pdf *.png -sim_lr.ipynb -*.mp3 -*.gz -*.flac -*.th -*.pth +*.pot *.pt -local_* +*.pth +*.py,cover +*.py[cod] +*.sage.py +*.so +*.spec +*.th +*a40* +*durip* +*rtx* +.cache +.coverage +.coverage.* +.dmypy.json +.DS_Store +.eggs/ +.env +.hypothesis/ +.idea +.installed.cfg +.ipynb_checkpoints +.mypy_cache/ +.nox/ +.pdm-build/ +.pdm-python +.pdm.toml +.pybuilder/ +.pypirc +.pyre/ +.pytest_cache +.pytest_cache/ +.Python +.pytype/ +.ropeproject +.ruff_cache/ +.scrapy +.spyderproject +.spyproject +.tox/ +.venv +.webassets-cache +/site +amt/ +bad_files/ +build/ +celerybeat-schedule +celerybeat.pid +cover/ +coverage.xml +cython_debug/ +db.sqlite3 +db.sqlite3-journal +demo/generated_tts +develop-eggs/ +dist/ +dmypy.json +docs/_build/ +downloads/ +eggs/ +env.bak/ +env/ +ENV/ +file_log.txt +file_log_debug*.txt +htmlcov/ hub/ +instance/ +ipython_config.py +lib/ +lib64/ +local_* +local_settings.py +MANIFEST +nosetests.xml +parts/ per_sample_res/ -src/ -res/seed_tts_eval/ +pip-delete-this-directory.txt +pip-log.txt +profile_default/ res/lspc_eval -file_log.txt -file_log_debug*.txt -bad_files/ -amt/ +res/seed_tts_eval/ sam1.wav sam2.wav sam3.wav -demo/generated_tts \ No newline at end of file +sdist/ +share/python-wheels/ +sim_lr.ipynb +src/ +target/ +thumbs.db +var/ +venv.bak/ +venv/ +wheels/ +__pycache__/ +__pypackages__/ +generated_tts/ \ No newline at end of file diff --git a/MODEL-LICENSE b/LICENSE-MODEL similarity index 100% rename from MODEL-LICENSE rename to LICENSE-MODEL diff --git a/README.md b/README.md index 3b9096c..234222a 100644 --- a/README.md +++ b/README.md @@ -1,79 +1,71 @@ -# VoiceStar: Robust, Duration-controllable TTS that can Extrapolate +# VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate -## TODO -- [x] Gradio demo ETA: 6 April 2025 -- [ ] Research Paper: 7 April 2025 - 14 April 2025 +VoiceStar is a robust, duration-controllable TTS model with support for test-time extrapolation, meaning it can generate speech longer than the duration it was trained on. + +## Features + +- **Duration control**: Specify the duration of the generated speech. +- **Zero-shot voice cloning**: Clone any voice with a short reference audio clip ([demo video](https://x.com/PuyuanPeng/status/1908822618167300419)). + +Coming soon: research paper (ETA: 7 April 2025 - 14 April 2025) + +## Quick Start + +### Install -## 1. Env setup -### Download model -```bash -# under VoiceStar root dir -wget -O ./pretrained/encodec_6f79c6a8.th https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th?download=true -wget -O ./pretrained/VoiceStar_840M_30s.pth https://huggingface.co/pyp1/VoiceStar/resolve/main/VoiceStar_840M_30s.pth?download=true -wget -O ./pretrained/VoiceStar_840M_40s.pth https://huggingface.co/pyp1/VoiceStar/resolve/main/VoiceStar_840M_40s.pth?download=true -``` -### Inference only: ```bash -conda create -n voicestar python=3.10 -conda activate voicestar # this seems to lead to much worse results in terms of wer and spksim (comparing e9_rerun and e9_rerun_newba_upgraded) -pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 -pip install numpy, tqdm, fire -pip install phonemizer==3.2.1 -apt-get install espeak-ng # backend for the phonemizer -pip install torchmetrics -pip install einops -pip install omegaconf==2.3.0 -pip install openai-whisper -pip install gradio +pip install voicestar ``` -* avoid warnings likes -[WARNING] words_mismatch.py:88 || words count mismatch on 200.0% of the lines (2/1) -```python -# go to ~/miniconda3/envs/voicestar/lib/python3.10/site-packages/phonemizer/backend/espeak/words_mismatch.py -# pass the warning like this - def _resume(self, nmismatch: int, nlines: int): - """Logs a high level undetailed warning""" - pass - # if nmismatch: - # self._logger.warning( - # 'words count mismatch on %s%% of the lines (%s/%s)', - # round(nmismatch / nlines, 2) * 100, nmismatch, nlines) -``` +Make sure you also have `espeak-ng` installed. + +**Note:** If you run into issues installing VoiceStar with `uv`, try installing it with `pip` instead. + +### Usage + +Basic usage: -### Training and data processing -*additional packages*: ```bash -pip install huggingface_hub -pip install datasets -pip install tensorboard -pip install wandb -pip install matplotlib -pip install ffmpeg-python -pip install scipy -pip install soundfile +voicestar --reference-speech "./demo/5895_34622_000026_000002.wav" --target-text "I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long." --target-duration 8 ``` -## 2. example -### command line example -check signature of `run_inference` func in `inference_commandline.py` for adjustable hyperparameters +Please refer to the CLI and Python API documentation below for more advanced usage. + +_Note: Both CUDA, CPU, and MPS (Apple Silicon) are supported._ + +## Training + +Please refer to the [training docs](docs/training.md) for more information. + +## Inference + +### CLI + ```bash -# under root dir -conda activate voicestar -python inference_commandline.py \ - --reference_speech "./demo/5895_34622_000026_000002.wav" \ - --target_text "I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long." \ - --target_duration 8 +voicestar --reference-speech "./demo/5895_34622_000026_000002.wav" --target-text "I cannot believe that the same model can also do text to speech synthesis too!" ``` -### Gradio +View all available options: + ```bash -conda activate voicestar -python inference_gradio.py +voicestar --help ``` +### Python API + +```python +from voicestar import VoiceStar + +# Initialize the model +model = VoiceStar() + +# Generate speech from text +audio = model.generate("I cannot believe that the same model can also do text to speech synthesis too!") +audio.save("output.wav") +``` ## License -Code license: MIT -Model Weights License: CC-BY-4.0 (as Emilia dataset we used is under this license) \ No newline at end of file +The code in this repo is licensed under the MIT license. The pretrained model weights available on Hugging Face are licensed under the CC-BY-4.0 license. + +This repository may contain third-party software which may be licensed under different licenses. \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000..c5d09e7 --- /dev/null +++ b/app.py @@ -0,0 +1,96 @@ +import gradio as gr +import torch +import os +from voicestar import VoiceStar +from voicestar.utils import seed_everything +from txtsplit import txtsplit +import numpy as np + +ABOUT = """ +# VoiceStar TTS + +Gradio demo for [VoiceStar](https://github.com/jasonppy/VoiceStar): robust, duration-controllable TTS that can extrapolate. +""" + +# Initialize model once outside the function for better performance +model = VoiceStar() + + +def generate_audio( + reference_speech, + text, + duration=0.0, + top_k=10, + temperature=1.0, + repeat_prompt=1, + seed=1, + progress=gr.Progress(), +): + # Set seed for reproducibility if provided + if seed > 0: + seed_everything(seed) + + # Update model parameters if needed + model.api.top_k = top_k + model.api.temperature = temperature + model.repeat_prompt = repeat_prompt + + # Generate speech + target_duration = None if duration <= 0 else duration + texts = txtsplit(text) + + audios = [] + for t in progress.tqdm(texts, desc=f"Generating audio in {len(texts)} chunks"): + audio = model.generate( + reference_speech=reference_speech, text=t, target_duration=target_duration + ) + audios.append(audio.waveform.squeeze().numpy()) + + audio = np.concatenate(audios) + + # Return audio for gradio + return (16000, audio) + + +with gr.Blocks() as demo: + gr.Markdown(ABOUT) + inp_ref = gr.Audio(label="Reference Audio", type="filepath") + inp_text = gr.Textbox(label="Text to synthesize") + + with gr.Accordion("Advanced Settings", open=False): + inp_reference_text = gr.Textbox( + label="Reference Text", + info="Enter a transcription of the reference audio. This is optional - if not provided, the model will transcribe the audio automatically.", + ) + inp_duration = gr.Number( + label="Duration", + info="Set to 0 to automatically estimate duration", + value=0.0, + ) + inp_top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=10) + inp_temp = gr.Slider( + label="Temperature", minimum=0.0, maximum=2.0, step=0.01, value=1.0 + ) + inp_repeat_prompt = gr.Slider( + label="Repeat prompt", minimum=1, maximum=10, step=1, value=1 + ) + inp_seed = gr.Number(label="Seed", info="Set to 0 to use random seed", value=1) + + btn_generate = gr.Button("Generate", variant="primary") + out_audio = gr.Audio(label="Generated Audio") + + btn_generate.click( + fn=generate_audio, + inputs=[ + inp_ref, + inp_text, + inp_duration, + inp_top_k, + inp_temp, + inp_repeat_prompt, + inp_seed, + ], + outputs=[out_audio], + ) + +demo.queue().launch() diff --git a/config.py b/config.py deleted file mode 100644 index 50d241b..0000000 --- a/config.py +++ /dev/null @@ -1,254 +0,0 @@ -import argparse - -def int_or_str(value): - """Custom function to allow both int and str types.""" - try: - return int(value) # Try converting to integer - except ValueError: - return value # If conversion fails, return as string - -def MyParser(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - # general training - parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--debug", type=int, default=0) - parser.add_argument("--multinodes", type=int, default=0) - parser.add_argument("--dist_url", default="env://", type=str) - parser.add_argument("--dist_backend", default="nccl", type=str) - parser.add_argument("--precision", type=str, default="float16", help="we might need float32 for NAR model") - parser.add_argument("--num_workers", type=int, default=8, help="per gpu") - parser.add_argument("--resume", action="store_true", default=False) - parser.add_argument("--tb_write_every_n_steps", type=int, default=100) - parser.add_argument("--print_every_n_steps", type=int, default=250) - parser.add_argument("--val_every_n_steps", type=int, default=500) - parser.add_argument("--inference_every_n_steps", type=int, default=3000, help="will only get to inference when model is saved, and therefore this needs to be multiple of val_every_n_steps") - parser.add_argument("--save_every_n_steps", type=int, default=10000000, help="save the model every n steps, will save the model as bundle_step$step.pth") - parser.add_argument("--lr", type=float, default=1e-4) - parser.add_argument("--batch_size", type=int, default=100, help="this is the effective batch size per gpu, no matter whether using gradient_accumulation_steps") - parser.add_argument("--weight_decay", type=float, default=1e-2) - parser.add_argument("--warmup_fraction", type=float, default=0.1, help="use linear warmup, the proportion of the training steps that are used for warming up") - parser.add_argument("--num_epochs", type=int, default=10) - parser.add_argument("--num_steps", type=int, default=None, help="if not None, will ignore n_epochs and use num_steps as the total number of amount of training, can try e.g. 400000 i.e. 400k steps") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--gradient_clip_val", type=float, default=1.0, help="the value for torch.nn.utils.clip_grad_norm_()") - parser.add_argument("--early_stop_step", type=int, default=3200, help="stop training after this many steps of non-improvement") - parser.add_argument("--early_stop_threshold", type=float, default=-1.0, help="early stop after the improvement is below this threshold for certain number of steps") - - - # path - parser.add_argument("--exp_dir", type=str, default='/saltpool0/scratch/pyp/VoiceEditor/', help="will be combined with dataset name") - parser.add_argument("--dataset", type=str, help="e.g. 'libritts', 'librilight', 'spotify', they are folder name in the data dir also") - parser.add_argument("--dataset_dir", type=str, help="need to be compatible with corresponding dataset py file") - parser.add_argument("--compact_folder_name", type=str, default=None, help="if not None, will use compact_combined_dataset.py, and this is the folder name of the compact dataset") - parser.add_argument("--inference_dataset_dir", type=str, default="/data/scratch/pyp/datasets/librilight/preprocessed", help="need to be compatible with corresponding dataset py file") - - parser.add_argument("--training_stage", type=int, default=1, help="if 1, train VoiceEditor_one, if 2 train VoiceEditor_seven") - parser.add_argument("--local_wandb", type=int, default=0, help="if 1, will use local wandb, otherwise use the global one") - parser.add_argument("--wandb_entity", type=str, default="puyuanpeng", help="the entity (usually your username) for wandb") - # data - parser.add_argument("--librilight_ratio", type=float, default=1, help='the portion of lightlight compared to gigaspeech, 1 means equal, 2 means librilight data is twice as much as gigaspeech') - parser.add_argument("--plus_librilight_root", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the root folder to librilight. Note that will need to merge the vocab.txt based on gigaspeech's, in order to be able to load a pretrained model") - parser.add_argument("--plus_librilight_phn_folder_name", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the phoneme folder name of librilight") - parser.add_argument("--plus_librilight_encodec_folder_name", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the encodec folder name of librilight") - parser.add_argument("--plus_librilight_manifest_name", type=str, default=None, help="if not None, will combine gigaspeech and librilight, this is the manifest folder name of librilight") - parser.add_argument("--skip_us", type=int, default=0, help="skip the giga utterances that contains 'j uː ɛ s' because of the tokenization issue") - parser.add_argument("--pseudo_epoch_size", type=int, default=37901, help="only use for Eden scheduler. 37901 is the epoch size in the default optim setting, this is probably too big") - parser.add_argument("--switch_order", type=int, default=0, help="this is only for hificodec, where we switch the order of 2 and 3nd codebook") - parser.add_argument("--phn_folder_name", type=str, default="phoneme", help="for libritts I also have arpa phns, in which case should be phonemes_arpa") - parser.add_argument("--encodec_folder_name", type=str, default="mimi_8cb", help="folder where encodec codes are stored") - parser.add_argument("--manifest_name", type=str, default="manifest_final", help="if using hificodec, it should be hificodec_menifest, if using encodec, it is the default") - parser.add_argument("--pad_x", type=int, default=1, help="whether or not always pad x to have text_max_length. select 1 to get the maximal memory consumption, but the actual case should be smaller, better to have it being 0") - parser.add_argument("--max_num_tokens", type=int, default=18750, help="max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note that batch size is the final effective batch size (sum of batch on each gpu), but max_num_tokens is per gpu") - parser.add_argument("--val_max_num_tokens", type=int, default=6000, help="FOR validation, this basically is for music-gen because of high mem consumption. max number of encodec tokens per gpu, this is only used when using dynamic batching, will ignore batch size. Note that batch size is the final effective batch size (sum of batch on each gpu), but max_num_tokens is per gpu") - parser.add_argument("--num_buckets", type=int, default=10) - parser.add_argument("--dynamic_batching", type=int, default=1) - parser.add_argument("--audio_max_length", type=float, default=120, help="in second, crop the audio is length is longer than this") - parser.add_argument("--audio_min_length", type=float, default=2, help="in second, drop the audio if length is shorter than this") - parser.add_argument("--text_max_length", type=int, default=1000, help='if too long, we crop') - parser.add_argument("--text_min_length", type=float, default=10, help="if too short, will drop") - parser.add_argument("--encodec_sr", type=float, default=50, help="for 24kHz mimi model, it produces 12.5 codes for 1 sec of audio") - parser.add_argument('--mask_len_min', type=int, default=20, help='Minimum mask length') - parser.add_argument('--mask_len_max', type=int, default=400, help='Maximum mask length') - parser.add_argument('--extra_mask_len_min', type=int, default=2, help='Minimum extra mask length') - parser.add_argument('--extra_mask_len_max', type=int, default=20, help='Maximum extra mask length') - parser.add_argument('--final_audio_token_len', type=int, default=772, help="this is only for stage 1 training, since we add eog, start_of_continue, and a random amount of extra mask, --audio_max_length won't be the final max length, the self.args.final_audio_token_len = self.args.audio_max_length*self.args.encodec_sr+self.args.extra_mask_len_max+2 ") - - # model - parser.add_argument("--ttsonly", default=0, type=int, help="if 1, only train tts model, no CM3") - parser.add_argument("--load_existing_text_embedding", type=int, default=0, help="if 1, when load model and the text vocab doesn't match, will load the existing weights while the new weights will be initialized randomly") - parser.add_argument("--fly", type=int, default=0, help="if 1, encode chunked audio on the fly") - parser.add_argument("--encodec_ckpt", type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th") - parser.add_argument("--downsample_rate", type=int, default=320, help="the downsample rate for the encodec model, 16000/320 = 50Hz") - parser.add_argument("--segtts_mask", type=int, default=0, help="if 1, use segtts_mask model, where we have a prefix and segment utterance into two and shifted separately for modeling, and use make use of mask:0, by insert two mask:0 in the middle of the two segments") - parser.add_argument("--segtts", type=int, default=0, help="if 1, use segtts model, where we have a prefix and segment utterance into two and shifted separately for modeling") - parser.add_argument("--edge", type=int, default=0, help="if 1, use edge prediction for the first codebook") - parser.add_argument("--duration_loss_weight", type=float, default=1.0, help="weight on the duration loss") - parser.add_argument("--drop_long", type=int, default=1, help="if this is true, will drop example whose encodec sequence or phone sequence is too long, rather than cropping as we did before, to avoid hellucination") - parser.add_argument("--eos", type=int, default=2051, help="this is to be used with reduced_eog, where we end the utterance with eos, and end the generated segment with eog, also when this is used, the n_special should be 4") - parser.add_argument("--reduced_eog", type=int, default=1, help="for the non-final segments, do not insert eog at the end, this could hopefully solve the early stopping issue when doing tts") - - parser.add_argument("--valle_orig", type=int, default=0, help="the original valle model, trained for TTS") - parser.add_argument("--valle_max_prompt_len", type=float, default=6, help='in sec.') - # randomly choose a portion as tts examples during training - parser.add_argument("--tts_portion", type=float, default=0, help="randomly choose a portion of the training examples as tts examples, where no mask and rearrangement is used") - - # put special tokens first to handle different vocab_size - parser.add_argument("--special_first", type=int, default=0, help="if 1, need to have special tokens to be the first few tokens, e.g. 0, 1, 2, which means we need to adjust the preprocessing and postprocessing of the encodec codes. note that we hard coded to have 3 special tokens") - parser.add_argument("--n_special", type=int, default=4, help="empty, eog, pad, eos") - - # weight codebook differently - parser.add_argument("--codebook_weight", type=str, default=None, help="e.g. ['5','1','0.5','0.1']") - - # args for MusicGen - parser.add_argument("--mask_span_weight", default=1.0, type=float, help="the weight on the tokens in masked span") - parser.add_argument("--unmask_span_weight", default=1.0, type=float, help="the weight on unmasked span") - parser.add_argument("--start_end_weight", default=None, type=str, help="weight the start x tokens and end x tokens differently, e.g. (10,2.0), means x == 10, weight==2.0") - # for now not consider the two weights above, only consider eog_weight, which is defined below somewhere, as the above two are not super principled - - parser.add_argument("--musicgen", type=int, default=0, help="whether or not use this model, will also have an impact on the output shape of the dataset") - parser.add_argument("--enc_dec", default=0, type=int, help="use enc-dec architecture, text is from the enc, only for musicgen") - parser.add_argument("--dec", default=0, type=int, help="use dec only architecture, text is from the enc, only for musicgen. Exclusive with --enc_dec") - parser.add_argument("--empty_token", default=2048, type=int, help="indicating the no token at the position for the codebook") - # args for the optimizer and scheduler from Feiteng - # original setup for the 3 params are 5000 4 and 1000 - # but that's because set_epoch is run on num_gradient_accumulation_step*step (with 4 being the accumulation step) - # so I scaled down them a little bit - # will try scaling them back if this doesn't work - parser.add_argument("--optimizer_name", type=str, default="AdamW", help="can also use ScaledAdam, in which case we'll also use the Eden scheduler") - parser.add_argument("--reduce_lr_start_step", type=int, default=3000, help='after which significantly reduce the lr. a param for the eden optimizer') - parser.add_argument("--reduce_lr_start_epoch", type=int, default=4) - parser.add_argument("--clipping_update_period", type=int, default=600) - - - # below are args for valle - # below are args for valle - parser.add_argument("--valle", type=int, default=0, help="if 1, use valle model (cm3)") - parser.add_argument("--decoder_dim", type=int, default=1024) - parser.add_argument("--norm_first", action="store_true", default=True) - parser.add_argument("--add_prenet", action="store_true", default=False) - parser.add_argument("--prefix_mode", type=int, default=5, help="this is for NAR, we only do 5, which is CM3") - parser.add_argument("--share_embedding", action="store_true", default=False) - parser.add_argument("--nar_scale_factor", type=float, default=1.0) - parser.add_argument("--prepend_bos", action="store_true", default=False) - parser.add_argument("--sync_nar", type=int, default=0, help="whether to choose the same NAR model to run for training_stage==2 across different process (this is only for DDP)") - # above are args for valle - # above are args for valle - - - - # add parallel_pattern - parser.add_argument("--parallel_pattern", type=int, default=0, help="if 1, use parallel pattern, we also use LFSC codec") - parser.add_argument("--full_prediction", type=int, default=0, help='this is for ve1, if 1, use full autoregressive mask, and calculate loss over all tokens, except for mask_tokens') - parser.add_argument("--multicm3", type=int, default=0, help='cm3 model but allows multiple mask spans') - parser.add_argument("--max_mask_portion",type=float,default=0.7,help="should mask a utterance for more than this portion") - parser.add_argument("--max_n_spans", type=int, default=8, help='maximal number of spans, only use when using multicm3, this is used to decide number of mask_embedding, and max clamp value if use Poisson distribution, if use uniform distribution to sample number of spans if will be uniform(1,max_n_spans)') - parser.add_argument("--shuffle_mask_embedding", type=int, default=0, help="whether shuffle the mask embedding, so that mask:0 is not the most well trained, default is not shuffling. The default has it's benefit, as it make sure that mask:0 always appear the first") - parser.add_argument("--mask_sample_dist", type=str, default="uniform", help="uniform or poissonx, e.g. poisson1, meaning the parameter lambda is 1, it will most likely sample 1 masks") - parser.add_argument("--min_gap", type=int, default=10, help="after sampled starts, delete later one if it closer to the former start than the min_gap") - - parser.add_argument('--cm3', type=int, default=0, help="use cm3 style for ve1, the input from dataloader is going to be just raw data, all masking and rearrangement will happen whin the model") - - parser.add_argument('--sep_special_token', type=int, default=0, help="remove text/audio pad token, set audio_mask_token and start of continue to be separately learned embeddings. Therefore, for ve1 self.n_text_tokens == self.args.text_vocab_size, self.n_audio_tokens == self.args.audio_vocab_size + 2, for ve7, self.n_text_tokens == self.args.text_vocab_size, self.n_audio_tokens == self.args.audio_vocab_size") - parser.add_argument('--one_causal', type=int, default=0, help="whether model VE_one generation as autoregressive gen or non-autoregressive gen") - parser.add_argument('--n_codebooks', type=int, default=8) - parser.add_argument('--weight_sharing', type=int, default=0, help="sharing weights between VE_seven predict layer and embedding layer") - parser.add_argument('--text_vocab_size', type=int, default=86, help='Size of text vocabulary') - parser.add_argument('--text_pad_token', type=int, default=86, help='padding of the text tokens, not attended') - # parser.add_argument('--audio_vocab_size', type=int, default=1024, help='Size of audio vocabulary') - parser.add_argument('--audio_vocab_size', type=str, default='2048', help="Size of audio vocabulary, can be specified as '[128,512,1024,2048]'") - parser.add_argument('--audio_mask_token', type=int, default=1024, help='Audio mask token, this the the extra mask used in the masked region for AR, for NAR, the entire masked region will be filled with it') - parser.add_argument('--bog', type=int, default=1025, help='Begin of generation token') - parser.add_argument('--eog', type=int, default=2049, help='End of generation token') - parser.add_argument('--start_of_continue', type=int, default=1027, help='this token follows the masked region, proceeds the first unmasked token, to indicate that gt tokens starts') - parser.add_argument('--audio_pad_token', type=int, default=2050, help='padding of the encodec codes, not attended') - parser.add_argument('--d_model', type=int, default=1024, help='Model dimension') - parser.add_argument('--audio_embedding_dim', type=int, default=128, help='dimension for encodec continues embedding (before being quantized)') - parser.add_argument('--text_embedding_dropout', type=float, default=0.1, help='Dropout for text embedding') - parser.add_argument('--audio_embedding_dropout', type=float, default=0, help='Dropout for audio embedding') - parser.add_argument('--text_positional_embedding_dropout', type=float, default=0.1, help='Dropout for text positional embedding') - parser.add_argument('--audio_positional_embedding_dropout', type=float, default=0.1, help='Dropout for audio positional embedding') - parser.add_argument('--trm_dropout', type=float, default=0.1, help='Dropout for transformer') - parser.add_argument('--nhead', type=int, default=16, help='Number of attention heads') - parser.add_argument('--num_encoder_layers', type=int, default=12, help='Number of encoder layers') - parser.add_argument('--num_decoder_layers', type=int, default=12, help='Number of decoder layers') - parser.add_argument('--eog_weight', type=float, default=1.0, help='Weight for End of generation token') - parser.add_argument('--stage_one_load_encodec_embedding', type=str, default=None, help='Path to load encodec embedding for stage one. On our lab machine it is /saltpool0/scratch/pyp/VoiceEditor/encodec_embedding/24khz_8codebooks.pth, 8 is the n_codebooks') - parser.add_argument('--stage_two_load_encodec_embedding', type=str, default=None, help='Path to load encodec embedding for stage two。 On our lab machine it is /saltpool0/scratch/pyp/VoiceEditor/encodec_embedding/24khz_8codebooks.pth, 8 is the n_codebooks') - parser.add_argument('--stage_two_load_ve_one_embedding', type=str, default=None, help='Path to load VoiceEditor_one audio embedding for stage two') - parser.add_argument('--load_model_from', type=str, default=None, help='Path to load model from, this will be effective last, so will overwrite all previous load, including resume') - parser.add_argument('--load_model_from_ve1', type=str, default=None, help='Path to load ve1 model weights from, this will be effective last, designed for loading the encoder weights of the VE7 from a pretrained VE1') - - - ## below are args for the new long model - parser.add_argument("--target_time_stretch_prob", type=float, default=0, help="the probability of time stretching the target audio") - parser.add_argument("--target_time_stretch_bound", type=float, default=0.1, help="the bound of the time stretching target audio, e.g. 0.1 means the audio will be stretched by 0.9 to 1.1") - parser.add_argument("--time_stretch_prob", type=float, default=0, help="the probability of time stretching the audio") - parser.add_argument("--time_stretch_bound", type=float, default=0.3, help="the bound of the time stretching, e.g. 0.3 means the audio will be stretched by 0.7 to 1.3") - parser.add_argument("--no_loss_on_prefix", type=int, default=0, help="if 1, will not calculate loss on the prefix acoustic tokens") - parser.add_argument("--x_sep_token", type=int, default=None, help="if not None, will use this token in between prompt text and target generation text") - parser.add_argument("--y_sep_token", type=int, default=None, help="if not None, will use this token in between prompt codec tokens and target codec tokens") - parser.add_argument("--neighbor_prompt_prob", type=float, default=0, help="the probability of using the prompt from the neighbor") - parser.add_argument("--neighbor_folder_name", type=str, default='neighbors',help="folder where the neighbors of the current audio files are stored, each row contains three tab separated entries: neighbor_fn, neighbor_temporal_distance, neighbor_duration") - parser.add_argument("--alignment_folder_name", type=str, default='alignment', help="folder where the forced alignment of the current audio files are stored, in csv format, each row contains five comma separated entries: begin, end, label, type, speaker, the first row is header") - parser.add_argument("--ipa_alignment_folder_name", type=str, default='ipa_alignment', help="folder where the forced alignment of the current audio files are stored, in txt format, each row contains three tab separated entries: begin, end, ipa phn sequence, generated using data/ll60k_preprocessing/step7_ipa_alignment.py") - parser.add_argument("--max_prompt_len", type=float, default=30, help="in sec., maximal prompt length selected from some neighboring file") - parser.add_argument("--min_prompt_len", type=float, default=0.5, help="in sec., minimal prompt length selected from some neighboring file") - parser.add_argument("--neighbor_selection_method", type=str, default="maxdist_60", help="maxdist_60 means uniformly select a neighbor that's within 60 sec of the current audio file") - parser.add_argument("--num_trial", type=int, default=5, help="number of tries to select a neighbor") - parser.add_argument("--prompt_start_from_begining_prob", type=float, default=0.5, help="the probability of starting the prompt from the beginning of the neighbor") - parser.add_argument("--min_alignment_len", type=int, default=5, help="in number of words") - parser.add_argument("--audio_folder_name", type=str, default='audio', help="folder where the audio files are stored") - - # rope parameters - parser.add_argument("--decoder_regular_rope", type=int, default=0, help="if 1, will use regular rope for the decoder (note that we always use regular rope for encoder). ") - parser.add_argument("--progress_no_multiple", type=int, default=0, help="if 1, will not multiple the percentage progress by the length of the key, see apply_rotary_pos_emb in models/modules/activation.py, this applies to both rope and sinusoidal positional encoding. Note that progress scale is still applied, i.e. when we only apply progress scale, but not multiple, the scaling factor is constant for every sample, rather than sample dependent") - parser.add_argument("--add_eos_to_text", type=int, default=0, help="if not 0, use this number as eos and add to the end of text token, usually use the second to last token in the vocab size") - parser.add_argument("--add_bos_to_text", type=int, default=0, help="if not 0, use this number as bos and add to the begining of text token, usually use the third to last token in the vocab size") - parser.add_argument("--use_sinusoidal", type=int, default=0, help="if 1, will use sinusoidal positional encoding, otherwise use rope. BUT if rope_base is None, will use sinusoidal") - parser.add_argument("--sinusoidal_base", type=int, default=1e4, help="the base of the exponential function, default is 1e4") - parser.add_argument("--use_sinusoidal_progress", type=int, default=0, help="if 1, will use sinusoidal positional encoding for progress, otherwise use rope") - parser.add_argument("--rope_base", type=int, default=None, help="the base of the exponential function, default is 1e4, if None, will not use rope") - parser.add_argument("--multiple_key_length", type=int, default=0, help="if 1, during progress calculation, will multiple the precentage progress by the length of the key, otherwise multiple with length of query. see models/rope_playground.ipynb") - parser.add_argument("--progress_scale", type=float, default=1.0, help="scale the progress, the smaller the value, the bigger the diagonal in attention score, see models/rope_playground.ipynb") - - # attention alignment loss - parser.add_argument("--attention_alignment_loss", type=float, default=0.0, help="the weight on the attention alignment loss, if 0, will not calculate the loss") - parser.add_argument("--alignment_loss_layer", type=str, default="['0-1', '2', '3']", help='the layers to calculate the alignment loss, e.g. ["0-1", "2", "3"]') - parser.add_argument("--alignment_loss_head", type=str, default="['0-1', '2', '3']", help='the attention heads to calculate the alignment loss, e.g. ["0-1", "2", "3"]') - parser.add_argument("--alignment_blank_logit", type=float, default=-1.0, help="the logit for the blank token added to the attention weights") - - # inference parameters - parser.add_argument("--metrics", type=str, default="['spk_sim','wer','mcd','pitch','energy','pesq','utmos']") - parser.add_argument("--res_jsonl_root", type=str, default="/home/pyp/BoostedVoiceEditor/res") - parser.add_argument("--res_name", type=str, default="2jan25.jsonl") - parser.add_argument("--inference_seed", type=int, default=1) - parser.add_argument("--codec_audio_sr", type=int, default=16000) - parser.add_argument("--codec_sr", type=float, default=50) - parser.add_argument("--top_k", type=int, default=0) - parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--temperature", type=float, default=1) - parser.add_argument("--silence_tokens", type=list, default=[]) - parser.add_argument("--kvcache", type=int, default=0) - parser.add_argument("--stop_repetition", type=int, default=3) - parser.add_argument("--sample_batch_size", type=int, default=1) - parser.add_argument("--inference_manifest_fns", type=str, default="['/home/pyp/BoostedVoiceEditor/manifests/debug.jsonl']") - parser.add_argument("--use_gt_duration", type=int, default=1) - parser.add_argument("--save_root", type=str, default="/data/scratch/pyp/exp_pyp/BoostedVoiceEditor/gens") - parser.add_argument("--encodec_signature", type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th") - parser.add_argument("--extra_cutoff", type=float, default=5, help="in rare cases where the model doesn't follow specified target duration (only happened in extrapolation cases), we will terminate generation once the extra duration exceeds this value") - parser.add_argument("--duration_margin", type=float, default=0.04, help="used along with extra_cutoff, when extra_cutoff is used (i.e. model doesn't follow specified target_duration), we terminate the generate, and cut the results to target_duration + duration_margin") - # add repeat_prompt and asr_model_name - parser.add_argument("--repeat_prompt", type=int_or_str, default=0, help="if 1, will repeat the prompt for each segment") - parser.add_argument("--asr_model_name", type=str, default="w2v2", help="the name of the asr model, if not None, will use the asr model to generate the prompt") - - # depth transformer parameters - parser.add_argument("--depth_dec_num_layers", type=int, default=0) - parser.add_argument("--depth_dec_d_model", type=int, default=768) - parser.add_argument("--depth_dec_nhead", type=int, default=12) - parser.add_argument("--moshi_depth", type=int, default=0, help="if 1, will use the same parameterization as moshi, i.e. temporal trm output will gets added to every transformed token embedding") - - parser.add_argument("--validation_sample_cap", type=int, default=None, help="cap the validation data to this number") - parser.add_argument("--no_libri_in_training", type=int, default=None, help="if 1, will not use librilight in training, only use in validation") - parser.add_argument("--uniform_weight_start_step", type=int, default=1e50, help="set all codebook weight to be uniform starting from this step") - - return parser \ No newline at end of file diff --git a/data/combined_dataset.py b/data/combined_dataset.py index b7925cf..70c12ea 100644 --- a/data/combined_dataset.py +++ b/data/combined_dataset.py @@ -11,6 +11,8 @@ import glob import numpy as np from data.tokenizer import TextTokenizer, tokenize_text, AudioTokenizer + + def find_files(root_dir, endswith=".wav"): files = [] # os.walk generates the file names in a directory tree @@ -24,16 +26,25 @@ def find_files(root_dir, endswith=".wav"): files.append(full_path) return files + class dataset(torch.utils.data.Dataset): def __init__(self, args, split): super().__init__() self.args = args - self.args.target_time_stretch_prob = getattr(self.args, "target_time_stretch_prob", 0) - self.args.target_time_stretch_bound = getattr(self.args, "target_time_stretch_bound", 0.1) + self.args.target_time_stretch_prob = getattr( + self.args, "target_time_stretch_prob", 0 + ) + self.args.target_time_stretch_bound = getattr( + self.args, "target_time_stretch_bound", 0.1 + ) self.split = split - - assert self.split in ['train', 'valid', 'test'], f"split should be one of ['train', 'valid', 'test'], but it's {split}" - + + assert self.split in [ + "train", + "valid", + "test", + ], f"split should be one of ['train', 'valid', 'test'], but it's {split}" + if "[" not in self.args.dataset_dir or "]" not in self.args.dataset_dir: self.dataset_dir = f"['{self.args.dataset_dir}']" else: @@ -46,29 +57,45 @@ def __init__(self, args, split): self.args.manifest_name = copy.deepcopy(self.args.manifest_name) self.manifest_name = eval(self.args.manifest_name) if len(self.manifest_name) != len(self.dataset_dir): - assert len(self.manifest_name) == 1, f"len(self.manifest_name) should be 1 or equal to len(self.dataset_dir), but it's {len(self.manifest_name)}" + assert ( + len(self.manifest_name) == 1 + ), f"len(self.manifest_name) should be 1 or equal to len(self.dataset_dir), but it's {len(self.manifest_name)}" self.manifest_name = self.manifest_name * len(self.dataset_dir) for i_data, dataset_dir in enumerate(self.dataset_dir): - if getattr(self.args, "no_libri_in_training", None) != None and ("librilight" in dataset_dir) and self.split == "train": + if ( + getattr(self.args, "no_libri_in_training", None) != None + and ("librilight" in dataset_dir) + and self.split == "train" + ): if not dist.is_initialized() or dist.get_rank() == 0: logging.info(f"skipping librilight in training split") continue n_datapoints = 0 - manifest_fn = os.path.join(dataset_dir, self.manifest_name[i_data], self.split+".txt") + manifest_fn = os.path.join( + dataset_dir, self.manifest_name[i_data], self.split + ".txt" + ) if not os.path.isfile(manifest_fn): all_manifest_fn = glob.glob(manifest_fn.replace(".txt", "_*=*.txt")) if len(all_manifest_fn) == 0: - logging.info(f"no manifest file found for {split} split in {dataset_dir}") + logging.info( + f"no manifest file found for {split} split in {dataset_dir}" + ) continue if self.args.debug: - logging.info(f"debugging mode, only using the frist found manifest file: {all_manifest_fn[0]}") + logging.info( + f"debugging mode, only using the frist found manifest file: {all_manifest_fn[0]}" + ) all_manifest_fn = all_manifest_fn[:1] else: if dist.is_initialized() and dist.get_rank() == 0: - logging.info(f"Combining found manifest files for {split}: {all_manifest_fn}") + logging.info( + f"Combining found manifest files for {split}: {all_manifest_fn}" + ) for cur_manifest_fn in all_manifest_fn: with open(cur_manifest_fn, "r") as rf: - tmp = [l.strip().split("\t") + [i_data] for l in rf.readlines()] # i_data is the index of the dataset + tmp = [ + l.strip().split("\t") + [i_data] for l in rf.readlines() + ] # i_data is the index of the dataset n_datapoints += len(tmp) data += tmp else: @@ -77,15 +104,22 @@ def __init__(self, args, split): data += tmp n_datapoints += len(tmp) if dist.is_initialized() and dist.get_rank() == 0: - logging.info(f"number of data points for {split} split in {dataset_dir}: {n_datapoints}") + logging.info( + f"number of data points for {split} split in {dataset_dir}: {n_datapoints}" + ) assert len(data) > 0, f"no data found for {split} split" - lengths_list = [int(item[1]) for item in data] # use 1 because there might be more than 1 columns (for gigaspeech we have 3 columns: path, duration, selfsim) + lengths_list = [ + int(item[1]) for item in data + ] # use 1 because there might be more than 1 columns (for gigaspeech we have 3 columns: path, duration, selfsim) self.data = [] self.lengths_list = [] total_duration = 0 for d, l in zip(data, lengths_list): - if l >= self.args.encodec_sr*self.args.audio_min_length: - if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length: + if l >= self.args.encodec_sr * self.args.audio_min_length: + if ( + self.args.drop_long + and l > self.args.encodec_sr * self.args.audio_max_length + ): continue self.data.append(d) self.lengths_list.append(l) @@ -94,8 +128,12 @@ def __init__(self, args, split): # self.data = self.data[:1000] # self.lengths_list = self.lengths_list[:1000] if dist.is_initialized() and dist.get_rank() == 0: - logging.info(f"TOTAL number of data points for {self.split} split: {len(self.lengths_list)}") - logging.info(f"TOTAL duration for {self.split} split: {total_duration:.1f} hours") + logging.info( + f"TOTAL number of data points for {self.split} split: {len(self.lengths_list)}" + ) + logging.info( + f"TOTAL duration for {self.split} split: {total_duration:.1f} hours" + ) # phoneme vocabulary phn_set = set() for dataset_dir in self.dataset_dir: @@ -103,25 +141,41 @@ def __init__(self, args, split): with open(vocab_fn, "r") as f: temp = [l.strip().split("\t") for l in f.readlines() if len(l) != 0] phn_set.update([item[-1] for item in temp]) - self.phn2num = {item:i for i, item in enumerate(phn_set)} - assert self.args.text_vocab_size > len(self.phn2num), f"need self.args.text_vocab_size to be bigger than number of phns in vocab to handle OOD phn, but the former is {self.args.text_vocab_size} while the latter is {len(self.phn2num)}" + self.phn2num = {item: i for i, item in enumerate(phn_set)} + assert self.args.text_vocab_size > len( + self.phn2num + ), f"need self.args.text_vocab_size to be bigger than number of phns in vocab to handle OOD phn, but the former is {self.args.text_vocab_size} while the latter is {len(self.phn2num)}" - if (self.args.neighbor_prompt_prob > 0 and self.args.time_stretch_prob > 0) or self.args.target_time_stretch_prob > 0: + if ( + self.args.neighbor_prompt_prob > 0 and self.args.time_stretch_prob > 0 + ) or self.args.target_time_stretch_prob > 0: userdir = os.path.expanduser("~") - encodec_signature = getattr(self.args, "encodec_signature", os.path.join(userdir, "VoiceStar", "pretrained", "encodec_6f79c6a8.th")) - self.audio_tokenizer = AudioTokenizer(signature=encodec_signature, device=torch.device("cpu"), encode_only=True) - assert self.audio_tokenizer.sample_rate == self.args.codec_audio_sr, f"audio_tokenizer.sample_rate: {self.audio_tokenizer.sample_rate}, self.args.encodec_sr: {self.args.encodec_sr}" + encodec_signature = getattr( + self.args, + "encodec_signature", + os.path.join(userdir, "VoiceStar", "pretrained", "encodec_6f79c6a8.th"), + ) + self.audio_tokenizer = AudioTokenizer( + signature=encodec_signature, + device=torch.device("cpu"), + encode_only=True, + ) + assert ( + self.audio_tokenizer.sample_rate == self.args.codec_audio_sr + ), f"audio_tokenizer.sample_rate: {self.audio_tokenizer.sample_rate}, self.args.encodec_sr: {self.args.encodec_sr}" if dist.is_initialized() and dist.get_rank() == 0: - logging.info(f"rank: {dist.get_rank()}, audio_tokenizer device: {self.audio_tokenizer._device}") - + logging.info( + f"rank: {dist.get_rank()}, audio_tokenizer device: {self.audio_tokenizer._device}" + ) + def __len__(self): return len(self.lengths_list) - + def _load_phn_enc(self, index): item = self.data[index] dataset_dir = self.dataset_dir[item[-1]] - pf = os.path.join(dataset_dir, self.args.phn_folder_name, item[0]+".txt") - ef = os.path.join(dataset_dir, self.args.encodec_folder_name, item[0]+".txt") + pf = os.path.join(dataset_dir, self.args.phn_folder_name, item[0] + ".txt") + ef = os.path.join(dataset_dir, self.args.encodec_folder_name, item[0] + ".txt") # with certain probability, we load the audio, and time stretch it, note that we should not hit self.args.audio_max_length if "/librilight" in dataset_dir: audio_ext = ".flac" @@ -129,25 +183,54 @@ def _load_phn_enc(self, index): audio_ext = ".mp3" else: raise NotImplementedError(f"dataset_dir: {dataset_dir}") - - audio_fn = os.path.join(dataset_dir, self.args.audio_folder_name, item[0].replace(".txt", "")+audio_ext) - speed_factor = random.uniform(-self.args.target_time_stretch_bound, self.args.target_time_stretch_bound) + 1 - length_ok = (float(item[1]) / self.args.encodec_sr) / speed_factor < self.args.audio_max_length # NOTE to calculate the maximal duration after time stretching, we should be used as orig/(1-bound), rather than orig*(1+bound) - if self.args.target_time_stretch_prob > 0 and random.random() < self.args.target_time_stretch_prob and os.path.isfile(audio_fn) and length_ok: + + audio_fn = os.path.join( + dataset_dir, + self.args.audio_folder_name, + item[0].replace(".txt", "") + audio_ext, + ) + speed_factor = ( + random.uniform( + -self.args.target_time_stretch_bound, + self.args.target_time_stretch_bound, + ) + + 1 + ) + length_ok = ( + float(item[1]) / self.args.encodec_sr + ) / speed_factor < self.args.audio_max_length # NOTE to calculate the maximal duration after time stretching, we should be used as orig/(1-bound), rather than orig*(1+bound) + if ( + self.args.target_time_stretch_prob > 0 + and random.random() < self.args.target_time_stretch_prob + and os.path.isfile(audio_fn) + and length_ok + ): try: with open(pf, "r") as p: phns = [l.strip() for l in p.readlines()] assert len(phns) == 1, phns all_phns = phns[0].split(" ") - x = [self.phn2num[item] for item in all_phns if item in self.phn2num] + x = [ + self.phn2num[item] for item in all_phns if item in self.phn2num + ] except: - logging.info(f"loading failed for {pf}, maybe files don't exist or are corrupted") + logging.info( + f"loading failed for {pf}, maybe files don't exist or are corrupted" + ) return [], [[]], dataset_dir, audio_ext # time stretch try: process = ( - ffmpeg.input(audio_fn, ss=0, t=float(item[1]) / self.args.encodec_sr) - .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor)) + ffmpeg.input( + audio_fn, ss=0, t=float(item[1]) / self.args.encodec_sr + ) + .output( + "pipe:1", + format="f32le", + ac=1, + ar=self.audio_tokenizer.sample_rate, + filter="atempo={}".format(speed_factor), + ) .run_async(pipe_stdout=True, pipe_stderr=True) ) # Read the processed audio from ffmpeg stdout @@ -159,11 +242,19 @@ def _load_phn_enc(self, index): # Reshape the numpy array back to the expected shape (1, samples for mono) waveform = torch.from_numpy(output_np) waveform = waveform.unsqueeze(0).unsqueeze(0) - assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape + assert ( + waveform.ndim == 3 + and waveform.shape[0] == 1 + and waveform.shape[1] == 1 + ), waveform.shape with torch.no_grad(): - encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device)) - assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}" - encos = encos.cpu().squeeze(0).numpy().tolist() # [K, T] + encos = self.audio_tokenizer.encode( + waveform.to(self.audio_tokenizer._device) + ) + assert ( + encos.shape[1] == self.args.n_codebooks + ), f"encos.shape: {encos.shape}" + encos = encos.cpu().squeeze(0).numpy().tolist() # [K, T] if self.args.special_first: raise NotImplementedError # y = [[int(n)+self.args.n_special for n in l] for l in encos] @@ -171,7 +262,9 @@ def _load_phn_enc(self, index): y = [[int(n) for n in l] for l in encos] return x, y, dataset_dir, audio_ext except Exception as e: - logging.info(f"failed with time stretch and codec encode for {audio_fn}") + logging.info( + f"failed with time stretch and codec encode for {audio_fn}" + ) logging.info(f"error: {e}") pass @@ -180,9 +273,15 @@ def _load_phn_enc(self, index): phns = [l.strip() for l in p.readlines()] assert len(phns) == 1, phns all_phns = phns[0].split(" ") - x = [self.phn2num[item] for item in all_phns if item in self.phn2num] # we assume that OOD will not happen, because phn vocab is small - encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks] - + x = [ + self.phn2num[item] for item in all_phns if item in self.phn2num + ] # we assume that OOD will not happen, because phn vocab is small + encos = [ + l.strip().split() + for k, l in enumerate(e.readlines()) + if k < self.args.n_codebooks + ] + assert len(encos) == self.args.n_codebooks, ef if self.args.special_first: @@ -191,7 +290,9 @@ def _load_phn_enc(self, index): else: y = [[int(n) for n in l] for l in encos] except: - logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted") + logging.info( + f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted" + ) return [], [[]], dataset_dir, audio_ext return x, y, dataset_dir, audio_ext @@ -199,16 +300,29 @@ def _load_phn_enc(self, index): # this uses the output of step7_ipa_alignment.py def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext): neighbor = random.choice(neighbors) - neighbor_enc_fn = os.path.join(dataset_dir, self.args.encodec_folder_name, neighbor[0]) + neighbor_enc_fn = os.path.join( + dataset_dir, self.args.encodec_folder_name, neighbor[0] + ) if not os.path.isfile(neighbor_enc_fn): return None, None - neighbor_audio_path = os.path.join(dataset_dir, self.args.audio_folder_name, neighbor[0].replace(".txt", audio_ext)) - if getattr(self.args, "time_stretch_prob", 0) > 0 and not os.path.isfile(neighbor_audio_path): + neighbor_audio_path = os.path.join( + dataset_dir, + self.args.audio_folder_name, + neighbor[0].replace(".txt", audio_ext), + ) + if getattr(self.args, "time_stretch_prob", 0) > 0 and not os.path.isfile( + neighbor_audio_path + ): logging.info(f"audio file not found: {neighbor_audio_path}") return None, None if random.random() < getattr(self.args, "time_stretch_prob", 0): time_stretch_flag = True - speed_factor = random.uniform(-self.args.time_stretch_bound, self.args.time_stretch_bound) + 1 + speed_factor = ( + random.uniform( + -self.args.time_stretch_bound, self.args.time_stretch_bound + ) + + 1 + ) duration_factor = 1 / speed_factor else: time_stretch_flag = False @@ -221,17 +335,27 @@ def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext): if "/emilia" in dataset_dir: # get neighbor duration neighbor_dur = float(neighbor[2]) - if neighbor_dur * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length or neighbor_dur * duration_factor < self.args.min_prompt_len: + if ( + neighbor_dur * duration_factor + y_len / self.args.encodec_sr + > self.args.audio_max_length + or neighbor_dur * duration_factor < self.args.min_prompt_len + ): return None, None try: - neighbor_pf = os.path.join(dataset_dir, self.args.phn_folder_name, neighbor[0]) + neighbor_pf = os.path.join( + dataset_dir, self.args.phn_folder_name, neighbor[0] + ) with open(neighbor_pf, "r") as p: phns = [l.strip() for l in p.readlines()] assert len(phns) == 1, phns all_phns = phns[0].split(" ") - phn_token = [self.phn2num[item] for item in all_phns if item in self.phn2num] + phn_token = [ + self.phn2num[item] for item in all_phns if item in self.phn2num + ] except: - logging.info(f"loading failed for {neighbor_pf}, maybe files don't exist") + logging.info( + f"loading failed for {neighbor_pf}, maybe files don't exist" + ) return None, None # if do not stretch the audio if not time_stretch_flag: @@ -239,21 +363,27 @@ def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext): neighbor_enc = [l.strip().split() for l in f.readlines()] if len(neighbor_enc) != self.args.n_codebooks: return None, None - # if too long + # if too long else: if self.args.special_first: raise NotImplementedError # neighbor_enc = [[int(n)+self.args.n_special for n in l] for l in neighbor_enc] else: neighbor_enc = [[int(n) for n in l] for l in neighbor_enc] - + return phn_token, neighbor_enc - else: # stretch the audio with ffmpeg-python + else: # stretch the audio with ffmpeg-python process = ( ffmpeg.input(neighbor_audio_path, ss=0, t=neighbor_dur) - .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor)) - .run_async(pipe_stdout=True, pipe_stderr=True) + .output( + "pipe:1", + format="f32le", + ac=1, + ar=self.audio_tokenizer.sample_rate, + filter="atempo={}".format(speed_factor), ) + .run_async(pipe_stdout=True, pipe_stderr=True) + ) # Read the processed audio from ffmpeg stdout output, _ = process.communicate() @@ -263,34 +393,58 @@ def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext): # Reshape the numpy array back to the expected shape (1, samples for mono) waveform = torch.from_numpy(output_np) waveform = waveform.unsqueeze(0).unsqueeze(0) - assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape + assert ( + waveform.ndim == 3 + and waveform.shape[0] == 1 + and waveform.shape[1] == 1 + ), waveform.shape with torch.no_grad(): - encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device)) - assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}" - neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T] + encos = self.audio_tokenizer.encode( + waveform.to(self.audio_tokenizer._device) + ) + assert ( + encos.shape[1] == self.args.n_codebooks + ), f"encos.shape: {encos.shape}" + neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T] return phn_token, neighbor_enc ####################### TODO for now always use the entire neighbor for emilia ####################### TODO for now always use the entire neighbor for emilia - ipa_alignment_fn = os.path.join(dataset_dir, self.args.ipa_alignment_folder_name, neighbor[0]) + ipa_alignment_fn = os.path.join( + dataset_dir, self.args.ipa_alignment_folder_name, neighbor[0] + ) if not os.path.isfile(ipa_alignment_fn): # print(f"file not found: {ipa_alignment_fn}", flush=True) return None, None with open(ipa_alignment_fn, "r") as f: alignments = [l.strip().split("\t") for l in f.readlines()] - alignments = [[float(l[0]), float(l[1]), l[2]] for l in alignments if len(l) == 3] - alignments = [l for l in alignments if self.args.min_prompt_len < (l[1] - l[0]) * duration_factor < self.args.max_prompt_len] + alignments = [ + [float(l[0]), float(l[1]), l[2]] for l in alignments if len(l) == 3 + ] + alignments = [ + l + for l in alignments + if self.args.min_prompt_len + < (l[1] - l[0]) * duration_factor + < self.args.max_prompt_len + ] if len(alignments) == 0: # print(f"no valid alignment found for {ipa_alignment_fn}") return None, None idx = random.choice(range(len(alignments))) - while (alignments[idx][1] - alignments[idx][0]) * duration_factor + y_len / self.args.encodec_sr > self.args.audio_max_length: + while ( + (alignments[idx][1] - alignments[idx][0]) * duration_factor + + y_len / self.args.encodec_sr + > self.args.audio_max_length + ): idx -= 1 if idx < 0: # print(f"too long combined with y_len {ipa_alignment_fn=}, and {y_len=}") return None, None - if (alignments[idx][1] - alignments[idx][0]) * duration_factor < self.args.min_prompt_len: + if ( + alignments[idx][1] - alignments[idx][0] + ) * duration_factor < self.args.min_prompt_len: return None, None - + start_time, end_time = alignments[idx][:2] phn = alignments[idx][2].split(" ") phn_token = [self.phn2num[item] for item in phn if item in self.phn2num] @@ -301,7 +455,13 @@ def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext): duration = end_time - start_time process = ( ffmpeg.input(neighbor_audio_path, ss=start_time, t=duration) - .output('pipe:1', format='f32le', ac=1, ar=self.audio_tokenizer.sample_rate, filter='atempo={}'.format(speed_factor)) + .output( + "pipe:1", + format="f32le", + ac=1, + ar=self.audio_tokenizer.sample_rate, + filter="atempo={}".format(speed_factor), + ) .run_async(pipe_stdout=True, pipe_stderr=True) ) # Read the processed audio from ffmpeg stdout @@ -313,15 +473,23 @@ def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext): # Reshape the numpy array back to the expected shape (1, samples for mono) waveform = torch.from_numpy(output_np) waveform = waveform.unsqueeze(0).unsqueeze(0) - assert waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1, waveform.shape + assert ( + waveform.ndim == 3 and waveform.shape[0] == 1 and waveform.shape[1] == 1 + ), waveform.shape try: with torch.no_grad(): - encos = self.audio_tokenizer.encode(waveform.to(self.audio_tokenizer._device)) + encos = self.audio_tokenizer.encode( + waveform.to(self.audio_tokenizer._device) + ) except: - logging.info(f"failed with time stretch for {neighbor_audio_path}, from {start_time} to {end_time} with duration factor {duration_factor}, which leads to {duration*duration_factor} seconds") + logging.info( + f"failed with time stretch for {neighbor_audio_path}, from {start_time} to {end_time} with duration factor {duration_factor}, which leads to {duration*duration_factor} seconds" + ) return None, None - assert encos.shape[1] == self.args.n_codebooks, f"encos.shape: {encos.shape}" - neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T] + assert ( + encos.shape[1] == self.args.n_codebooks + ), f"encos.shape: {encos.shape}" + neighbor_enc = encos.cpu().squeeze(0).numpy().tolist() # [K, T] return phn_token, neighbor_enc else: # get encodec codes from storage @@ -334,7 +502,9 @@ def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext): # trim the encodec codes to the segment start_enc_frame = int(start_time * self.args.encodec_sr) end_enc_frame = int(end_time * self.args.encodec_sr) - neighbor_enc = [l[start_enc_frame:end_enc_frame] for l in neighbor_enc] + neighbor_enc = [ + l[start_enc_frame:end_enc_frame] for l in neighbor_enc + ] if len(neighbor_enc[0]) == 0: # print(f"no valid encodec codes found for {neighbor_enc_fn}") return None, None @@ -347,27 +517,33 @@ def find_neighbor(self, neighbors, y_len, dataset_dir, audio_ext): def __getitem__(self, index): x, y, dataset_dir, audio_ext = self._load_phn_enc(index) x_len, y_len = len(x), len(y[0]) - extra_ret = {'x_sep_token_position': 0, 'y_sep_token_position': 0} + extra_ret = {"x_sep_token_position": 0, "y_sep_token_position": 0} if x_len == 0 or y_len == 0: ret = { - "x": None, - "x_len": None, - "y": None, + "x": None, + "x_len": None, + "y": None, "y_len": None, - } + } ret.update(extra_ret) return ret - while y_len < self.args.encodec_sr*self.args.audio_min_length: + while y_len < self.args.encodec_sr * self.args.audio_min_length: assert not self.args.dynamic_batching - index = random.choice(range(len(self))) # regenerate an index + index = random.choice(range(len(self))) # regenerate an index x, y, dataset_dir, audio_ext = self._load_phn_enc(index) x_len, y_len = len(x), len(y[0]) - + # if use neighbor prompt x_neighbor, y_neighbor = None, None use_neighbor_prob = random.random() - neighbor_fn = os.path.join(dataset_dir, self.args.neighbor_folder_name, self.data[index][0]+".txt") - if self.args.neighbor_prompt_prob > 0 and use_neighbor_prob < self.args.neighbor_prompt_prob and os.path.isfile(neighbor_fn): # it might not exist, just because we didn't find neighbor for this file (other than itself, which is common for emilia) + neighbor_fn = os.path.join( + dataset_dir, self.args.neighbor_folder_name, self.data[index][0] + ".txt" + ) + if ( + self.args.neighbor_prompt_prob > 0 + and use_neighbor_prob < self.args.neighbor_prompt_prob + and os.path.isfile(neighbor_fn) + ): # it might not exist, just because we didn't find neighbor for this file (other than itself, which is common for emilia) with open(neighbor_fn, "r") as f: neighbors = [l.strip().split("\t") for l in f.readlines()] # select neighbors @@ -379,26 +555,38 @@ def __getitem__(self, index): raise NotImplementedError x_neighbor, y_neighbor = None, None if len(neighbors) > 0: - x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext) + x_neighbor, y_neighbor = self.find_neighbor( + neighbors, y_len, dataset_dir, audio_ext + ) i_trial = 0 - while x_neighbor is None and i_trial < self.args.num_trial and i_trial < len(neighbors): - x_neighbor, y_neighbor = self.find_neighbor(neighbors, y_len, dataset_dir, audio_ext) + while ( + x_neighbor is None + and i_trial < self.args.num_trial + and i_trial < len(neighbors) + ): + x_neighbor, y_neighbor = self.find_neighbor( + neighbors, y_len, dataset_dir, audio_ext + ) i_trial += 1 - + if x_neighbor != None: if self.args.x_sep_token != None: x = x_neighbor + [self.args.x_sep_token] + x else: x = x_neighbor + x if self.args.y_sep_token != None: - y = [y_neighbor[i] + [self.args.y_sep_token] + y[i] for i in range(len(y))] + y = [ + y_neighbor[i] + [self.args.y_sep_token] + y[i] + for i in range(len(y)) + ] else: y = [y_neighbor[i] + y[i] for i in range(len(y))] - extra_ret['y_sep_token_position'] = len(y_neighbor[0]) + 1 # if using y_sep_token, this is actually the position of the token right before the y_sep_token, but since y_sep_token is ignored in loss computation, it's fine that we use the position of the token right before it - extra_ret['x_sep_token_position'] = len(x_neighbor) + 1 + extra_ret["y_sep_token_position"] = ( + len(y_neighbor[0]) + 1 + ) # if using y_sep_token, this is actually the position of the token right before the y_sep_token, but since y_sep_token is ignored in loss computation, it's fine that we use the position of the token right before it + extra_ret["x_sep_token_position"] = len(x_neighbor) + 1 x_len, y_len = len(x), len(y[0]) - # consider adding eos to the end of the text if self.args.add_eos_to_text != 0: x.append(self.args.add_eos_to_text) @@ -411,37 +599,46 @@ def __getitem__(self, index): # adjust the length of encodec codes, pad to max_len or randomly crop orig_y_len = copy.copy(y_len) max_len = int(self.args.audio_max_length * self.args.encodec_sr) - if y_len > max_len + 10: # give it some margin for rounding error + if y_len > max_len + 10: # give it some margin for rounding error raise RuntimeError(f"audio is too long, {y_len=}, {max_len=}") else: audio_start = 0 if not self.args.dynamic_batching: - pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len) + pad = ( + [0] * (max_len - y_len) + if self.args.sep_special_token + else [self.args.audio_pad_token] * (max_len - y_len) + ) for i in range(len(y)): y[i] = y[i] + pad - + if self.args.pad_x and x_len <= self.args.text_max_length: - pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len) + pad = ( + [0] * (self.args.text_max_length - x_len) + if self.args.sep_special_token + else [self.args.text_pad_token] * (self.args.text_max_length - x_len) + ) x = x + pad - + ret = { - "x": torch.LongTensor(x), - "x_len": x_len, - "y": torch.LongTensor(y), + "x": torch.LongTensor(x), + "x_len": x_len, + "y": torch.LongTensor(y), "y_len": y_len, - } + } ret.update(extra_ret) return ret - def collate(self, batch): # make sure keys in every batch is the same for batch1, batch2 in zip(batch[:-1], batch[1:]): - assert set(batch1.keys()) == set(batch2.keys()), f"keys in batch1: {batch1.keys()} and keys in batch2: {batch2.keys()} are different" - out = {key:[] for key in batch[0]} + assert set(batch1.keys()) == set( + batch2.keys() + ), f"keys in batch1: {batch1.keys()} and keys in batch2: {batch2.keys()} are different" + out = {key: [] for key in batch[0]} for item in batch: - if item['x'] == None: # deal with load failure + if item["x"] == None: # deal with load failure continue for key, val in item.items(): out[key].append(val) @@ -449,18 +646,27 @@ def collate(self, batch): if self.args.pad_x: res["x"] = torch.stack(out["x"], dim=0) else: - res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token) + res["x"] = torch.nn.utils.rnn.pad_sequence( + out["x"], batch_first=True, padding_value=self.args.text_pad_token + ) res["x_lens"] = torch.LongTensor(out["x_len"]) if self.args.dynamic_batching: - res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token) - res['y'] = res['y'].permute(1,2,0) # T B K -> B K T + res["y"] = torch.nn.utils.rnn.pad_sequence( + [item.transpose(1, 0) for item in out["y"]], + padding_value=self.args.audio_pad_token, + ) + res["y"] = res["y"].permute(1, 2, 0) # T B K -> B K T else: - res['y'] = torch.stack(out['y'], dim=0) + res["y"] = torch.stack(out["y"], dim=0) res["y_lens"] = torch.LongTensor(out["y_len"]) - res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1) - res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1) + res["text_padding_mask"] = torch.arange(res["x"][0].shape[-1]).unsqueeze( + 0 + ) >= res["x_lens"].unsqueeze(1) + res["audio_padding_mask"] = torch.arange(res["y"][0].shape[-1]).unsqueeze( + 0 + ) >= res["y_lens"].unsqueeze(1) if "y_sep_token_position" in out: res["y_sep_token_position"] = torch.LongTensor(out["y_sep_token_position"]) if "x_sep_token_position" in out: res["x_sep_token_position"] = torch.LongTensor(out["x_sep_token_position"]) - return res \ No newline at end of file + return res diff --git a/data/emilia_preprocessing/encodec.py b/data/emilia_preprocessing/encodec.py index be65978..5c7579d 100644 --- a/data/emilia_preprocessing/encodec.py +++ b/data/emilia_preprocessing/encodec.py @@ -25,10 +25,13 @@ import warnings from einops import rearrange, repeat import omegaconf + # import flashy -CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', - 'time_group_norm']) +CONV_NORMALIZATIONS = frozenset( + ["none", "weight_norm", "spectral_norm", "time_group_norm"] +) + def dict_from_config(cfg: omegaconf.DictConfig) -> dict: """Convenience function to map an omegaconf configuration to a dictionary. @@ -42,6 +45,7 @@ def dict_from_config(cfg: omegaconf.DictConfig) -> dict: assert isinstance(dct, dict) return dct + @dataclass class QuantizedResult: x: torch.Tensor @@ -50,9 +54,9 @@ class QuantizedResult: penalty: tp.Optional[torch.Tensor] = None metrics: dict = field(default_factory=dict) + class BaseQuantizer(nn.Module): - """Base class for quantizers. - """ + """Base class for quantizers.""" def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: """ @@ -85,17 +89,19 @@ def set_num_codebooks(self, n: int): """Set the number of active codebooks.""" raise NotImplementedError() + class CompressionModel(ABC, nn.Module): """Base API for all compression model that aim at being used as audio tokenizers with a language model. """ @abstractmethod - def forward(self, x: torch.Tensor) -> QuantizedResult: - ... + def forward(self, x: torch.Tensor) -> QuantizedResult: ... @abstractmethod - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """See `EncodecModel.encode`.""" ... @@ -111,44 +117,39 @@ def decode_latent(self, codes: torch.Tensor): @property @abstractmethod - def channels(self) -> int: - ... + def channels(self) -> int: ... @property @abstractmethod - def frame_rate(self) -> float: - ... + def frame_rate(self) -> float: ... @property @abstractmethod - def sample_rate(self) -> int: - ... + def sample_rate(self) -> int: ... @property @abstractmethod - def cardinality(self) -> int: - ... + def cardinality(self) -> int: ... @property @abstractmethod - def num_codebooks(self) -> int: - ... + def num_codebooks(self) -> int: ... @property @abstractmethod - def total_codebooks(self) -> int: - ... + def total_codebooks(self) -> int: ... @abstractmethod def set_num_codebooks(self, n: int): """Set the active number of codebooks used by the quantizer.""" ... -def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): + +def apply_parametrization_norm(module: nn.Module, norm: str = "none"): assert norm in CONV_NORMALIZATIONS - if norm == 'weight_norm': + if norm == "weight_norm": return weight_norm(module) - elif norm == 'spectral_norm': + elif norm == "spectral_norm": return spectral_norm(module) else: # We already check was in CONV_NORMALIZATION, so any other choice @@ -156,12 +157,14 @@ def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): return module -def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs): +def get_norm_module( + module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs +): """Return the proper normalization module. If causal is True, this will ensure the returned module is causal, or return an error if the normalization doesn't support causal evaluation. """ assert norm in CONV_NORMALIZATIONS - if norm == 'time_group_norm': + if norm == "time_group_norm": if causal: raise ValueError("GroupNorm doesn't support causal evaluation.") assert isinstance(module, nn.modules.conv._ConvNd) @@ -170,8 +173,9 @@ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', return nn.Identity() -def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, - padding_total: int = 0) -> int: +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: """See `pad_for_conv1d`.""" length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 @@ -179,7 +183,9 @@ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, return ideal_length - length -def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): +def pad_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +): """Pad for a convolution to make sure that the last window is full. Extra padding is added at the end. This is required to ensure that we can rebuild an output of the same length, as otherwise, even with padding, some time steps @@ -194,14 +200,19 @@ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total return F.pad(x, (0, extra_padding)) -def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): +def pad1d( + x: torch.Tensor, + paddings: tp.Tuple[int, int], + mode: str = "constant", + value: float = 0.0, +): """Tiny wrapper around F.pad, just to allow for reflect padding on small input. If this is the case, we insert extra 0 padding to the right before the reflection happen. """ length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == 'reflect': + if mode == "reflect": max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: @@ -220,15 +231,22 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert (padding_left + padding_right) <= x.shape[-1] end = x.shape[-1] - padding_right - return x[..., padding_left: end] + return x[..., padding_left:end] class NormConv1d(nn.Module): """Wrapper around Conv1d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, causal: bool = False, norm: str = 'none', - norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) @@ -244,7 +262,14 @@ class NormConv2d(nn.Module): """Wrapper around Conv2d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) @@ -260,10 +285,19 @@ class NormConvTranspose1d(nn.Module): """Wrapper around ConvTranspose1d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, causal: bool = False, norm: str = 'none', - norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() - self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.convtr = apply_parametrization_norm( + nn.ConvTranspose1d(*args, **kwargs), norm + ) self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) self.norm_type = norm @@ -277,9 +311,18 @@ class NormConvTranspose2d(nn.Module): """Wrapper around ConvTranspose2d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() - self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.convtr = apply_parametrization_norm( + nn.ConvTranspose2d(*args, **kwargs), norm + ) self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) def forward(self, x): @@ -292,19 +335,40 @@ class StreamableConv1d(nn.Module): """Conv1d with some builtin handling of asymmetric or causal padding and normalization. """ - def __init__(self, in_channels: int, out_channels: int, - kernel_size: int, stride: int = 1, dilation: int = 1, - groups: int = 1, bias: bool = True, causal: bool = False, - norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, - pad_mode: str = 'reflect'): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = "reflect", + ): super().__init__() # warn user on unusual setup between dilation and stride if stride > 1 and dilation > 1: - warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1" - f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).") - self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, - dilation=dilation, groups=groups, bias=bias, causal=causal, - norm=norm, norm_kwargs=norm_kwargs) + warnings.warn( + "StreamableConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) self.causal = causal self.pad_mode = pad_mode @@ -313,9 +377,13 @@ def forward(self, x): kernel_size = self.conv.conv.kernel_size[0] stride = self.conv.conv.stride[0] dilation = self.conv.conv.dilation[0] - kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + kernel_size = ( + kernel_size - 1 + ) * dilation + 1 # effective kernel size with dilations padding_total = kernel_size - stride - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + extra_padding = get_extra_padding_for_conv1d( + x, kernel_size, stride, padding_total + ) if self.causal: # Left padding for causal x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) @@ -323,7 +391,9 @@ def forward(self, x): # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right - x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) return self.conv(x) @@ -331,18 +401,34 @@ class StreamableConvTranspose1d(nn.Module): """ConvTranspose1d with some builtin handling of asymmetric or causal padding and normalization. """ - def __init__(self, in_channels: int, out_channels: int, - kernel_size: int, stride: int = 1, causal: bool = False, - norm: str = 'none', trim_right_ratio: float = 1., - norm_kwargs: tp.Dict[str, tp.Any] = {}): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: tp.Dict[str, tp.Any] = {}, + ): super().__init__() - self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, - causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) self.causal = causal self.trim_right_ratio = trim_right_ratio - assert self.causal or self.trim_right_ratio == 1., \ - "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" - assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 def forward(self, x): kernel_size = self.convtr.convtr.kernel_size[0] @@ -373,6 +459,7 @@ class StreamableLSTM(nn.Module): """LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout. """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): super().__init__() self.skip = skip @@ -404,12 +491,25 @@ class SEANetResnetBlock(nn.Module): true_skip (bool): Whether to use true skip connection or a simple (streamable) convolution as the skip connection. """ - def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], - activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, - pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): + + def __init__( + self, + dim: int, + kernel_sizes: tp.List[int] = [3, 1], + dilations: tp.List[int] = [1, 1], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): super().__init__() - assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" act = getattr(nn, activation) hidden = dim // compress block = [] @@ -418,17 +518,31 @@ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp. out_chs = dim if i == len(kernel_sizes) - 1 else hidden block += [ act(**activation_params), - StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, - norm=norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode), + StreamableConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] self.block = nn.Sequential(*block) self.shortcut: nn.Module if true_skip: self.shortcut = nn.Identity() else: - self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode) + self.shortcut = StreamableConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) def forward(self, x): return self.shortcut(x) + self.block(x) @@ -462,12 +576,29 @@ class SEANetEncoder(nn.Module): disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. For the encoder, it corresponds to the N first blocks. """ - def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, - ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, - last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, - pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, - disable_norm_outer_blocks: int = 0): + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = True, + compress: int = 2, + lstm: int = 0, + disable_norm_outer_blocks: int = 0, + ): super().__init__() self.channels = channels self.dimension = dimension @@ -478,36 +609,61 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, self.hop_length = np.prod(self.ratios) self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks - assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ - "Number of blocks for which to disable norm is invalid." \ + assert ( + self.disable_norm_outer_blocks >= 0 + and self.disable_norm_outer_blocks <= self.n_blocks + ), ( + "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + ) act = getattr(nn, activation) mult = 1 model: tp.List[nn.Module] = [ - StreamableConv1d(channels, mult * n_filters, kernel_size, - norm='none' if self.disable_norm_outer_blocks >= 1 else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + channels, + mult * n_filters, + kernel_size, + norm="none" if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) ] # Downsample to raw audio scale for i, ratio in enumerate(self.ratios): - block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm + block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm # Add residual layers for j in range(n_residual_layers): model += [ - SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], - norm=block_norm, norm_params=norm_params, - activation=activation, activation_params=activation_params, - causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + SEANetResnetBlock( + mult * n_filters, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + norm=block_norm, + norm_params=norm_params, + activation=activation, + activation_params=activation_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] # Add downsampling layers model += [ act(**activation_params), - StreamableConv1d(mult * n_filters, mult * n_filters * 2, - kernel_size=ratio * 2, stride=ratio, - norm=block_norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode), + StreamableConv1d( + mult * n_filters, + mult * n_filters * 2, + kernel_size=ratio * 2, + stride=ratio, + norm=block_norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] mult *= 2 @@ -516,9 +672,17 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, model += [ act(**activation_params), - StreamableConv1d(mult * n_filters, dimension, last_kernel_size, - norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + mult * n_filters, + dimension, + last_kernel_size, + norm=( + "none" if self.disable_norm_outer_blocks == self.n_blocks else norm + ), + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] self.model = nn.Sequential(*model) @@ -557,13 +721,32 @@ class SEANetDecoder(nn.Module): trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. If equal to 1.0, it means that all the trimming is done at the right. """ - def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, - ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, - last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, - pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, - disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0): + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + final_activation: tp.Optional[str] = None, + final_activation_params: tp.Optional[dict] = None, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = True, + compress: int = 2, + lstm: int = 0, + disable_norm_outer_blocks: int = 0, + trim_right_ratio: float = 1.0, + ): super().__init__() self.dimension = dimension self.channels = channels @@ -574,16 +757,28 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, self.hop_length = np.prod(self.ratios) self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks - assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ - "Number of blocks for which to disable norm is invalid." \ + assert ( + self.disable_norm_outer_blocks >= 0 + and self.disable_norm_outer_blocks <= self.n_blocks + ), ( + "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + ) act = getattr(nn, activation) mult = int(2 ** len(self.ratios)) model: tp.List[nn.Module] = [ - StreamableConv1d(dimension, mult * n_filters, kernel_size, - norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=( + "none" if self.disable_norm_outer_blocks == self.n_blocks else norm + ), + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) ] if lstm: @@ -591,40 +786,63 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, # Upsample to raw audio scale for i, ratio in enumerate(self.ratios): - block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm + block_norm = ( + "none" + if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) + else norm + ) # Add upsampling layers model += [ act(**activation_params), - StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2, - kernel_size=ratio * 2, stride=ratio, - norm=block_norm, norm_kwargs=norm_params, - causal=causal, trim_right_ratio=trim_right_ratio), + StreamableConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=block_norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), ] # Add residual layers for j in range(n_residual_layers): model += [ - SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], - activation=activation, activation_params=activation_params, - norm=block_norm, norm_params=norm_params, causal=causal, - pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + SEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=block_norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] mult //= 2 # Add final layers model += [ act(**activation_params), - StreamableConv1d(n_filters, channels, last_kernel_size, - norm='none' if self.disable_norm_outer_blocks >= 1 else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + n_filters, + channels, + last_kernel_size, + norm="none" if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] # Add optional final activation to decoder (eg. tanh) if final_activation is not None: final_act = getattr(nn, final_activation) final_activation_params = final_activation_params or {} - model += [ - final_act(**final_activation_params) - ] + model += [final_act(**final_activation_params)] self.model = nn.Sequential(*model) def forward(self, z): @@ -675,10 +893,8 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10): means = sample_vectors(samples, num_clusters) for _ in range(num_iters): - diffs = rearrange(samples, "n d -> n () d") - rearrange( - means, "c d -> () c d" - ) - dists = -(diffs ** 2).sum(dim=-1) + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) @@ -700,7 +916,7 @@ def orthogonal_loss_fn(t): normed_codes = l2norm(t) identity = torch.eye(n, device=t.device) cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) - return ((cosine_sim - identity) ** 2).sum() / (n ** 2) + return ((cosine_sim - identity) ** 2).sum() / (n**2) class EuclideanCodebook(nn.Module): @@ -719,6 +935,7 @@ class EuclideanCodebook(nn.Module): that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ + def __init__( self, dim: int, @@ -731,7 +948,9 @@ def __init__( ): super().__init__() self.decay = decay - init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) embed = init_fn(codebook_size, dim) self.codebook_size = codebook_size @@ -862,6 +1081,7 @@ class VectorQuantization(nn.Module): that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ + def __init__( self, dim: int, @@ -873,7 +1093,7 @@ def __init__( kmeans_iters: int = 10, threshold_ema_dead_code: int = 2, channels_last: bool = False, - commitment_weight: float = 1., + commitment_weight: float = 1.0, orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: tp.Optional[int] = None, @@ -882,8 +1102,12 @@ def __init__( _codebook_dim: int = default(codebook_dim, dim) requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) self.epsilon = epsilon self.commitment_weight = commitment_weight @@ -892,10 +1116,15 @@ def __init__( self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes - self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, - decay=decay, epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code) + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) self.codebook_size = codebook_size self.channels_last = channels_last @@ -956,8 +1185,13 @@ def forward(self, x): codebook = codebook[unique_code_ids] num_codes = codebook.shape[0] - if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: - rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] + if ( + exists(self.orthogonal_reg_max_codes) + and num_codes > self.orthogonal_reg_max_codes + ): + rand_ids = torch.randperm(num_codes, device=device)[ + : self.orthogonal_reg_max_codes + ] codebook = codebook[rand_ids] orthogonal_reg_loss = orthogonal_loss_fn(codebook) @@ -974,17 +1208,20 @@ class ResidualVectorQuantization(nn.Module): Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ + def __init__(self, *, num_quantizers, **kwargs): super().__init__() - codebook_size = kwargs.pop('codebook_size', None) + codebook_size = kwargs.pop("codebook_size", None) if codebook_size is None: raise ValueError("codebook_size must be provided in kwargs") if type(codebook_size) != list: codebook_size = [codebook_size] * num_quantizers self.layers = nn.ModuleList( - [VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)] + [ + VectorQuantization(codebook_size=cur_codebook_size, **kwargs) + for _, cur_codebook_size in zip(range(num_quantizers), codebook_size) + ] ) - # self.layers = nn.ModuleList( # [VectorQuantization(**kwargs) for _ in range(num_quantizers)] @@ -1058,6 +1295,7 @@ class ResidualVectorQuantizer(BaseQuantizer): orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. for orthogonal regularization. """ + def __init__( self, dimension: int = 256, @@ -1096,7 +1334,7 @@ def __init__( orthogonal_reg_weight=self.orthogonal_reg_weight, orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, - channels_last=False + channels_last=False, ) def forward(self, x: torch.Tensor, frame_rate: int): @@ -1144,15 +1382,18 @@ def set_num_codebooks(self, n: int): assert n > 0 and n <= self.max_n_q self.n_q = n + class DummyQuantizer(BaseQuantizer): - """Fake quantizer that actually does not perform any quantization. - """ + """Fake quantizer that actually does not perform any quantization.""" + def __init__(self): super().__init__() def forward(self, x: torch.Tensor, frame_rate: int): q = x.unsqueeze(1) - return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) + return QuantizedResult( + x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x) + ) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode a given input tensor with the specified sample rate at the given bandwidth. @@ -1180,7 +1421,9 @@ def num_codebooks(self): def set_num_codebooks(self, n: int): """Set the number of active codebooks.""" - raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") + raise AttributeError( + "Cannot override the number of codebooks for the dummy quantizer" + ) class EncodecModel(CompressionModel): @@ -1196,21 +1439,24 @@ class EncodecModel(CompressionModel): causal (bool): Whether to use a causal version of the model. renormalize (bool): Whether to renormalize the audio before running the model. """ + # we need assignment to override the property in the abstract class, # I couldn't find a better way... frame_rate: float = 0 sample_rate: int = 0 channels: int = 0 - def __init__(self, - encoder: nn.Module, - decoder: nn.Module, - quantizer: BaseQuantizer, - frame_rate: int, - sample_rate: int, - channels: int, - causal: bool = False, - renormalize: bool = False): + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + quantizer: BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False, + ): super().__init__() self.encoder = encoder self.decoder = decoder @@ -1223,7 +1469,7 @@ def __init__(self, if self.causal: # we force disabling here to avoid handling linear overlap of segments # as supported in original EnCodec codebase. - assert not self.renormalize, 'Causal model does not support renormalize' + assert not self.renormalize, "Causal model does not support renormalize" @property def total_codebooks(self): @@ -1244,7 +1490,9 @@ def cardinality(self): """Cardinality of each codebook.""" return self.quantizer.bins - def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def preprocess( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: scale: tp.Optional[torch.Tensor] if self.renormalize: mono = x.mean(dim=1, keepdim=True) @@ -1256,9 +1504,9 @@ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torc scale = None return x, scale - def postprocess(self, - x: torch.Tensor, - scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + def postprocess( + self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: if scale is not None: assert self.renormalize x = x * scale.view(-1, 1, 1) @@ -1285,7 +1533,9 @@ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: return q_res - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Encode the given input tensor to quantized representation along with scale parameter. Args: @@ -1323,6 +1573,7 @@ def decode_latent(self, codes: torch.Tensor): """Decode from the discrete codes to continuous latent space.""" return self.quantizer.decode(codes) + class EncodecModel_encode_only(CompressionModel): """Encodec model operating on the raw waveform. Encode only, so no decoder @@ -1335,20 +1586,23 @@ class EncodecModel_encode_only(CompressionModel): causal (bool): Whether to use a causal version of the model. renormalize (bool): Whether to renormalize the audio before running the model. """ + # we need assignment to override the property in the abstract class, # I couldn't find a better way... frame_rate: float = 0 sample_rate: int = 0 channels: int = 0 - def __init__(self, - encoder: nn.Module, - quantizer: BaseQuantizer, - frame_rate: int, - sample_rate: int, - channels: int, - causal: bool = False, - renormalize: bool = False): + def __init__( + self, + encoder: nn.Module, + quantizer: BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False, + ): super().__init__() self.encoder = encoder self.quantizer = quantizer @@ -1360,7 +1614,7 @@ def __init__(self, if self.causal: # we force disabling here to avoid handling linear overlap of segments # as supported in original EnCodec codebase. - assert not self.renormalize, 'Causal model does not support renormalize' + assert not self.renormalize, "Causal model does not support renormalize" @property def total_codebooks(self): @@ -1381,7 +1635,9 @@ def cardinality(self): """Cardinality of each codebook.""" return self.quantizer.bins - def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def preprocess( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: scale: tp.Optional[torch.Tensor] if self.renormalize: mono = x.mean(dim=1, keepdim=True) @@ -1393,9 +1649,9 @@ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torc scale = None return x, scale - def postprocess(self, - x: torch.Tensor, - scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + def postprocess( + self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: if scale is not None: assert self.renormalize x = x * scale.view(-1, 1, 1) @@ -1422,7 +1678,9 @@ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: return q_res - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Encode the given input tensor to quantized representation along with scale parameter. Args: @@ -1462,21 +1720,22 @@ def decode_latent(self, codes: torch.Tensor): raise NotImplementedError("Decode is not supported for encode only model") return self.quantizer.decode(codes) -def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer: - klass = { - 'no_quant': DummyQuantizer, - 'rvq': ResidualVectorQuantizer - }[quantizer] + +def get_quantizer( + quantizer: str, cfg: omegaconf.DictConfig, dimension: int +) -> BaseQuantizer: + klass = {"no_quant": DummyQuantizer, "rvq": ResidualVectorQuantizer}[quantizer] kwargs = dict_from_config(getattr(cfg, quantizer)) - if quantizer != 'no_quant': - kwargs['dimension'] = dimension + if quantizer != "no_quant": + kwargs["dimension"] = dimension return klass(**kwargs) + def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): - if encoder_name == 'seanet': - kwargs = dict_from_config(getattr(cfg, 'seanet')) - encoder_override_kwargs = kwargs.pop('encoder') - decoder_override_kwargs = kwargs.pop('decoder') + if encoder_name == "seanet": + kwargs = dict_from_config(getattr(cfg, "seanet")) + encoder_override_kwargs = kwargs.pop("encoder") + decoder_override_kwargs = kwargs.pop("decoder") encoder_kwargs = {**kwargs, **encoder_override_kwargs} decoder_kwargs = {**kwargs, **decoder_override_kwargs} encoder = SEANetEncoder(**encoder_kwargs) @@ -1489,56 +1748,93 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel: """Instantiate a compression model.""" if device == None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - state = torch.load(ckpt_fn, map_location='cpu') - cfg = state['xp.cfg'] + device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + state = torch.load(ckpt_fn, map_location="cpu") + cfg = state["xp.cfg"] cfg.device = str(device) - weights = state['best_state']['model'] - assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now." + weights = state["best_state"]["model"] + assert ( + cfg.compression_model == "encodec" + ), "Only Encodec model is supported for now." if encode_only: all_keys = list(weights.keys()) for key in all_keys: - if key.startswith('decoder'): + if key.startswith("decoder"): del weights[key] - kwargs = dict_from_config(getattr(cfg, 'encodec')) - encoder_name = kwargs.pop('autoencoder') - quantizer_name = kwargs.pop('quantizer') + kwargs = dict_from_config(getattr(cfg, "encodec")) + encoder_name = kwargs.pop("autoencoder") + quantizer_name = kwargs.pop("quantizer") encoder, _ = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) - frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', False) + frame_rate = kwargs["sample_rate"] // encoder.hop_length + renormalize = kwargs.pop("renormalize", False) # deprecated params - kwargs.pop('renorm', None) - compression_model = EncodecModel_encode_only(encoder, quantizer, - frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) - assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" + kwargs.pop("renorm", None) + compression_model = EncodecModel_encode_only( + encoder, quantizer, frame_rate=frame_rate, renormalize=renormalize, **kwargs + ).to(cfg.device) + assert ( + compression_model.sample_rate == cfg.sample_rate + ), "Compression model sample rate should match" compression_model.load_state_dict(weights) compression_model.eval() return compression_model else: - kwargs = dict_from_config(getattr(cfg, 'encodec')) - encoder_name = kwargs.pop('autoencoder') - quantizer_name = kwargs.pop('quantizer') + kwargs = dict_from_config(getattr(cfg, "encodec")) + encoder_name = kwargs.pop("autoencoder") + quantizer_name = kwargs.pop("quantizer") encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) - frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', False) + frame_rate = kwargs["sample_rate"] // encoder.hop_length + renormalize = kwargs.pop("renormalize", False) # deprecated params - kwargs.pop('renorm', None) - compression_model = EncodecModel(encoder, decoder, quantizer, - frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) - assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" + kwargs.pop("renorm", None) + compression_model = EncodecModel( + encoder, + decoder, + quantizer, + frame_rate=frame_rate, + renormalize=renormalize, + **kwargs, + ).to(cfg.device) + assert ( + compression_model.sample_rate == cfg.sample_rate + ), "Compression model sample rate should match" compression_model.load_state_dict(weights) compression_model.eval() return compression_model + if __name__ == "__main__": import torchaudio + ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th" - audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"] - audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + audio_in_fns = [ + "/home/pyp/BoostedVoiceEditor/demo/pam.wav", + "/home/pyp/BoostedVoiceEditor/demo/ray.wav", + "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", + "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", + "/home/pyp/BoostedVoiceEditor/demo/bible.wav", + "/home/pyp/BoostedVoiceEditor/demo/miley.wav", + ] + audio_out_fns = [ + "/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav", + ] + device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) model = get_compression_model(ckpt_fn, device=device) for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns): @@ -1551,4 +1847,4 @@ def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> Compressi audio_in = audio_in.to(torch.float32).to(device) codes = model.encode(audio_in)[0] audio_out = model.decode(codes)[0].cpu() - torchaudio.save(audio_out_fn, audio_out, model.sample_rate) \ No newline at end of file + torchaudio.save(audio_out_fn, audio_out, model.sample_rate) diff --git a/data/emilia_preprocessing/sha256hash.py b/data/emilia_preprocessing/sha256hash.py index 42802a2..3e43ad1 100644 --- a/data/emilia_preprocessing/sha256hash.py +++ b/data/emilia_preprocessing/sha256hash.py @@ -1,6 +1,7 @@ import hashlib import sys + def sha256_hash_file(filename): sha256_hash = hashlib.sha256() with open(filename, "rb") as file: @@ -9,6 +10,7 @@ def sha256_hash_file(filename): sha256_hash.update(byte_block) return sha256_hash.hexdigest() + # Usage example filename = sys.argv[1] -print(sha256_hash_file(filename)) \ No newline at end of file +print(sha256_hash_file(filename)) diff --git a/data/emilia_preprocessing/step1_download.py b/data/emilia_preprocessing/step1_download.py index 601b4de..258f294 100644 --- a/data/emilia_preprocessing/step1_download.py +++ b/data/emilia_preprocessing/step1_download.py @@ -1,9 +1,19 @@ # conda activate emilia from datasets import load_dataset import fire -def main(root: str="/data/scratch/pyp/datasets/emilia"): + + +def main(root: str = "/data/scratch/pyp/datasets/emilia"): path = "EN/*.tar.gz*" - dataset = load_dataset("amphion/Emilia-Dataset", data_files={"en": path}, split="en", streaming=False, revision="fc71e07e8572f5f3be1dbd02ed3172a4d298f152", cache_dir=root) + dataset = load_dataset( + "amphion/Emilia-Dataset", + data_files={"en": path}, + split="en", + streaming=False, + revision="fc71e07e8572f5f3be1dbd02ed3172a4d298f152", + cache_dir=root, + ) + if __name__ == "__main__": - fire.Fire(main) \ No newline at end of file + fire.Fire(main) diff --git a/data/emilia_preprocessing/step4_construct_manifest.py b/data/emilia_preprocessing/step4_construct_manifest.py index b97d868..aaae973 100644 --- a/data/emilia_preprocessing/step4_construct_manifest.py +++ b/data/emilia_preprocessing/step4_construct_manifest.py @@ -12,17 +12,22 @@ from multiprocessing import Pool import glob, os from collections import defaultdict + + def write_jsonl(data, fn): with open(fn, "w") as file: for entry in data: file.write(json.dumps(entry, ensure_ascii=False) + "\n") + + def read_jsonl(file_path): cur_data = [] - with open(file_path, 'r', encoding='utf-8-sig') as file: + with open(file_path, "r", encoding="utf-8-sig") as file: for line in file: cur_data.append(json.loads(line.strip())) return cur_data + def repetition_found(text, length=2, tolerance=10): pattern_count = defaultdict(int) for i in range(len(text) - length + 1): @@ -34,7 +39,6 @@ def repetition_found(text, length=2, tolerance=10): return False - out_en = { "EN_B00013_S00913", "EN_B00042_S00120", @@ -114,6 +118,7 @@ def repetition_found(text, length=2, tolerance=10): from multiprocessing import Pool + def process_meta_item(item, root, sub_root, audio_folder, audio_ext, text_ext): global filtered_duration, filtered_count, total_duration, total_count # Data filtering following Yushen's approach @@ -123,13 +128,15 @@ def process_meta_item(item, root, sub_root, audio_folder, audio_ext, text_ext): or repetition_found(item["text"], length=4) ): return None, item["duration"], 1, 0, 0, (None, None) # Return filtered results - + # Trim leading space from text if exists if item["text"].startswith(" "): item["text"] = item["text"][1:] - + # write text to text file - text_fn = os.path.join(root, sub_root, audio_folder, item["wav"].replace(audio_ext, text_ext)) + text_fn = os.path.join( + root, sub_root, audio_folder, item["wav"].replace(audio_ext, text_ext) + ) os.makedirs(os.path.dirname(text_fn), exist_ok=True) with open(text_fn, "w") as f: f.write(item["text"]) @@ -141,24 +148,29 @@ def process_meta_item(item, root, sub_root, audio_folder, audio_ext, text_ext): 0, item["duration"], 1, - (item['speaker'], item) + (item["speaker"], item), ) # Return processed results -def parallel_process_meta(meta, root, sub_root, audio_folder, num_workers, audio_ext, text_ext): +def parallel_process_meta( + meta, root, sub_root, audio_folder, num_workers, audio_ext, text_ext +): with Pool(num_workers) as pool: results = pool.starmap( process_meta_item, - [(item, root, sub_root, audio_folder, audio_ext, text_ext) for item in meta], + [ + (item, root, sub_root, audio_folder, audio_ext, text_ext) + for item in meta + ], ) - + processed_items = [] spkitem = [] filtered_duration = 0 filtered_count = 0 total_duration = 0 total_count = 0 - + for result in results: if result[0]: # If the item was processed processed_items.append(result[0]) @@ -167,8 +179,15 @@ def parallel_process_meta(meta, root, sub_root, audio_folder, num_workers, audio total_duration += result[3] total_count += result[4] spkitem.append(result[5]) - - return processed_items, filtered_duration, filtered_count, total_duration, total_count, spkitem + + return ( + processed_items, + filtered_duration, + filtered_count, + total_duration, + total_count, + spkitem, + ) def main( @@ -189,7 +208,7 @@ def main( ] print(f"found {len(all_fns)} untarred segments") print(f"{all_fns[:3]}") - + res = [] total_duration = 0 total_count = 0 @@ -200,12 +219,12 @@ def main( spk2info = defaultdict(list) metafn = os.path.join(root, "EN", os.path.basename(fn) + ".jsonl") meta = read_jsonl(metafn) - + # Parallel process metadata processed_items, fd, fc, td, tc, spkitem = parallel_process_meta( meta, root, sub_root, audio_folder, num_workers, audio_ext, text_ext ) - + # Aggregate results res.extend(processed_items) filtered_duration += fd @@ -216,7 +235,7 @@ def main( for spk, item in spkitem: if spk: spk2info[spk].append(item) - + # Save neighbor files for spk in spk2info: for item in spk2info[spk]: @@ -227,11 +246,15 @@ def main( item["wav"].replace(audio_ext, text_ext), ) os.makedirs(os.path.dirname(neighbor_fn), exist_ok=True) - tobe_write = [f"{neighbor_item['wav'].replace(audio_ext, text_ext)}\t0\t{neighbor_item['duration']}\n" for neighbor_item in spk2info[spk] if neighbor_item["wav"] != item["wav"]] + tobe_write = [ + f"{neighbor_item['wav'].replace(audio_ext, text_ext)}\t0\t{neighbor_item['duration']}\n" + for neighbor_item in spk2info[spk] + if neighbor_item["wav"] != item["wav"] + ] if tobe_write: with open(neighbor_fn, "w") as f: f.writelines(tobe_write) - + print( f"total duration: {total_duration / 3600:.2f} hours, total count: {total_count}" ) @@ -248,4 +271,4 @@ def main( if __name__ == "__main__": import fire - fire.Fire(main) \ No newline at end of file + fire.Fire(main) diff --git a/data/emilia_preprocessing/step5_phonemize.py b/data/emilia_preprocessing/step5_phonemize.py index bf0d046..ae8148b 100644 --- a/data/emilia_preprocessing/step5_phonemize.py +++ b/data/emilia_preprocessing/step5_phonemize.py @@ -6,9 +6,11 @@ from multiprocessing import Pool import glob, os, fire from collections import defaultdict + sys.path.insert(0, "../../") from data.tokenizer import TextTokenizer, tokenize_text + def write_jsonl(data, fn): with open(fn, "w") as file: for entry in data: @@ -17,7 +19,7 @@ def write_jsonl(data, fn): def read_jsonl(file_path): cur_data = [] - with open(file_path, 'r', encoding='utf-8-sig') as file: + with open(file_path, "r", encoding="utf-8-sig") as file: for line in file: cur_data.append(json.loads(line.strip())) return cur_data @@ -32,9 +34,21 @@ def phonemize_and_save(text, fn, text_tokenizer): return set(phn) -def process_item(item, root, sub_root, audio_folder, phn_folder, audio_ext, text_ext, phn_ext, text_tokenizer): +def process_item( + item, + root, + sub_root, + audio_folder, + phn_folder, + audio_ext, + text_ext, + phn_ext, + text_tokenizer, +): """Worker function to process a single item.""" - text_path = os.path.join(root, sub_root, audio_folder, item[0].replace(audio_ext, text_ext)) + text_path = os.path.join( + root, sub_root, audio_folder, item[0].replace(audio_ext, text_ext) + ) if not os.path.exists(text_path): return {"missing_text": text_path, "success": False, "cur_phn_set": set()} @@ -42,7 +56,9 @@ def process_item(item, root, sub_root, audio_folder, phn_folder, audio_ext, text text = [line.strip() for line in f.readlines()] text = " ".join(text) - phn_path = os.path.join(root, sub_root, phn_folder, item[0].replace(audio_ext, phn_ext)) + phn_path = os.path.join( + root, sub_root, phn_folder, item[0].replace(audio_ext, phn_ext) + ) cur_phn_set = phonemize_and_save(text, phn_path, text_tokenizer) return {"missing_text": None, "success": True, "cur_phn_set": cur_phn_set} @@ -51,6 +67,7 @@ def process_item_star(args): """Unpacks arguments for `process_item` to work with `imap`.""" return process_item(*args) + def main( root="/data/scratch/pyp/datasets/emilia", sub_root="preprocessed", @@ -73,7 +90,7 @@ def main( for fn in all_fns: with open(fn, "r") as f: data += [line.strip().split("\t") for line in f] - + vocab = set() ################## parallel processing ################## @@ -118,7 +135,9 @@ def main( ################## sequential processing ################## missing_text = [] for item in tqdm.tqdm(data): - text_path = os.path.join(root, sub_root, audio_folder, item[0].replace(audio_ext, text_ext)) + text_path = os.path.join( + root, sub_root, audio_folder, item[0].replace(audio_ext, text_ext) + ) if not os.path.exists(text_path): missing_text.append(text_path) continue @@ -129,7 +148,13 @@ def main( except: print(f"Error reading {text_path}") continue - cur_phn_set = phonemize_and_save(text, os.path.join(root, sub_root, phn_folder, item[0].replace(audio_ext, phn_ext)), text_tokenizer) + cur_phn_set = phonemize_and_save( + text, + os.path.join( + root, sub_root, phn_folder, item[0].replace(audio_ext, phn_ext) + ), + text_tokenizer, + ) vocab.update(cur_phn_set) ################## sequential processing ################## ################## sequential processing ################## @@ -145,14 +170,10 @@ def main( # Collect missing text paths print(f"Missing text files: {len(missing_text)}") if missing_text: - print("Some missing files:", missing_text[:10]) # Print the first 10 missing files as an example - + print( + "Some missing files:", missing_text[:10] + ) # Print the first 10 missing files as an example + if __name__ == "__main__": fire.Fire(main) - - - - - - diff --git a/data/emilia_preprocessing/step6_encodec_encode.py b/data/emilia_preprocessing/step6_encodec_encode.py index 952d422..a7e3037 100644 --- a/data/emilia_preprocessing/step6_encodec_encode.py +++ b/data/emilia_preprocessing/step6_encodec_encode.py @@ -1,27 +1,70 @@ import argparse from email.policy import default + + def parse_args(): parser = argparse.ArgumentParser(description="encode the dataset using codec model") - parser.add_argument('--root', type=str, default="/data/scratch/pyp/datasets/emilia", help="Path to the directory") - parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory") - parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model") - parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes") - parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus") - parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate') - parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate') - parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate') - parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate') - parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number') - parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number') - parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing') - parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test']) + parser.add_argument( + "--root", + type=str, + default="/data/scratch/pyp/datasets/emilia", + help="Path to the directory", + ) + parser.add_argument( + "--sub_root", type=str, default="preprocessed", help="sub directory" + ) + parser.add_argument( + "--encodec_name", + type=str, + default="encodec_6f79c6a8.th", + help="name of the codec model", + ) + parser.add_argument( + "--n_workers", type=int, default=16, help="Number of parallel worker processes" + ) + parser.add_argument( + "--batch_size", + type=int, + default=16, + help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus", + ) + parser.add_argument( + "--audio_sr", type=int, default=16000, help="input audio sample rate" + ) + parser.add_argument( + "--model_sr", type=int, default=16000, help="encodec input audio sample rate" + ) + parser.add_argument( + "--downsample_rate", type=int, default=320, help="encodec downsample rate" + ) + parser.add_argument( + "--model_code_sr", type=float, default=50, help="codec model code sample rate" + ) + parser.add_argument( + "--len_cap", + type=float, + default=1000, + help="will drop audios that are longer than this number", + ) + parser.add_argument( + "--min_len", + type=float, + default=0.5, + help="will drop audios that are shorter than this number", + ) + parser.add_argument( + "--partition", type=str, default="1/1", help="split for parallel processing" + ) + parser.add_argument( + "--split", type=str, default="train", choices=["train", "valid", "test"] + ) return parser.parse_args() + if __name__ == "__main__": import logging - formatter = ( - "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" - ) + + formatter = "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) import os, sys @@ -35,31 +78,43 @@ def parse_args(): def sort_by_audio_len(lens): inds = np.argsort(lens).tolist() - - logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.") - logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.") - logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.") - logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.") + + logging.info( + f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec." + ) + logging.info( + f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec." + ) + logging.info( + f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec." + ) + logging.info( + f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec." + ) return inds[::-1] def write_array_to_txt_file(array, filename): - with open(filename, 'w') as f: + with open(filename, "w") as f: for a in array[:-1]: - f.write(' '.join(map(str, a))+'\n') - f.write(' '.join(map(str, array[-1]))) + f.write(" ".join(map(str, a)) + "\n") + f.write(" ".join(map(str, array[-1]))) class mydataset(torch.utils.data.Dataset): def __init__(self, split): super().__init__() self.split = split self.audio_dir = audio_dir - manifest_fn = os.path.join(encodec_manifest_dir, split+".txt") - cur_sp = int(args.partition.split("/")[0])-1 + manifest_fn = os.path.join(encodec_manifest_dir, split + ".txt") + cur_sp = int(args.partition.split("/")[0]) - 1 total_sp = int(args.partition.split("/")[1]) with open(manifest_fn, "r") as rf: - self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp] + self.data = [l.strip().split("\t") for l in rf.readlines()][ + cur_sp::total_sp + ] + def __len__(self): return len(self.data) + def __getitem__(self, ind): try: afn = self.data[ind][0] @@ -72,8 +127,13 @@ def __getitem__(self, ind): except Exception as e: # logging.info(f"{e}") return None, None, None - assert audio.ndim==2 and audio.shape[0] == 1, audio.shape - return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0] + assert audio.ndim == 2 and audio.shape[0] == 1, audio.shape + return ( + audio.type(torch.float32).squeeze(0), + audio.shape[-1], + os.path.splitext(afn)[0], + ) + def collate(self, batch): lens, audios, segment_ids = [], [], [] for item in batch: @@ -82,49 +142,70 @@ def collate(self, batch): lens.append(item[1]) segment_ids.append(item[2]) return audios, lens, segment_ids - + # roots sub_root = args.sub_root encodec_manifest_dir = os.path.join(args.root, sub_root, "manifest_for_codec") audio_dir = os.path.join(args.root, sub_root, "audio") - save_manifest_dir = os.path.join(args.root, sub_root,"manifest_final_encodec") + save_manifest_dir = os.path.join(args.root, sub_root, "manifest_final_encodec") if args.encodec_name == "encodec_6f79c6a8.th": - save_codes_dir = os.path.join(args.root, sub_root,"encodec_4cb") + save_codes_dir = os.path.join(args.root, sub_root, "encodec_4cb") elif args.encodec_name == "encodec_8cb1024_giga.th": - save_codes_dir = os.path.join(args.root, sub_root,"encodec_8cb") + save_codes_dir = os.path.join(args.root, sub_root, "encodec_8cb") os.makedirs(save_manifest_dir, exist_ok=True) os.makedirs(save_codes_dir, exist_ok=True) - + def import_encodec(): from encodec import get_compression_model + userdir = os.path.expanduser("~") - model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda") + model = get_compression_model( + os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), + encode_only=True, + device="cuda", + ) model = torch.nn.DataParallel(model) return model + model = import_encodec() - + # setup dataloader mega_batch_size = 2048 batch_size = args.batch_size - + dataset = mydataset(args.split) if len(dataset) == 0: logging.info(f"no data found for split {args.split} partition {args.partition}") sys.exit(0) - loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate) + loader = torch.torch.utils.data.DataLoader( + dataset, + batch_size=mega_batch_size, + shuffle=False, + drop_last=False, + num_workers=args.n_workers, + collate_fn=dataset.collate, + ) split = args.split skip = 0 logging.info(f"now processing split {split} partition {args.partition}...") mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size)) # mega_n_steps = int(np.ceil(len(gs) / mega_batch_size)) - logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples") - mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt") - logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}") + logging.info( + f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples" + ) + mani_fn = os.path.join( + save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt" + ) + logging.info( + f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}" + ) with open(mani_fn, "w") as mani_wf: - # with open(mani_fn, "a") as mani_wf: # resume from where we failed - for m, mega_batch in enumerate(tqdm.tqdm(loader, mininterval=60, maxinterval=60)): + # with open(mani_fn, "a") as mani_wf: # resume from where we failed + for m, mega_batch in enumerate( + tqdm.tqdm(loader, mininterval=60, maxinterval=60) + ): logging.info(f"====================================") logging.info(f"====================================") @@ -134,44 +215,65 @@ def import_encodec(): lengths = np.array(mega_batch[1]) sorted_inds = sort_by_audio_len(lengths) for j in range(len(sorted_inds))[::-1]: - if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) + if ( + lengths[sorted_inds[j]] < args.model_sr * args.min_len + or lengths[sorted_inds[j]] > args.model_sr * args.len_cap + ): # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) skip += 1 del sorted_inds[j] - + n_steps = int(np.ceil(len(sorted_inds) / batch_size)) for n in tqdm.tqdm(range(n_steps), disable=True): - inds_used = sorted_inds[n*batch_size:(n+1)*batch_size] + inds_used = sorted_inds[n * batch_size : (n + 1) * batch_size] wav_batch = [mega_batch[0][id] for id in inds_used] all_lens = [mega_batch[1][id] for id in inds_used] segment_id_batch = [mega_batch[2][id] for id in inds_used] - padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T] + padded_wav = torch.nn.utils.rnn.pad_sequence( + wav_batch, batch_first=True + ).unsqueeze( + 1 + ) # [B, T] -> [B, 1, T] # Extract discrete codes from EnCodec with torch.no_grad(): - if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time + if ( + max(all_lens) > 300000 and len(all_lens) > 1 + ): # if utterances are long, simply pass half of them at a time codes = [] inwav = padded_wav.cuda() - codes.append(model(inwav[:len(inwav)//2])[0].cpu()) - codes.append(model(inwav[len(inwav)//2:])[0].cpu()) + codes.append(model(inwav[: len(inwav) // 2])[0].cpu()) + codes.append(model(inwav[len(inwav) // 2 :])[0].cpu()) codes = torch.cat(codes, dim=0) else: - encoded_frames = model(padded_wav.cuda()) - codes = encoded_frames[0].cpu() # [B, n_codebook, T] + encoded_frames = model(padded_wav.cuda()) + codes = encoded_frames[0].cpu() # [B, n_codebook, T] for i, length in enumerate(all_lens): - save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt") - actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model - cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist() + save_fn = os.path.join( + save_codes_dir, segment_id_batch[i] + ".txt" + ) + actual_len = round( + length / args.downsample_rate + ) # 320 is downsample rate for this model + cur_code = ( + codes[i].tolist() + if type(codes) == list + else codes[i, :, :actual_len].tolist() + ) os.makedirs(os.path.dirname(save_fn), exist_ok=True) write_array_to_txt_file(cur_code, save_fn) - mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file + mani_wf.write( + f"{segment_id_batch[i]}\t{len(cur_code[0])}\n" + ) # write to manifest file # if i == 10: # raise except Exception as e: - print(f'exception!! at {m+1}') + print(f"exception!! at {m+1}") print(e) continue # break - logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short") - # break + logging.info( + f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short" + ) + # break diff --git a/data/encodec.py b/data/encodec.py index be65978..fc74ecc 100644 --- a/data/encodec.py +++ b/data/encodec.py @@ -25,10 +25,13 @@ import warnings from einops import rearrange, repeat import omegaconf + # import flashy -CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', - 'time_group_norm']) +CONV_NORMALIZATIONS = frozenset( + ["none", "weight_norm", "spectral_norm", "time_group_norm"] +) + def dict_from_config(cfg: omegaconf.DictConfig) -> dict: """Convenience function to map an omegaconf configuration to a dictionary. @@ -42,6 +45,7 @@ def dict_from_config(cfg: omegaconf.DictConfig) -> dict: assert isinstance(dct, dict) return dct + @dataclass class QuantizedResult: x: torch.Tensor @@ -50,9 +54,9 @@ class QuantizedResult: penalty: tp.Optional[torch.Tensor] = None metrics: dict = field(default_factory=dict) + class BaseQuantizer(nn.Module): - """Base class for quantizers. - """ + """Base class for quantizers.""" def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: """ @@ -85,17 +89,19 @@ def set_num_codebooks(self, n: int): """Set the number of active codebooks.""" raise NotImplementedError() + class CompressionModel(ABC, nn.Module): """Base API for all compression model that aim at being used as audio tokenizers with a language model. """ @abstractmethod - def forward(self, x: torch.Tensor) -> QuantizedResult: - ... + def forward(self, x: torch.Tensor) -> QuantizedResult: ... @abstractmethod - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """See `EncodecModel.encode`.""" ... @@ -111,44 +117,39 @@ def decode_latent(self, codes: torch.Tensor): @property @abstractmethod - def channels(self) -> int: - ... + def channels(self) -> int: ... @property @abstractmethod - def frame_rate(self) -> float: - ... + def frame_rate(self) -> float: ... @property @abstractmethod - def sample_rate(self) -> int: - ... + def sample_rate(self) -> int: ... @property @abstractmethod - def cardinality(self) -> int: - ... + def cardinality(self) -> int: ... @property @abstractmethod - def num_codebooks(self) -> int: - ... + def num_codebooks(self) -> int: ... @property @abstractmethod - def total_codebooks(self) -> int: - ... + def total_codebooks(self) -> int: ... @abstractmethod def set_num_codebooks(self, n: int): """Set the active number of codebooks used by the quantizer.""" ... -def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): + +def apply_parametrization_norm(module: nn.Module, norm: str = "none"): assert norm in CONV_NORMALIZATIONS - if norm == 'weight_norm': + if norm == "weight_norm": return weight_norm(module) - elif norm == 'spectral_norm': + elif norm == "spectral_norm": return spectral_norm(module) else: # We already check was in CONV_NORMALIZATION, so any other choice @@ -156,12 +157,14 @@ def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): return module -def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs): +def get_norm_module( + module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs +): """Return the proper normalization module. If causal is True, this will ensure the returned module is causal, or return an error if the normalization doesn't support causal evaluation. """ assert norm in CONV_NORMALIZATIONS - if norm == 'time_group_norm': + if norm == "time_group_norm": if causal: raise ValueError("GroupNorm doesn't support causal evaluation.") assert isinstance(module, nn.modules.conv._ConvNd) @@ -170,8 +173,9 @@ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', return nn.Identity() -def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, - padding_total: int = 0) -> int: +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: """See `pad_for_conv1d`.""" length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 @@ -179,7 +183,9 @@ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, return ideal_length - length -def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): +def pad_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +): """Pad for a convolution to make sure that the last window is full. Extra padding is added at the end. This is required to ensure that we can rebuild an output of the same length, as otherwise, even with padding, some time steps @@ -194,14 +200,19 @@ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total return F.pad(x, (0, extra_padding)) -def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): +def pad1d( + x: torch.Tensor, + paddings: tp.Tuple[int, int], + mode: str = "constant", + value: float = 0.0, +): """Tiny wrapper around F.pad, just to allow for reflect padding on small input. If this is the case, we insert extra 0 padding to the right before the reflection happen. """ length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == 'reflect': + if mode == "reflect": max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: @@ -220,15 +231,22 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert (padding_left + padding_right) <= x.shape[-1] end = x.shape[-1] - padding_right - return x[..., padding_left: end] + return x[..., padding_left:end] class NormConv1d(nn.Module): """Wrapper around Conv1d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, causal: bool = False, norm: str = 'none', - norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) @@ -244,7 +262,14 @@ class NormConv2d(nn.Module): """Wrapper around Conv2d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) @@ -260,10 +285,19 @@ class NormConvTranspose1d(nn.Module): """Wrapper around ConvTranspose1d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, causal: bool = False, norm: str = 'none', - norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() - self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.convtr = apply_parametrization_norm( + nn.ConvTranspose1d(*args, **kwargs), norm + ) self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) self.norm_type = norm @@ -277,9 +311,18 @@ class NormConvTranspose2d(nn.Module): """Wrapper around ConvTranspose2d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() - self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.convtr = apply_parametrization_norm( + nn.ConvTranspose2d(*args, **kwargs), norm + ) self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) def forward(self, x): @@ -292,19 +335,40 @@ class StreamableConv1d(nn.Module): """Conv1d with some builtin handling of asymmetric or causal padding and normalization. """ - def __init__(self, in_channels: int, out_channels: int, - kernel_size: int, stride: int = 1, dilation: int = 1, - groups: int = 1, bias: bool = True, causal: bool = False, - norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, - pad_mode: str = 'reflect'): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = "reflect", + ): super().__init__() # warn user on unusual setup between dilation and stride if stride > 1 and dilation > 1: - warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1" - f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).") - self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, - dilation=dilation, groups=groups, bias=bias, causal=causal, - norm=norm, norm_kwargs=norm_kwargs) + warnings.warn( + "StreamableConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) self.causal = causal self.pad_mode = pad_mode @@ -313,9 +377,13 @@ def forward(self, x): kernel_size = self.conv.conv.kernel_size[0] stride = self.conv.conv.stride[0] dilation = self.conv.conv.dilation[0] - kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + kernel_size = ( + kernel_size - 1 + ) * dilation + 1 # effective kernel size with dilations padding_total = kernel_size - stride - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + extra_padding = get_extra_padding_for_conv1d( + x, kernel_size, stride, padding_total + ) if self.causal: # Left padding for causal x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) @@ -323,7 +391,9 @@ def forward(self, x): # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right - x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) return self.conv(x) @@ -331,18 +401,34 @@ class StreamableConvTranspose1d(nn.Module): """ConvTranspose1d with some builtin handling of asymmetric or causal padding and normalization. """ - def __init__(self, in_channels: int, out_channels: int, - kernel_size: int, stride: int = 1, causal: bool = False, - norm: str = 'none', trim_right_ratio: float = 1., - norm_kwargs: tp.Dict[str, tp.Any] = {}): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: tp.Dict[str, tp.Any] = {}, + ): super().__init__() - self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, - causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) self.causal = causal self.trim_right_ratio = trim_right_ratio - assert self.causal or self.trim_right_ratio == 1., \ - "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" - assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 def forward(self, x): kernel_size = self.convtr.convtr.kernel_size[0] @@ -373,6 +459,7 @@ class StreamableLSTM(nn.Module): """LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout. """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): super().__init__() self.skip = skip @@ -404,12 +491,25 @@ class SEANetResnetBlock(nn.Module): true_skip (bool): Whether to use true skip connection or a simple (streamable) convolution as the skip connection. """ - def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], - activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, - pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): + + def __init__( + self, + dim: int, + kernel_sizes: tp.List[int] = [3, 1], + dilations: tp.List[int] = [1, 1], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): super().__init__() - assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" act = getattr(nn, activation) hidden = dim // compress block = [] @@ -418,17 +518,31 @@ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp. out_chs = dim if i == len(kernel_sizes) - 1 else hidden block += [ act(**activation_params), - StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, - norm=norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode), + StreamableConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] self.block = nn.Sequential(*block) self.shortcut: nn.Module if true_skip: self.shortcut = nn.Identity() else: - self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode) + self.shortcut = StreamableConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) def forward(self, x): return self.shortcut(x) + self.block(x) @@ -462,12 +576,29 @@ class SEANetEncoder(nn.Module): disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. For the encoder, it corresponds to the N first blocks. """ - def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, - ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, - last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, - pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, - disable_norm_outer_blocks: int = 0): + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = True, + compress: int = 2, + lstm: int = 0, + disable_norm_outer_blocks: int = 0, + ): super().__init__() self.channels = channels self.dimension = dimension @@ -478,36 +609,61 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, self.hop_length = np.prod(self.ratios) self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks - assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ - "Number of blocks for which to disable norm is invalid." \ + assert ( + self.disable_norm_outer_blocks >= 0 + and self.disable_norm_outer_blocks <= self.n_blocks + ), ( + "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + ) act = getattr(nn, activation) mult = 1 model: tp.List[nn.Module] = [ - StreamableConv1d(channels, mult * n_filters, kernel_size, - norm='none' if self.disable_norm_outer_blocks >= 1 else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + channels, + mult * n_filters, + kernel_size, + norm="none" if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) ] # Downsample to raw audio scale for i, ratio in enumerate(self.ratios): - block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm + block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm # Add residual layers for j in range(n_residual_layers): model += [ - SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], - norm=block_norm, norm_params=norm_params, - activation=activation, activation_params=activation_params, - causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + SEANetResnetBlock( + mult * n_filters, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + norm=block_norm, + norm_params=norm_params, + activation=activation, + activation_params=activation_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] # Add downsampling layers model += [ act(**activation_params), - StreamableConv1d(mult * n_filters, mult * n_filters * 2, - kernel_size=ratio * 2, stride=ratio, - norm=block_norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode), + StreamableConv1d( + mult * n_filters, + mult * n_filters * 2, + kernel_size=ratio * 2, + stride=ratio, + norm=block_norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] mult *= 2 @@ -516,9 +672,17 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, model += [ act(**activation_params), - StreamableConv1d(mult * n_filters, dimension, last_kernel_size, - norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + mult * n_filters, + dimension, + last_kernel_size, + norm=( + "none" if self.disable_norm_outer_blocks == self.n_blocks else norm + ), + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] self.model = nn.Sequential(*model) @@ -557,13 +721,32 @@ class SEANetDecoder(nn.Module): trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. If equal to 1.0, it means that all the trimming is done at the right. """ - def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, - ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, - last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, - pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, - disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0): + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + final_activation: tp.Optional[str] = None, + final_activation_params: tp.Optional[dict] = None, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = True, + compress: int = 2, + lstm: int = 0, + disable_norm_outer_blocks: int = 0, + trim_right_ratio: float = 1.0, + ): super().__init__() self.dimension = dimension self.channels = channels @@ -574,16 +757,28 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, self.hop_length = np.prod(self.ratios) self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks - assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ - "Number of blocks for which to disable norm is invalid." \ + assert ( + self.disable_norm_outer_blocks >= 0 + and self.disable_norm_outer_blocks <= self.n_blocks + ), ( + "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + ) act = getattr(nn, activation) mult = int(2 ** len(self.ratios)) model: tp.List[nn.Module] = [ - StreamableConv1d(dimension, mult * n_filters, kernel_size, - norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=( + "none" if self.disable_norm_outer_blocks == self.n_blocks else norm + ), + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) ] if lstm: @@ -591,40 +786,63 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, # Upsample to raw audio scale for i, ratio in enumerate(self.ratios): - block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm + block_norm = ( + "none" + if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) + else norm + ) # Add upsampling layers model += [ act(**activation_params), - StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2, - kernel_size=ratio * 2, stride=ratio, - norm=block_norm, norm_kwargs=norm_params, - causal=causal, trim_right_ratio=trim_right_ratio), + StreamableConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=block_norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), ] # Add residual layers for j in range(n_residual_layers): model += [ - SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], - activation=activation, activation_params=activation_params, - norm=block_norm, norm_params=norm_params, causal=causal, - pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + SEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=block_norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] mult //= 2 # Add final layers model += [ act(**activation_params), - StreamableConv1d(n_filters, channels, last_kernel_size, - norm='none' if self.disable_norm_outer_blocks >= 1 else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + n_filters, + channels, + last_kernel_size, + norm="none" if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] # Add optional final activation to decoder (eg. tanh) if final_activation is not None: final_act = getattr(nn, final_activation) final_activation_params = final_activation_params or {} - model += [ - final_act(**final_activation_params) - ] + model += [final_act(**final_activation_params)] self.model = nn.Sequential(*model) def forward(self, z): @@ -675,10 +893,8 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10): means = sample_vectors(samples, num_clusters) for _ in range(num_iters): - diffs = rearrange(samples, "n d -> n () d") - rearrange( - means, "c d -> () c d" - ) - dists = -(diffs ** 2).sum(dim=-1) + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) @@ -700,7 +916,7 @@ def orthogonal_loss_fn(t): normed_codes = l2norm(t) identity = torch.eye(n, device=t.device) cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) - return ((cosine_sim - identity) ** 2).sum() / (n ** 2) + return ((cosine_sim - identity) ** 2).sum() / (n**2) class EuclideanCodebook(nn.Module): @@ -719,6 +935,7 @@ class EuclideanCodebook(nn.Module): that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ + def __init__( self, dim: int, @@ -731,7 +948,9 @@ def __init__( ): super().__init__() self.decay = decay - init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) embed = init_fn(codebook_size, dim) self.codebook_size = codebook_size @@ -862,6 +1081,7 @@ class VectorQuantization(nn.Module): that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ + def __init__( self, dim: int, @@ -873,7 +1093,7 @@ def __init__( kmeans_iters: int = 10, threshold_ema_dead_code: int = 2, channels_last: bool = False, - commitment_weight: float = 1., + commitment_weight: float = 1.0, orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: tp.Optional[int] = None, @@ -882,8 +1102,12 @@ def __init__( _codebook_dim: int = default(codebook_dim, dim) requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) self.epsilon = epsilon self.commitment_weight = commitment_weight @@ -892,10 +1116,15 @@ def __init__( self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes - self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, - decay=decay, epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code) + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) self.codebook_size = codebook_size self.channels_last = channels_last @@ -956,8 +1185,13 @@ def forward(self, x): codebook = codebook[unique_code_ids] num_codes = codebook.shape[0] - if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: - rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] + if ( + exists(self.orthogonal_reg_max_codes) + and num_codes > self.orthogonal_reg_max_codes + ): + rand_ids = torch.randperm(num_codes, device=device)[ + : self.orthogonal_reg_max_codes + ] codebook = codebook[rand_ids] orthogonal_reg_loss = orthogonal_loss_fn(codebook) @@ -974,17 +1208,20 @@ class ResidualVectorQuantization(nn.Module): Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ + def __init__(self, *, num_quantizers, **kwargs): super().__init__() - codebook_size = kwargs.pop('codebook_size', None) + codebook_size = kwargs.pop("codebook_size", None) if codebook_size is None: raise ValueError("codebook_size must be provided in kwargs") if type(codebook_size) != list: codebook_size = [codebook_size] * num_quantizers self.layers = nn.ModuleList( - [VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)] + [ + VectorQuantization(codebook_size=cur_codebook_size, **kwargs) + for _, cur_codebook_size in zip(range(num_quantizers), codebook_size) + ] ) - # self.layers = nn.ModuleList( # [VectorQuantization(**kwargs) for _ in range(num_quantizers)] @@ -1058,6 +1295,7 @@ class ResidualVectorQuantizer(BaseQuantizer): orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. for orthogonal regularization. """ + def __init__( self, dimension: int = 256, @@ -1096,7 +1334,7 @@ def __init__( orthogonal_reg_weight=self.orthogonal_reg_weight, orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, - channels_last=False + channels_last=False, ) def forward(self, x: torch.Tensor, frame_rate: int): @@ -1144,15 +1382,18 @@ def set_num_codebooks(self, n: int): assert n > 0 and n <= self.max_n_q self.n_q = n + class DummyQuantizer(BaseQuantizer): - """Fake quantizer that actually does not perform any quantization. - """ + """Fake quantizer that actually does not perform any quantization.""" + def __init__(self): super().__init__() def forward(self, x: torch.Tensor, frame_rate: int): q = x.unsqueeze(1) - return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) + return QuantizedResult( + x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x) + ) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode a given input tensor with the specified sample rate at the given bandwidth. @@ -1180,7 +1421,9 @@ def num_codebooks(self): def set_num_codebooks(self, n: int): """Set the number of active codebooks.""" - raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") + raise AttributeError( + "Cannot override the number of codebooks for the dummy quantizer" + ) class EncodecModel(CompressionModel): @@ -1196,21 +1439,24 @@ class EncodecModel(CompressionModel): causal (bool): Whether to use a causal version of the model. renormalize (bool): Whether to renormalize the audio before running the model. """ + # we need assignment to override the property in the abstract class, # I couldn't find a better way... frame_rate: float = 0 sample_rate: int = 0 channels: int = 0 - def __init__(self, - encoder: nn.Module, - decoder: nn.Module, - quantizer: BaseQuantizer, - frame_rate: int, - sample_rate: int, - channels: int, - causal: bool = False, - renormalize: bool = False): + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + quantizer: BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False, + ): super().__init__() self.encoder = encoder self.decoder = decoder @@ -1223,7 +1469,7 @@ def __init__(self, if self.causal: # we force disabling here to avoid handling linear overlap of segments # as supported in original EnCodec codebase. - assert not self.renormalize, 'Causal model does not support renormalize' + assert not self.renormalize, "Causal model does not support renormalize" @property def total_codebooks(self): @@ -1244,7 +1490,9 @@ def cardinality(self): """Cardinality of each codebook.""" return self.quantizer.bins - def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def preprocess( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: scale: tp.Optional[torch.Tensor] if self.renormalize: mono = x.mean(dim=1, keepdim=True) @@ -1256,9 +1504,9 @@ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torc scale = None return x, scale - def postprocess(self, - x: torch.Tensor, - scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + def postprocess( + self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: if scale is not None: assert self.renormalize x = x * scale.view(-1, 1, 1) @@ -1285,7 +1533,9 @@ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: return q_res - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Encode the given input tensor to quantized representation along with scale parameter. Args: @@ -1323,6 +1573,7 @@ def decode_latent(self, codes: torch.Tensor): """Decode from the discrete codes to continuous latent space.""" return self.quantizer.decode(codes) + class EncodecModel_encode_only(CompressionModel): """Encodec model operating on the raw waveform. Encode only, so no decoder @@ -1335,20 +1586,23 @@ class EncodecModel_encode_only(CompressionModel): causal (bool): Whether to use a causal version of the model. renormalize (bool): Whether to renormalize the audio before running the model. """ + # we need assignment to override the property in the abstract class, # I couldn't find a better way... frame_rate: float = 0 sample_rate: int = 0 channels: int = 0 - def __init__(self, - encoder: nn.Module, - quantizer: BaseQuantizer, - frame_rate: int, - sample_rate: int, - channels: int, - causal: bool = False, - renormalize: bool = False): + def __init__( + self, + encoder: nn.Module, + quantizer: BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False, + ): super().__init__() self.encoder = encoder self.quantizer = quantizer @@ -1360,7 +1614,7 @@ def __init__(self, if self.causal: # we force disabling here to avoid handling linear overlap of segments # as supported in original EnCodec codebase. - assert not self.renormalize, 'Causal model does not support renormalize' + assert not self.renormalize, "Causal model does not support renormalize" @property def total_codebooks(self): @@ -1381,7 +1635,9 @@ def cardinality(self): """Cardinality of each codebook.""" return self.quantizer.bins - def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def preprocess( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: scale: tp.Optional[torch.Tensor] if self.renormalize: mono = x.mean(dim=1, keepdim=True) @@ -1393,9 +1649,9 @@ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torc scale = None return x, scale - def postprocess(self, - x: torch.Tensor, - scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + def postprocess( + self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: if scale is not None: assert self.renormalize x = x * scale.view(-1, 1, 1) @@ -1422,7 +1678,9 @@ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: return q_res - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Encode the given input tensor to quantized representation along with scale parameter. Args: @@ -1462,21 +1720,22 @@ def decode_latent(self, codes: torch.Tensor): raise NotImplementedError("Decode is not supported for encode only model") return self.quantizer.decode(codes) -def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer: - klass = { - 'no_quant': DummyQuantizer, - 'rvq': ResidualVectorQuantizer - }[quantizer] + +def get_quantizer( + quantizer: str, cfg: omegaconf.DictConfig, dimension: int +) -> BaseQuantizer: + klass = {"no_quant": DummyQuantizer, "rvq": ResidualVectorQuantizer}[quantizer] kwargs = dict_from_config(getattr(cfg, quantizer)) - if quantizer != 'no_quant': - kwargs['dimension'] = dimension + if quantizer != "no_quant": + kwargs["dimension"] = dimension return klass(**kwargs) + def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): - if encoder_name == 'seanet': - kwargs = dict_from_config(getattr(cfg, 'seanet')) - encoder_override_kwargs = kwargs.pop('encoder') - decoder_override_kwargs = kwargs.pop('decoder') + if encoder_name == "seanet": + kwargs = dict_from_config(getattr(cfg, "seanet")) + encoder_override_kwargs = kwargs.pop("encoder") + decoder_override_kwargs = kwargs.pop("decoder") encoder_kwargs = {**kwargs, **encoder_override_kwargs} decoder_kwargs = {**kwargs, **decoder_override_kwargs} encoder = SEANetEncoder(**encoder_kwargs) @@ -1489,56 +1748,95 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel: """Instantiate a compression model.""" if device == None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - state = torch.load(ckpt_fn, map_location='cpu') - cfg = state['xp.cfg'] + device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + state = torch.load( + ckpt_fn, map_location="cpu", weights_only=False + ) # TODO: Convert to SafeTensors + cfg = state["xp.cfg"] cfg.device = str(device) - weights = state['best_state']['model'] - assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now." + weights = state["best_state"]["model"] + assert ( + cfg.compression_model == "encodec" + ), "Only Encodec model is supported for now." if encode_only: all_keys = list(weights.keys()) for key in all_keys: - if key.startswith('decoder'): + if key.startswith("decoder"): del weights[key] - kwargs = dict_from_config(getattr(cfg, 'encodec')) - encoder_name = kwargs.pop('autoencoder') - quantizer_name = kwargs.pop('quantizer') + kwargs = dict_from_config(getattr(cfg, "encodec")) + encoder_name = kwargs.pop("autoencoder") + quantizer_name = kwargs.pop("quantizer") encoder, _ = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) - frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', False) + frame_rate = kwargs["sample_rate"] // encoder.hop_length + renormalize = kwargs.pop("renormalize", False) # deprecated params - kwargs.pop('renorm', None) - compression_model = EncodecModel_encode_only(encoder, quantizer, - frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) - assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" + kwargs.pop("renorm", None) + compression_model = EncodecModel_encode_only( + encoder, quantizer, frame_rate=frame_rate, renormalize=renormalize, **kwargs + ).to(cfg.device) + assert ( + compression_model.sample_rate == cfg.sample_rate + ), "Compression model sample rate should match" compression_model.load_state_dict(weights) compression_model.eval() return compression_model else: - kwargs = dict_from_config(getattr(cfg, 'encodec')) - encoder_name = kwargs.pop('autoencoder') - quantizer_name = kwargs.pop('quantizer') + kwargs = dict_from_config(getattr(cfg, "encodec")) + encoder_name = kwargs.pop("autoencoder") + quantizer_name = kwargs.pop("quantizer") encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) - frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', False) + frame_rate = kwargs["sample_rate"] // encoder.hop_length + renormalize = kwargs.pop("renormalize", False) # deprecated params - kwargs.pop('renorm', None) - compression_model = EncodecModel(encoder, decoder, quantizer, - frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) - assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" + kwargs.pop("renorm", None) + compression_model = EncodecModel( + encoder, + decoder, + quantizer, + frame_rate=frame_rate, + renormalize=renormalize, + **kwargs, + ).to(cfg.device) + assert ( + compression_model.sample_rate == cfg.sample_rate + ), "Compression model sample rate should match" compression_model.load_state_dict(weights) compression_model.eval() return compression_model + if __name__ == "__main__": import torchaudio + ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th" - audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"] - audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + audio_in_fns = [ + "/home/pyp/BoostedVoiceEditor/demo/pam.wav", + "/home/pyp/BoostedVoiceEditor/demo/ray.wav", + "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", + "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", + "/home/pyp/BoostedVoiceEditor/demo/bible.wav", + "/home/pyp/BoostedVoiceEditor/demo/miley.wav", + ] + audio_out_fns = [ + "/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav", + ] + device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) model = get_compression_model(ckpt_fn, device=device) for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns): @@ -1551,4 +1849,4 @@ def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> Compressi audio_in = audio_in.to(torch.float32).to(device) codes = model.encode(audio_in)[0] audio_out = model.decode(codes)[0].cpu() - torchaudio.save(audio_out_fn, audio_out, model.sample_rate) \ No newline at end of file + torchaudio.save(audio_out_fn, audio_out, model.sample_rate) diff --git a/data/ll60k_preprocessing/encodec.py b/data/ll60k_preprocessing/encodec.py index be65978..5c7579d 100644 --- a/data/ll60k_preprocessing/encodec.py +++ b/data/ll60k_preprocessing/encodec.py @@ -25,10 +25,13 @@ import warnings from einops import rearrange, repeat import omegaconf + # import flashy -CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', - 'time_group_norm']) +CONV_NORMALIZATIONS = frozenset( + ["none", "weight_norm", "spectral_norm", "time_group_norm"] +) + def dict_from_config(cfg: omegaconf.DictConfig) -> dict: """Convenience function to map an omegaconf configuration to a dictionary. @@ -42,6 +45,7 @@ def dict_from_config(cfg: omegaconf.DictConfig) -> dict: assert isinstance(dct, dict) return dct + @dataclass class QuantizedResult: x: torch.Tensor @@ -50,9 +54,9 @@ class QuantizedResult: penalty: tp.Optional[torch.Tensor] = None metrics: dict = field(default_factory=dict) + class BaseQuantizer(nn.Module): - """Base class for quantizers. - """ + """Base class for quantizers.""" def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: """ @@ -85,17 +89,19 @@ def set_num_codebooks(self, n: int): """Set the number of active codebooks.""" raise NotImplementedError() + class CompressionModel(ABC, nn.Module): """Base API for all compression model that aim at being used as audio tokenizers with a language model. """ @abstractmethod - def forward(self, x: torch.Tensor) -> QuantizedResult: - ... + def forward(self, x: torch.Tensor) -> QuantizedResult: ... @abstractmethod - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """See `EncodecModel.encode`.""" ... @@ -111,44 +117,39 @@ def decode_latent(self, codes: torch.Tensor): @property @abstractmethod - def channels(self) -> int: - ... + def channels(self) -> int: ... @property @abstractmethod - def frame_rate(self) -> float: - ... + def frame_rate(self) -> float: ... @property @abstractmethod - def sample_rate(self) -> int: - ... + def sample_rate(self) -> int: ... @property @abstractmethod - def cardinality(self) -> int: - ... + def cardinality(self) -> int: ... @property @abstractmethod - def num_codebooks(self) -> int: - ... + def num_codebooks(self) -> int: ... @property @abstractmethod - def total_codebooks(self) -> int: - ... + def total_codebooks(self) -> int: ... @abstractmethod def set_num_codebooks(self, n: int): """Set the active number of codebooks used by the quantizer.""" ... -def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): + +def apply_parametrization_norm(module: nn.Module, norm: str = "none"): assert norm in CONV_NORMALIZATIONS - if norm == 'weight_norm': + if norm == "weight_norm": return weight_norm(module) - elif norm == 'spectral_norm': + elif norm == "spectral_norm": return spectral_norm(module) else: # We already check was in CONV_NORMALIZATION, so any other choice @@ -156,12 +157,14 @@ def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): return module -def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs): +def get_norm_module( + module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs +): """Return the proper normalization module. If causal is True, this will ensure the returned module is causal, or return an error if the normalization doesn't support causal evaluation. """ assert norm in CONV_NORMALIZATIONS - if norm == 'time_group_norm': + if norm == "time_group_norm": if causal: raise ValueError("GroupNorm doesn't support causal evaluation.") assert isinstance(module, nn.modules.conv._ConvNd) @@ -170,8 +173,9 @@ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', return nn.Identity() -def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, - padding_total: int = 0) -> int: +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: """See `pad_for_conv1d`.""" length = x.shape[-1] n_frames = (length - kernel_size + padding_total) / stride + 1 @@ -179,7 +183,9 @@ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, return ideal_length - length -def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): +def pad_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +): """Pad for a convolution to make sure that the last window is full. Extra padding is added at the end. This is required to ensure that we can rebuild an output of the same length, as otherwise, even with padding, some time steps @@ -194,14 +200,19 @@ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total return F.pad(x, (0, extra_padding)) -def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): +def pad1d( + x: torch.Tensor, + paddings: tp.Tuple[int, int], + mode: str = "constant", + value: float = 0.0, +): """Tiny wrapper around F.pad, just to allow for reflect padding on small input. If this is the case, we insert extra 0 padding to the right before the reflection happen. """ length = x.shape[-1] padding_left, padding_right = paddings assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == 'reflect': + if mode == "reflect": max_pad = max(padding_left, padding_right) extra_pad = 0 if length <= max_pad: @@ -220,15 +231,22 @@ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) assert (padding_left + padding_right) <= x.shape[-1] end = x.shape[-1] - padding_right - return x[..., padding_left: end] + return x[..., padding_left:end] class NormConv1d(nn.Module): """Wrapper around Conv1d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, causal: bool = False, norm: str = 'none', - norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) @@ -244,7 +262,14 @@ class NormConv2d(nn.Module): """Wrapper around Conv2d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) @@ -260,10 +285,19 @@ class NormConvTranspose1d(nn.Module): """Wrapper around ConvTranspose1d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, causal: bool = False, norm: str = 'none', - norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() - self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.convtr = apply_parametrization_norm( + nn.ConvTranspose1d(*args, **kwargs), norm + ) self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) self.norm_type = norm @@ -277,9 +311,18 @@ class NormConvTranspose2d(nn.Module): """Wrapper around ConvTranspose2d and normalization applied to this conv to provide a uniform interface across normalization approaches. """ - def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): super().__init__() - self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) + self.convtr = apply_parametrization_norm( + nn.ConvTranspose2d(*args, **kwargs), norm + ) self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) def forward(self, x): @@ -292,19 +335,40 @@ class StreamableConv1d(nn.Module): """Conv1d with some builtin handling of asymmetric or causal padding and normalization. """ - def __init__(self, in_channels: int, out_channels: int, - kernel_size: int, stride: int = 1, dilation: int = 1, - groups: int = 1, bias: bool = True, causal: bool = False, - norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, - pad_mode: str = 'reflect'): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = "reflect", + ): super().__init__() # warn user on unusual setup between dilation and stride if stride > 1 and dilation > 1: - warnings.warn("StreamableConv1d has been initialized with stride > 1 and dilation > 1" - f" (kernel_size={kernel_size} stride={stride}, dilation={dilation}).") - self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, - dilation=dilation, groups=groups, bias=bias, causal=causal, - norm=norm, norm_kwargs=norm_kwargs) + warnings.warn( + "StreamableConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) self.causal = causal self.pad_mode = pad_mode @@ -313,9 +377,13 @@ def forward(self, x): kernel_size = self.conv.conv.kernel_size[0] stride = self.conv.conv.stride[0] dilation = self.conv.conv.dilation[0] - kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + kernel_size = ( + kernel_size - 1 + ) * dilation + 1 # effective kernel size with dilations padding_total = kernel_size - stride - extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + extra_padding = get_extra_padding_for_conv1d( + x, kernel_size, stride, padding_total + ) if self.causal: # Left padding for causal x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) @@ -323,7 +391,9 @@ def forward(self, x): # Asymmetric padding required for odd strides padding_right = padding_total // 2 padding_left = padding_total - padding_right - x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) return self.conv(x) @@ -331,18 +401,34 @@ class StreamableConvTranspose1d(nn.Module): """ConvTranspose1d with some builtin handling of asymmetric or causal padding and normalization. """ - def __init__(self, in_channels: int, out_channels: int, - kernel_size: int, stride: int = 1, causal: bool = False, - norm: str = 'none', trim_right_ratio: float = 1., - norm_kwargs: tp.Dict[str, tp.Any] = {}): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: tp.Dict[str, tp.Any] = {}, + ): super().__init__() - self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, - causal=causal, norm=norm, norm_kwargs=norm_kwargs) + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) self.causal = causal self.trim_right_ratio = trim_right_ratio - assert self.causal or self.trim_right_ratio == 1., \ - "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" - assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 def forward(self, x): kernel_size = self.convtr.convtr.kernel_size[0] @@ -373,6 +459,7 @@ class StreamableLSTM(nn.Module): """LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout. """ + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): super().__init__() self.skip = skip @@ -404,12 +491,25 @@ class SEANetResnetBlock(nn.Module): true_skip (bool): Whether to use true skip connection or a simple (streamable) convolution as the skip connection. """ - def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], - activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, - pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): + + def __init__( + self, + dim: int, + kernel_sizes: tp.List[int] = [3, 1], + dilations: tp.List[int] = [1, 1], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): super().__init__() - assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" act = getattr(nn, activation) hidden = dim // compress block = [] @@ -418,17 +518,31 @@ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp. out_chs = dim if i == len(kernel_sizes) - 1 else hidden block += [ act(**activation_params), - StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, - norm=norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode), + StreamableConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] self.block = nn.Sequential(*block) self.shortcut: nn.Module if true_skip: self.shortcut = nn.Identity() else: - self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode) + self.shortcut = StreamableConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) def forward(self, x): return self.shortcut(x) + self.block(x) @@ -462,12 +576,29 @@ class SEANetEncoder(nn.Module): disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. For the encoder, it corresponds to the N first blocks. """ - def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, - ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, - last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, - pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, - disable_norm_outer_blocks: int = 0): + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = True, + compress: int = 2, + lstm: int = 0, + disable_norm_outer_blocks: int = 0, + ): super().__init__() self.channels = channels self.dimension = dimension @@ -478,36 +609,61 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, self.hop_length = np.prod(self.ratios) self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks - assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ - "Number of blocks for which to disable norm is invalid." \ + assert ( + self.disable_norm_outer_blocks >= 0 + and self.disable_norm_outer_blocks <= self.n_blocks + ), ( + "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + ) act = getattr(nn, activation) mult = 1 model: tp.List[nn.Module] = [ - StreamableConv1d(channels, mult * n_filters, kernel_size, - norm='none' if self.disable_norm_outer_blocks >= 1 else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + channels, + mult * n_filters, + kernel_size, + norm="none" if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) ] # Downsample to raw audio scale for i, ratio in enumerate(self.ratios): - block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm + block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm # Add residual layers for j in range(n_residual_layers): model += [ - SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], - norm=block_norm, norm_params=norm_params, - activation=activation, activation_params=activation_params, - causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + SEANetResnetBlock( + mult * n_filters, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + norm=block_norm, + norm_params=norm_params, + activation=activation, + activation_params=activation_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] # Add downsampling layers model += [ act(**activation_params), - StreamableConv1d(mult * n_filters, mult * n_filters * 2, - kernel_size=ratio * 2, stride=ratio, - norm=block_norm, norm_kwargs=norm_params, - causal=causal, pad_mode=pad_mode), + StreamableConv1d( + mult * n_filters, + mult * n_filters * 2, + kernel_size=ratio * 2, + stride=ratio, + norm=block_norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] mult *= 2 @@ -516,9 +672,17 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, model += [ act(**activation_params), - StreamableConv1d(mult * n_filters, dimension, last_kernel_size, - norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + mult * n_filters, + dimension, + last_kernel_size, + norm=( + "none" if self.disable_norm_outer_blocks == self.n_blocks else norm + ), + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] self.model = nn.Sequential(*model) @@ -557,13 +721,32 @@ class SEANetDecoder(nn.Module): trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. If equal to 1.0, it means that all the trimming is done at the right. """ - def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, - ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, - final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, - norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, - last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, - pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, - disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0): + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + final_activation: tp.Optional[str] = None, + final_activation_params: tp.Optional[dict] = None, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = True, + compress: int = 2, + lstm: int = 0, + disable_norm_outer_blocks: int = 0, + trim_right_ratio: float = 1.0, + ): super().__init__() self.dimension = dimension self.channels = channels @@ -574,16 +757,28 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, self.hop_length = np.prod(self.ratios) self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks - assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ - "Number of blocks for which to disable norm is invalid." \ + assert ( + self.disable_norm_outer_blocks >= 0 + and self.disable_norm_outer_blocks <= self.n_blocks + ), ( + "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + ) act = getattr(nn, activation) mult = int(2 ** len(self.ratios)) model: tp.List[nn.Module] = [ - StreamableConv1d(dimension, mult * n_filters, kernel_size, - norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=( + "none" if self.disable_norm_outer_blocks == self.n_blocks else norm + ), + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) ] if lstm: @@ -591,40 +786,63 @@ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, # Upsample to raw audio scale for i, ratio in enumerate(self.ratios): - block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm + block_norm = ( + "none" + if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) + else norm + ) # Add upsampling layers model += [ act(**activation_params), - StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2, - kernel_size=ratio * 2, stride=ratio, - norm=block_norm, norm_kwargs=norm_params, - causal=causal, trim_right_ratio=trim_right_ratio), + StreamableConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=block_norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), ] # Add residual layers for j in range(n_residual_layers): model += [ - SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], - dilations=[dilation_base ** j, 1], - activation=activation, activation_params=activation_params, - norm=block_norm, norm_params=norm_params, causal=causal, - pad_mode=pad_mode, compress=compress, true_skip=true_skip)] + SEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=block_norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] mult //= 2 # Add final layers model += [ act(**activation_params), - StreamableConv1d(n_filters, channels, last_kernel_size, - norm='none' if self.disable_norm_outer_blocks >= 1 else norm, - norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) + StreamableConv1d( + n_filters, + channels, + last_kernel_size, + norm="none" if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), ] # Add optional final activation to decoder (eg. tanh) if final_activation is not None: final_act = getattr(nn, final_activation) final_activation_params = final_activation_params or {} - model += [ - final_act(**final_activation_params) - ] + model += [final_act(**final_activation_params)] self.model = nn.Sequential(*model) def forward(self, z): @@ -675,10 +893,8 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10): means = sample_vectors(samples, num_clusters) for _ in range(num_iters): - diffs = rearrange(samples, "n d -> n () d") - rearrange( - means, "c d -> () c d" - ) - dists = -(diffs ** 2).sum(dim=-1) + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) @@ -700,7 +916,7 @@ def orthogonal_loss_fn(t): normed_codes = l2norm(t) identity = torch.eye(n, device=t.device) cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) - return ((cosine_sim - identity) ** 2).sum() / (n ** 2) + return ((cosine_sim - identity) ** 2).sum() / (n**2) class EuclideanCodebook(nn.Module): @@ -719,6 +935,7 @@ class EuclideanCodebook(nn.Module): that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ + def __init__( self, dim: int, @@ -731,7 +948,9 @@ def __init__( ): super().__init__() self.decay = decay - init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) embed = init_fn(codebook_size, dim) self.codebook_size = codebook_size @@ -862,6 +1081,7 @@ class VectorQuantization(nn.Module): that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ + def __init__( self, dim: int, @@ -873,7 +1093,7 @@ def __init__( kmeans_iters: int = 10, threshold_ema_dead_code: int = 2, channels_last: bool = False, - commitment_weight: float = 1., + commitment_weight: float = 1.0, orthogonal_reg_weight: float = 0.0, orthogonal_reg_active_codes_only: bool = False, orthogonal_reg_max_codes: tp.Optional[int] = None, @@ -882,8 +1102,12 @@ def __init__( _codebook_dim: int = default(codebook_dim, dim) requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) self.epsilon = epsilon self.commitment_weight = commitment_weight @@ -892,10 +1116,15 @@ def __init__( self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes - self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, - decay=decay, epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code) + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) self.codebook_size = codebook_size self.channels_last = channels_last @@ -956,8 +1185,13 @@ def forward(self, x): codebook = codebook[unique_code_ids] num_codes = codebook.shape[0] - if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes: - rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes] + if ( + exists(self.orthogonal_reg_max_codes) + and num_codes > self.orthogonal_reg_max_codes + ): + rand_ids = torch.randperm(num_codes, device=device)[ + : self.orthogonal_reg_max_codes + ] codebook = codebook[rand_ids] orthogonal_reg_loss = orthogonal_loss_fn(codebook) @@ -974,17 +1208,20 @@ class ResidualVectorQuantization(nn.Module): Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ + def __init__(self, *, num_quantizers, **kwargs): super().__init__() - codebook_size = kwargs.pop('codebook_size', None) + codebook_size = kwargs.pop("codebook_size", None) if codebook_size is None: raise ValueError("codebook_size must be provided in kwargs") if type(codebook_size) != list: codebook_size = [codebook_size] * num_quantizers self.layers = nn.ModuleList( - [VectorQuantization(codebook_size=cur_codebook_size, **kwargs) for _,cur_codebook_size in zip(range(num_quantizers), codebook_size)] + [ + VectorQuantization(codebook_size=cur_codebook_size, **kwargs) + for _, cur_codebook_size in zip(range(num_quantizers), codebook_size) + ] ) - # self.layers = nn.ModuleList( # [VectorQuantization(**kwargs) for _ in range(num_quantizers)] @@ -1058,6 +1295,7 @@ class ResidualVectorQuantizer(BaseQuantizer): orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. for orthogonal regularization. """ + def __init__( self, dimension: int = 256, @@ -1096,7 +1334,7 @@ def __init__( orthogonal_reg_weight=self.orthogonal_reg_weight, orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, - channels_last=False + channels_last=False, ) def forward(self, x: torch.Tensor, frame_rate: int): @@ -1144,15 +1382,18 @@ def set_num_codebooks(self, n: int): assert n > 0 and n <= self.max_n_q self.n_q = n + class DummyQuantizer(BaseQuantizer): - """Fake quantizer that actually does not perform any quantization. - """ + """Fake quantizer that actually does not perform any quantization.""" + def __init__(self): super().__init__() def forward(self, x: torch.Tensor, frame_rate: int): q = x.unsqueeze(1) - return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) + return QuantizedResult( + x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x) + ) def encode(self, x: torch.Tensor) -> torch.Tensor: """Encode a given input tensor with the specified sample rate at the given bandwidth. @@ -1180,7 +1421,9 @@ def num_codebooks(self): def set_num_codebooks(self, n: int): """Set the number of active codebooks.""" - raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") + raise AttributeError( + "Cannot override the number of codebooks for the dummy quantizer" + ) class EncodecModel(CompressionModel): @@ -1196,21 +1439,24 @@ class EncodecModel(CompressionModel): causal (bool): Whether to use a causal version of the model. renormalize (bool): Whether to renormalize the audio before running the model. """ + # we need assignment to override the property in the abstract class, # I couldn't find a better way... frame_rate: float = 0 sample_rate: int = 0 channels: int = 0 - def __init__(self, - encoder: nn.Module, - decoder: nn.Module, - quantizer: BaseQuantizer, - frame_rate: int, - sample_rate: int, - channels: int, - causal: bool = False, - renormalize: bool = False): + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + quantizer: BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False, + ): super().__init__() self.encoder = encoder self.decoder = decoder @@ -1223,7 +1469,7 @@ def __init__(self, if self.causal: # we force disabling here to avoid handling linear overlap of segments # as supported in original EnCodec codebase. - assert not self.renormalize, 'Causal model does not support renormalize' + assert not self.renormalize, "Causal model does not support renormalize" @property def total_codebooks(self): @@ -1244,7 +1490,9 @@ def cardinality(self): """Cardinality of each codebook.""" return self.quantizer.bins - def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def preprocess( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: scale: tp.Optional[torch.Tensor] if self.renormalize: mono = x.mean(dim=1, keepdim=True) @@ -1256,9 +1504,9 @@ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torc scale = None return x, scale - def postprocess(self, - x: torch.Tensor, - scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + def postprocess( + self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: if scale is not None: assert self.renormalize x = x * scale.view(-1, 1, 1) @@ -1285,7 +1533,9 @@ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: return q_res - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Encode the given input tensor to quantized representation along with scale parameter. Args: @@ -1323,6 +1573,7 @@ def decode_latent(self, codes: torch.Tensor): """Decode from the discrete codes to continuous latent space.""" return self.quantizer.decode(codes) + class EncodecModel_encode_only(CompressionModel): """Encodec model operating on the raw waveform. Encode only, so no decoder @@ -1335,20 +1586,23 @@ class EncodecModel_encode_only(CompressionModel): causal (bool): Whether to use a causal version of the model. renormalize (bool): Whether to renormalize the audio before running the model. """ + # we need assignment to override the property in the abstract class, # I couldn't find a better way... frame_rate: float = 0 sample_rate: int = 0 channels: int = 0 - def __init__(self, - encoder: nn.Module, - quantizer: BaseQuantizer, - frame_rate: int, - sample_rate: int, - channels: int, - causal: bool = False, - renormalize: bool = False): + def __init__( + self, + encoder: nn.Module, + quantizer: BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False, + ): super().__init__() self.encoder = encoder self.quantizer = quantizer @@ -1360,7 +1614,7 @@ def __init__(self, if self.causal: # we force disabling here to avoid handling linear overlap of segments # as supported in original EnCodec codebase. - assert not self.renormalize, 'Causal model does not support renormalize' + assert not self.renormalize, "Causal model does not support renormalize" @property def total_codebooks(self): @@ -1381,7 +1635,9 @@ def cardinality(self): """Cardinality of each codebook.""" return self.quantizer.bins - def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def preprocess( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: scale: tp.Optional[torch.Tensor] if self.renormalize: mono = x.mean(dim=1, keepdim=True) @@ -1393,9 +1649,9 @@ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torc scale = None return x, scale - def postprocess(self, - x: torch.Tensor, - scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: + def postprocess( + self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: if scale is not None: assert self.renormalize x = x * scale.view(-1, 1, 1) @@ -1422,7 +1678,9 @@ def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: return q_res - def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: """Encode the given input tensor to quantized representation along with scale parameter. Args: @@ -1462,21 +1720,22 @@ def decode_latent(self, codes: torch.Tensor): raise NotImplementedError("Decode is not supported for encode only model") return self.quantizer.decode(codes) -def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> BaseQuantizer: - klass = { - 'no_quant': DummyQuantizer, - 'rvq': ResidualVectorQuantizer - }[quantizer] + +def get_quantizer( + quantizer: str, cfg: omegaconf.DictConfig, dimension: int +) -> BaseQuantizer: + klass = {"no_quant": DummyQuantizer, "rvq": ResidualVectorQuantizer}[quantizer] kwargs = dict_from_config(getattr(cfg, quantizer)) - if quantizer != 'no_quant': - kwargs['dimension'] = dimension + if quantizer != "no_quant": + kwargs["dimension"] = dimension return klass(**kwargs) + def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): - if encoder_name == 'seanet': - kwargs = dict_from_config(getattr(cfg, 'seanet')) - encoder_override_kwargs = kwargs.pop('encoder') - decoder_override_kwargs = kwargs.pop('decoder') + if encoder_name == "seanet": + kwargs = dict_from_config(getattr(cfg, "seanet")) + encoder_override_kwargs = kwargs.pop("encoder") + decoder_override_kwargs = kwargs.pop("decoder") encoder_kwargs = {**kwargs, **encoder_override_kwargs} decoder_kwargs = {**kwargs, **decoder_override_kwargs} encoder = SEANetEncoder(**encoder_kwargs) @@ -1489,56 +1748,93 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel: """Instantiate a compression model.""" if device == None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - state = torch.load(ckpt_fn, map_location='cpu') - cfg = state['xp.cfg'] + device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + state = torch.load(ckpt_fn, map_location="cpu") + cfg = state["xp.cfg"] cfg.device = str(device) - weights = state['best_state']['model'] - assert cfg.compression_model == 'encodec', "Only Encodec model is supported for now." + weights = state["best_state"]["model"] + assert ( + cfg.compression_model == "encodec" + ), "Only Encodec model is supported for now." if encode_only: all_keys = list(weights.keys()) for key in all_keys: - if key.startswith('decoder'): + if key.startswith("decoder"): del weights[key] - kwargs = dict_from_config(getattr(cfg, 'encodec')) - encoder_name = kwargs.pop('autoencoder') - quantizer_name = kwargs.pop('quantizer') + kwargs = dict_from_config(getattr(cfg, "encodec")) + encoder_name = kwargs.pop("autoencoder") + quantizer_name = kwargs.pop("quantizer") encoder, _ = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) - frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', False) + frame_rate = kwargs["sample_rate"] // encoder.hop_length + renormalize = kwargs.pop("renormalize", False) # deprecated params - kwargs.pop('renorm', None) - compression_model = EncodecModel_encode_only(encoder, quantizer, - frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) - assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" + kwargs.pop("renorm", None) + compression_model = EncodecModel_encode_only( + encoder, quantizer, frame_rate=frame_rate, renormalize=renormalize, **kwargs + ).to(cfg.device) + assert ( + compression_model.sample_rate == cfg.sample_rate + ), "Compression model sample rate should match" compression_model.load_state_dict(weights) compression_model.eval() return compression_model else: - kwargs = dict_from_config(getattr(cfg, 'encodec')) - encoder_name = kwargs.pop('autoencoder') - quantizer_name = kwargs.pop('quantizer') + kwargs = dict_from_config(getattr(cfg, "encodec")) + encoder_name = kwargs.pop("autoencoder") + quantizer_name = kwargs.pop("quantizer") encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) - frame_rate = kwargs['sample_rate'] // encoder.hop_length - renormalize = kwargs.pop('renormalize', False) + frame_rate = kwargs["sample_rate"] // encoder.hop_length + renormalize = kwargs.pop("renormalize", False) # deprecated params - kwargs.pop('renorm', None) - compression_model = EncodecModel(encoder, decoder, quantizer, - frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) - assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match" + kwargs.pop("renorm", None) + compression_model = EncodecModel( + encoder, + decoder, + quantizer, + frame_rate=frame_rate, + renormalize=renormalize, + **kwargs, + ).to(cfg.device) + assert ( + compression_model.sample_rate == cfg.sample_rate + ), "Compression model sample rate should match" compression_model.load_state_dict(weights) compression_model.eval() return compression_model + if __name__ == "__main__": import torchaudio + ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th" - audio_in_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam.wav", "/home/pyp/BoostedVoiceEditor/demo/ray.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", "/home/pyp/BoostedVoiceEditor/demo/bible.wav", "/home/pyp/BoostedVoiceEditor/demo/miley.wav"] - audio_out_fns = ["/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav"] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + audio_in_fns = [ + "/home/pyp/BoostedVoiceEditor/demo/pam.wav", + "/home/pyp/BoostedVoiceEditor/demo/ray.wav", + "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", + "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", + "/home/pyp/BoostedVoiceEditor/demo/bible.wav", + "/home/pyp/BoostedVoiceEditor/demo/miley.wav", + ] + audio_out_fns = [ + "/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav", + ] + device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) model = get_compression_model(ckpt_fn, device=device) for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns): @@ -1551,4 +1847,4 @@ def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> Compressi audio_in = audio_in.to(torch.float32).to(device) codes = model.encode(audio_in)[0] audio_out = model.decode(codes)[0].cpu() - torchaudio.save(audio_out_fn, audio_out, model.sample_rate) \ No newline at end of file + torchaudio.save(audio_out_fn, audio_out, model.sample_rate) diff --git a/data/ll60k_preprocessing/step2_resplit_long.py b/data/ll60k_preprocessing/step2_resplit_long.py index a2a53d3..9f1150b 100644 --- a/data/ll60k_preprocessing/step2_resplit_long.py +++ b/data/ll60k_preprocessing/step2_resplit_long.py @@ -5,29 +5,44 @@ import os, random, numpy as np, socket import json import tqdm + + def write_jsonl(data, fn): with open(fn, "w") as file: for entry in data: file.write(json.dumps(entry, ensure_ascii=False) + "\n") + + def read_jsonl(file_path): cur_data = [] - with open(file_path, 'r', encoding='utf-8-sig') as file: + with open(file_path, "r", encoding="utf-8-sig") as file: for line in file: cur_data.append(json.loads(line.strip())) return cur_data + + import os -dataroot=os.environ["DATAROOT"] -manifestroot=os.path.join(dataroot, "libriheavy") -tgt_names = ['libriheavy_cuts_dev.jsonl', 'libriheavy_cuts_test_clean.jsonl', 'libriheavy_cuts_test_other.jsonl'] -orig_names = ['libriheavy_long_original_cuts_small.jsonl', 'libriheavy_long_original_cuts_medium.jsonl', 'libriheavy_long_original_cuts_large.jsonl'] + +dataroot = os.environ["DATAROOT"] +manifestroot = os.path.join(dataroot, "libriheavy") +tgt_names = [ + "libriheavy_cuts_dev.jsonl", + "libriheavy_cuts_test_clean.jsonl", + "libriheavy_cuts_test_other.jsonl", +] +orig_names = [ + "libriheavy_long_original_cuts_small.jsonl", + "libriheavy_long_original_cuts_medium.jsonl", + "libriheavy_long_original_cuts_large.jsonl", +] id2split = {} data = read_jsonl(os.path.join(manifestroot, "libriheavy_cuts_dev.jsonl")) -dev_ids = set(["/".join(item['id'].split("/")[:3]) for item in data]) +dev_ids = set(["/".join(item["id"].split("/")[:3]) for item in data]) data = read_jsonl(os.path.join(manifestroot, "libriheavy_cuts_test_clean.jsonl")) -test_clean_ids = set(["/".join(item['id'].split("/")[:3]) for item in data]) +test_clean_ids = set(["/".join(item["id"].split("/")[:3]) for item in data]) data = read_jsonl(os.path.join(manifestroot, "libriheavy_cuts_test_other.jsonl")) -test_other_ids = set(["/".join(item['id'].split("/")[:3]) for item in data]) +test_other_ids = set(["/".join(item["id"].split("/")[:3]) for item in data]) long_dev = [] long_test_clean = [] @@ -36,16 +51,20 @@ def read_jsonl(file_path): keep = [] data = read_jsonl(os.path.join(manifestroot, orig_name)) for item in tqdm.tqdm(data): - if "/".join(item['id'].split("/")[:3]) in dev_ids: + if "/".join(item["id"].split("/")[:3]) in dev_ids: long_dev.append(item) - elif "/".join(item['id'].split("/")[:3]) in test_clean_ids: + elif "/".join(item["id"].split("/")[:3]) in test_clean_ids: long_test_clean.append(item) - elif "/".join(item['id'].split("/")[:3]) in test_other_ids: + elif "/".join(item["id"].split("/")[:3]) in test_other_ids: long_test_other.append(item) else: keep.append(item) write_jsonl(keep, os.path.join(manifestroot, orig_name.replace("_original", ""))) write_jsonl(long_dev, os.path.join(manifestroot, "libriheavy_long_cuts_dev.jsonl")) -write_jsonl(long_test_clean, os.path.join(manifestroot, "libriheavy_long_cuts_test_clean.jsonl")) -write_jsonl(long_test_other, os.path.join(manifestroot, "libriheavy_long_cuts_test_other.jsonl")) \ No newline at end of file +write_jsonl( + long_test_clean, os.path.join(manifestroot, "libriheavy_long_cuts_test_clean.jsonl") +) +write_jsonl( + long_test_other, os.path.join(manifestroot, "libriheavy_long_cuts_test_other.jsonl") +) diff --git a/data/ll60k_preprocessing/step3_seg_phn_manifest.py b/data/ll60k_preprocessing/step3_seg_phn_manifest.py index 6743eef..e64fabf 100644 --- a/data/ll60k_preprocessing/step3_seg_phn_manifest.py +++ b/data/ll60k_preprocessing/step3_seg_phn_manifest.py @@ -15,33 +15,42 @@ import os, random, numpy as np, socket import json import tqdm + + def write_jsonl(data, fn): with open(fn, "w") as file: for entry in data: file.write(json.dumps(entry, ensure_ascii=False) + "\n") + + def read_jsonl(file_path): cur_data = [] - with open(file_path, 'r', encoding='utf-8-sig') as file: + with open(file_path, "r", encoding="utf-8-sig") as file: for line in file: cur_data.append(json.loads(line.strip())) return cur_data + + def save_audio(seq, fn): output = seq os.makedirs(os.path.dirname(fn), exist_ok=True) sf.write(fn, output, samplerate=16000) + def save_text(text, fn): os.makedirs(os.path.dirname(fn), exist_ok=True) with open(fn, "w") as wwf: wwf.writelines(text) + def phonemize_and_save(text, fn): phn = tokenize_text(text_tokenizer, text) os.makedirs(os.path.dirname(fn), exist_ok=True) with open(fn, "w") as f: - f.write(' '.join(phn)) + f.write(" ".join(phn)) return set(phn) + def cut_sequence(task): in_audio_fn, output_dir, metadata = task if not os.path.isfile(in_audio_fn): @@ -52,85 +61,103 @@ def cut_sequence(task): assert samplerate == 16000 all_phns = set() for item in metadata: - out_fn = item['file_id'] + out_fn = item["file_id"] out_audio_fn = os.path.join(output_dir, "audio", out_fn) out_text_fn = os.path.join(output_dir, "audio", out_fn.replace(".flac", ".txt")) - out_phn_fn = os.path.join(output_dir, "phoneme", out_fn.replace(".flac", ".txt")) - save_audio(data[int(item['vad'][0]*samplerate):int(item['vad'][1]*samplerate)], out_audio_fn) - save_text(item['text'], out_text_fn) - phns = phonemize_and_save(item['text'], out_phn_fn) + out_phn_fn = os.path.join( + output_dir, "phoneme", out_fn.replace(".flac", ".txt") + ) + save_audio( + data[int(item["vad"][0] * samplerate) : int(item["vad"][1] * samplerate)], + out_audio_fn, + ) + save_text(item["text"], out_text_fn) + phns = phonemize_and_save(item["text"], out_phn_fn) all_phns.update(phns) - + return all_phns from collections import defaultdict + + # Function to create a defaultdict recursively def nested_defaultdict(levels, inner_type): if levels <= 1: return defaultdict(inner_type) - return defaultdict(lambda: nested_defaultdict(levels-1, inner_type)) + return defaultdict(lambda: nested_defaultdict(levels - 1, inner_type)) def open_mani(fn): print("load segmentation and transcription metadata...") stime = time.time() data = [] - with gzip.open(fn, 'rt', encoding='utf-8') as f: + with gzip.open(fn, "rt", encoding="utf-8") as f: for line in f: data.append(json.loads(line)) print(f"loading done, took {time.time() - stime:.4f} seconds") return data -def cut(split, - audio_dir, - mani_dir, - output_dir, - n_process=32, - percent=0.5): + +def cut(split, audio_dir, mani_dir, output_dir, n_process=32, percent=0.5): split2manifest = { - "train": [ - "libriheavy_long_cuts_small.jsonl", - "libriheavy_long_cuts_medium.jsonl", - "libriheavy_long_cuts_large.jsonl", - "libriheavy_cuts_small.jsonl", - "libriheavy_cuts_medium.jsonl", - "libriheavy_cuts_large.jsonl", - ], - "valid": [ - "libriheavy_cuts_dev.jsonl", - "libriheavy_long_cuts_dev.jsonl" - ], - "test": [ - "libriheavy_cuts_test_clean.jsonl", - "libriheavy_cuts_test_other.jsonl", - "libriheavy_long_cuts_test_clean.jsonl", - "libriheavy_long_cuts_test_other.jsonl" - ] - } + "train": [ + "libriheavy_long_cuts_small.jsonl", + "libriheavy_long_cuts_medium.jsonl", + "libriheavy_long_cuts_large.jsonl", + "libriheavy_cuts_small.jsonl", + "libriheavy_cuts_medium.jsonl", + "libriheavy_cuts_large.jsonl", + ], + "valid": ["libriheavy_cuts_dev.jsonl", "libriheavy_long_cuts_dev.jsonl"], + "test": [ + "libriheavy_cuts_test_clean.jsonl", + "libriheavy_cuts_test_other.jsonl", + "libriheavy_long_cuts_test_clean.jsonl", + "libriheavy_long_cuts_test_other.jsonl", + ], + } print("organize data by recording_id (i.e. the original big .flac file name)...") stime = time.time() organized_data = nested_defaultdict(4, list) - manifest_fn = os.path.join(output_dir, "manifest_mimi", split+".txt") + manifest_fn = os.path.join(output_dir, "manifest_mimi", split + ".txt") os.makedirs(os.path.join(output_dir, "manifest_mimi"), exist_ok=True) with open(manifest_fn, "w") as wf: for mani_fn in split2manifest[split]: # data = open_mani(os.path.join(mani_dir, mani_fn)) data = read_jsonl(os.path.join(mani_dir, mani_fn)) for item in data: - file_id = item['supervisions'][0]['id'] + '.flac' - recording_id = item['recording']['id'] + '.flac' - sizeSplit, spk, book, flac = recording_id.split("/") # e.g. 'medium/100/emerald_city_librivox_64kb_mp3/emeraldcity_01_baum_64kb' + file_id = item["supervisions"][0]["id"] + ".flac" + recording_id = item["recording"]["id"] + ".flac" + sizeSplit, spk, book, flac = recording_id.split( + "/" + ) # e.g. 'medium/100/emerald_city_librivox_64kb_mp3/emeraldcity_01_baum_64kb' if os.path.isfile(os.path.join(audio_dir, recording_id)): - vad = (item['start'], item['start']+item['duration']) - text = item['supervisions'][0]['custom']['texts'][0] - file_id = file_id.replace(".flac", "") + f"_{vad[0]:.2f}_{vad[1]:.2f}.flac" - organized_data[sizeSplit][spk][book][recording_id].append({"file_id": file_id, "vad":vad, "text": text}) + vad = (item["start"], item["start"] + item["duration"]) + text = item["supervisions"][0]["custom"]["texts"][0] + file_id = ( + file_id.replace(".flac", "") + + f"_{vad[0]:.2f}_{vad[1]:.2f}.flac" + ) + organized_data[sizeSplit][spk][book][recording_id].append( + {"file_id": file_id, "vad": vad, "text": text} + ) wf.writelines(f"{file_id}\t{item['duration']}\n") - + # #### take only a subet of tasks - tasks = [(os.path.join(audio_dir, recording_id), output_dir, organized_data[sizeSplit][spk][book][recording_id], spk) for sizeSplit in organized_data for spk in organized_data[sizeSplit] for book in organized_data[sizeSplit][spk] for recording_id in organized_data[sizeSplit][spk][book]] + tasks = [ + ( + os.path.join(audio_dir, recording_id), + output_dir, + organized_data[sizeSplit][spk][book][recording_id], + spk, + ) + for sizeSplit in organized_data + for spk in organized_data[sizeSplit] + for book in organized_data[sizeSplit][spk] + for recording_id in organized_data[sizeSplit][spk][book] + ] ntasks = len(tasks) spk2tasks = defaultdict(list) for task in tasks: @@ -146,7 +173,9 @@ def cut(split, if len(spk2tasks[spk]) == 0: continue tasks.append(spk2tasks[spk].pop()[:-1]) - print(f"take only {percent*100:.2f}% of the tasks, {len(tasks)} out of {ntasks} tasks") + print( + f"take only {percent*100:.2f}% of the tasks, {len(tasks)} out of {ntasks} tasks" + ) #### take only a subet of tasks print(f"organizing done, took {time.time() - stime:.4f} seconds") @@ -154,7 +183,9 @@ def cut(split, phn_vocab = set() cnt = 0 with multiprocessing.Pool(processes=n_process) as pool: - for phns in tqdm.tqdm(pool.imap_unordered(cut_sequence, tasks), total=len(tasks)): + for phns in tqdm.tqdm( + pool.imap_unordered(cut_sequence, tasks), total=len(tasks) + ): cnt += 1 if phns != None: phn_vocab.update(phns) @@ -168,21 +199,46 @@ def cut(split, f.write(f"{str(i)}\t{phn}\n") else: f.write(f"{str(i)}\t{phn}") - -def parse_args(): - parser = argparse.ArgumentParser(description="Cut a dataset in small " - "sequences using VAD files") - parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'], help="train = libriheavy_cuts_{small,medium,large}.jsonl.gz, valid = libriheavy_cuts_dev_{clean,other}.jsonl.gz, test = libriheavy_cuts_test_{clean,other}.jsonl.gz") - parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/librilight_example", - help="Path to the audio directory") - parser.add_argument('--manifest_dir', type=str, default="/data/scratch/pyp/datasets/librilight/libriheavy", help="path to the transcription file's dir, can be downloaded https://huggingface.co/datasets/pkufool/libriheavy/tree/main/v0.1") - parser.add_argument('--output_dir', type=str, default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", - help="Path to the output directory") - parser.add_argument('--n_workers', type=int, default=16, - help="Number of parallel worker processes") - parser.add_argument('--percent', type=float, default=0.5, help="take only this percent of the tasks, randomly sampled from each speaker") +def parse_args(): + parser = argparse.ArgumentParser( + description="Cut a dataset in small " "sequences using VAD files" + ) + parser.add_argument( + "--split", + type=str, + default="train", + choices=["train", "valid", "test"], + help="train = libriheavy_cuts_{small,medium,large}.jsonl.gz, valid = libriheavy_cuts_dev_{clean,other}.jsonl.gz, test = libriheavy_cuts_test_{clean,other}.jsonl.gz", + ) + parser.add_argument( + "--audio_dir", + type=str, + default="/data/scratch/pyp/datasets/librilight_example", + help="Path to the audio directory", + ) + parser.add_argument( + "--manifest_dir", + type=str, + default="/data/scratch/pyp/datasets/librilight/libriheavy", + help="path to the transcription file's dir, can be downloaded https://huggingface.co/datasets/pkufool/libriheavy/tree/main/v0.1", + ) + parser.add_argument( + "--output_dir", + type=str, + default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", + help="Path to the output directory", + ) + parser.add_argument( + "--n_workers", type=int, default=16, help="Number of parallel worker processes" + ) + parser.add_argument( + "--percent", + type=float, + default=0.5, + help="take only this percent of the tasks, randomly sampled from each speaker", + ) return parser.parse_args() @@ -191,4 +247,11 @@ def parse_args(): args = parse_args() pathlib.Path(args.output_dir).mkdir(exist_ok=True, parents=True) text_tokenizer = TextTokenizer() - cut(args.split, args.audio_dir, args.manifest_dir, args.output_dir, args.n_workers, args.percent) \ No newline at end of file + cut( + args.split, + args.audio_dir, + args.manifest_dir, + args.output_dir, + args.n_workers, + args.percent, + ) diff --git a/data/ll60k_preprocessing/step4_encodec_encode.py b/data/ll60k_preprocessing/step4_encodec_encode.py index 7aaf62a..69d265a 100644 --- a/data/ll60k_preprocessing/step4_encodec_encode.py +++ b/data/ll60k_preprocessing/step4_encodec_encode.py @@ -1,27 +1,72 @@ import argparse from email.policy import default + + def parse_args(): - parser = argparse.ArgumentParser(description="encode the librilight dataset using codec model") - parser.add_argument('--dir', type=str, default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", help="Path to the directory") - parser.add_argument('--sub_root', type=str, default="preprocessed", help="sub directory") - parser.add_argument('--encodec_name', type=str, default="encodec_6f79c6a8.th", help="name of the codec model") - parser.add_argument('--n_workers', type=int, default=16, help="Number of parallel worker processes") - parser.add_argument('--batch_size', type=int, default=16, help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus") - parser.add_argument('--audio_sr', type=int, default=16000, help='input audio sample rate') - parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate') - parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate') - parser.add_argument('--model_code_sr', type=float, default=50, help='codec model code sample rate') - parser.add_argument('--len_cap', type=float, default=1000, help='will drop audios that are longer than this number') - parser.add_argument('--min_len', type=float, default=0.5, help='will drop audios that are shorter than this number') - parser.add_argument('--partition', type=str, default="1/1", help='split for parallel processing') - parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test']) + parser = argparse.ArgumentParser( + description="encode the librilight dataset using codec model" + ) + parser.add_argument( + "--dir", + type=str, + default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", + help="Path to the directory", + ) + parser.add_argument( + "--sub_root", type=str, default="preprocessed", help="sub directory" + ) + parser.add_argument( + "--encodec_name", + type=str, + default="encodec_6f79c6a8.th", + help="name of the codec model", + ) + parser.add_argument( + "--n_workers", type=int, default=16, help="Number of parallel worker processes" + ) + parser.add_argument( + "--batch_size", + type=int, + default=16, + help="batch size for codec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus", + ) + parser.add_argument( + "--audio_sr", type=int, default=16000, help="input audio sample rate" + ) + parser.add_argument( + "--model_sr", type=int, default=16000, help="encodec input audio sample rate" + ) + parser.add_argument( + "--downsample_rate", type=int, default=320, help="encodec downsample rate" + ) + parser.add_argument( + "--model_code_sr", type=float, default=50, help="codec model code sample rate" + ) + parser.add_argument( + "--len_cap", + type=float, + default=1000, + help="will drop audios that are longer than this number", + ) + parser.add_argument( + "--min_len", + type=float, + default=0.5, + help="will drop audios that are shorter than this number", + ) + parser.add_argument( + "--partition", type=str, default="1/1", help="split for parallel processing" + ) + parser.add_argument( + "--split", type=str, default="train", choices=["train", "valid", "test"] + ) return parser.parse_args() + if __name__ == "__main__": import logging - formatter = ( - "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" - ) + + formatter = "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) import os, sys @@ -37,31 +82,47 @@ def sort_by_audio_len(lens): inds = np.argsort(lens).tolist() if len(inds) < 10: return inds[::-1] - logging.info(f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec.") - logging.info(f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec.") - logging.info(f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec.") - logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec.") + logging.info( + f"longest: {lens[inds[-1]]/args.downsample_rate} encodec codes, {lens[inds[-1]]/args.model_sr:.2f} sec." + ) + logging.info( + f"shortest: {lens[inds[0]]/args.downsample_rate} encodec codes, {lens[inds[0]]/args.model_sr:.2f} sec." + ) + logging.info( + f"median: {lens[inds[len(inds)//2]]/args.downsample_rate} encodec codes, {lens[inds[len(inds)//2]]/args.model_sr:.2f} sec." + ) + logging.info( + f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]/args.downsample_rate} encodec codes, {lens[inds[int(len(inds)*0.95)]]/args.model_sr:.2f} sec." + ) return inds[::-1] def write_array_to_txt_file(array, filename): - with open(filename, 'w') as f: + with open(filename, "w") as f: for a in array[:-1]: - f.write(' '.join(map(str, a))+'\n') - f.write(' '.join(map(str, array[-1]))) + f.write(" ".join(map(str, a)) + "\n") + f.write(" ".join(map(str, array[-1]))) class mydataset(torch.utils.data.Dataset): def __init__(self, split): super().__init__() self.split = split self.audio_dir = audio_dir - manifest_fn = os.path.join(encodec_manifest_dir, split+".txt") - cur_sp = int(args.partition.split("/")[0])-1 + manifest_fn = os.path.join(encodec_manifest_dir, split + ".txt") + cur_sp = int(args.partition.split("/")[0]) - 1 total_sp = int(args.partition.split("/")[1]) with open(manifest_fn, "r") as rf: - self.data = [l.strip().split("\t") for l in rf.readlines()][cur_sp::total_sp] - self.data = [l for l in self.data if os.path.isfile(os.path.join(self.audio_dir, l[0]))] + self.data = [l.strip().split("\t") for l in rf.readlines()][ + cur_sp::total_sp + ] + self.data = [ + l + for l in self.data + if os.path.isfile(os.path.join(self.audio_dir, l[0])) + ] + def __len__(self): return len(self.data) + def __getitem__(self, ind): try: afn = self.data[ind][0] @@ -74,8 +135,13 @@ def __getitem__(self, ind): except Exception as e: # logging.info(f"{e}") return None, None, None - assert audio.ndim==2 and audio.shape[0] == 1, audio.shape - return audio.type(torch.float32).squeeze(0), audio.shape[-1], os.path.splitext(afn)[0] + assert audio.ndim == 2 and audio.shape[0] == 1, audio.shape + return ( + audio.type(torch.float32).squeeze(0), + audio.shape[-1], + os.path.splitext(afn)[0], + ) + def collate(self, batch): lens, audios, segment_ids = [], [], [] for item in batch: @@ -89,96 +155,138 @@ def collate(self, batch): sub_root = args.sub_root encodec_manifest_dir = os.path.join(args.dir, sub_root, "manifest_mimi") audio_dir = os.path.join(args.dir, sub_root, "audio") - save_manifest_dir = os.path.join(args.dir, sub_root,"manifest_final_encodec") + save_manifest_dir = os.path.join(args.dir, sub_root, "manifest_final_encodec") if args.encodec_name == "encodec_6f79c6a8.th": - save_codes_dir = os.path.join(args.dir, sub_root,"encodec_4cb") + save_codes_dir = os.path.join(args.dir, sub_root, "encodec_4cb") elif args.encodec_name == "encodec_8cb1024_giga.th": - save_codes_dir = os.path.join(args.dir, sub_root,"encodec_8cb") + save_codes_dir = os.path.join(args.dir, sub_root, "encodec_8cb") os.makedirs(save_manifest_dir, exist_ok=True) os.makedirs(save_codes_dir, exist_ok=True) - + # load the encodec model def import_encodec(): from encodec import get_compression_model + userdir = os.path.expanduser("~") - model = get_compression_model(os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), encode_only=True, device="cuda") + model = get_compression_model( + os.path.join(userdir, "VoiceStar", f"pretrained/{args.encodec_name}"), + encode_only=True, + device="cuda", + ) model = torch.nn.DataParallel(model) return model + model = import_encodec() - + # setup dataloader mega_batch_size = 1024 batch_size = args.batch_size - + dataset = mydataset(args.split) if len(dataset) == 0: logging.info(f"no data found for split {args.split} partition {args.partition}") sys.exit(0) - loader = torch.torch.utils.data.DataLoader(dataset, batch_size=mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=dataset.collate) + loader = torch.torch.utils.data.DataLoader( + dataset, + batch_size=mega_batch_size, + shuffle=False, + drop_last=False, + num_workers=args.n_workers, + collate_fn=dataset.collate, + ) split = args.split skip = 0 logging.info(f"now processing split {split} partition {args.partition}...") mega_n_steps = int(np.ceil(len(loader.dataset) / mega_batch_size)) # mega_n_steps = int(np.ceil(len(gs) / mega_batch_size)) - logging.info(f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples") - mani_fn = os.path.join(save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt") - logging.info(f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}") + logging.info( + f"partition the split {split} into {mega_n_steps} parts, each has at most {mega_batch_size} samples" + ) + mani_fn = os.path.join( + save_manifest_dir, f"{split}_{args.partition.replace('/', '=')}.txt" + ) + logging.info( + f"manifest for split {split} partition {args.partition.replace('/', '=')}.txt will be saved at {mani_fn}" + ) with open(mani_fn, "w") as mani_wf: - # with open(mani_fn, "a") as mani_wf: # resume from where we failed + # with open(mani_fn, "a") as mani_wf: # resume from where we failed for m, mega_batch in enumerate(tqdm.tqdm(loader)): logging.info(f"====================================") logging.info(f"====================================") logging.info(f"now processing mega step {m+1}/{mega_n_steps}") try: - # if True: + # if True: lengths = np.array(mega_batch[1]) - if len(lengths) == 0: # the loader might not find any audio because step3 will write to manifest first, and then might selection a subset to cut and save audio + if ( + len(lengths) == 0 + ): # the loader might not find any audio because step3 will write to manifest first, and then might selection a subset to cut and save audio continue sorted_inds = sort_by_audio_len(lengths) for j in range(len(sorted_inds))[::-1]: - if lengths[sorted_inds[j]] < args.model_sr*args.min_len or lengths[sorted_inds[j]] > args.model_sr*args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) + if ( + lengths[sorted_inds[j]] < args.model_sr * args.min_len + or lengths[sorted_inds[j]] > args.model_sr * args.len_cap + ): # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s) skip += 1 del sorted_inds[j] - + n_steps = int(np.ceil(len(sorted_inds) / batch_size)) for n in tqdm.tqdm(range(n_steps), disable=True): - inds_used = sorted_inds[n*batch_size:(n+1)*batch_size] + inds_used = sorted_inds[n * batch_size : (n + 1) * batch_size] while len(inds_used) < batch_size: - inds_used += sorted_inds[:batch_size-len(inds_used)] + inds_used += sorted_inds[: batch_size - len(inds_used)] wav_batch = [mega_batch[0][id] for id in inds_used] all_lens = [mega_batch[1][id] for id in inds_used] segment_id_batch = [mega_batch[2][id] for id in inds_used] - padded_wav = torch.nn.utils.rnn.pad_sequence(wav_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T] + padded_wav = torch.nn.utils.rnn.pad_sequence( + wav_batch, batch_first=True + ).unsqueeze( + 1 + ) # [B, T] -> [B, 1, T] # Extract discrete codes from EnCodec with torch.no_grad(): - if max(all_lens) > 300000 and len(all_lens) > 1: # if utterances are long, simply pass half of them at a time + if ( + max(all_lens) > 300000 and len(all_lens) > 1 + ): # if utterances are long, simply pass half of them at a time codes = [] inwav = padded_wav.cuda() - codes.append(model(inwav[:len(inwav)//2])[0].cpu()) - codes.append(model(inwav[len(inwav)//2:])[0].cpu()) + codes.append(model(inwav[: len(inwav) // 2])[0].cpu()) + codes.append(model(inwav[len(inwav) // 2 :])[0].cpu()) codes = torch.cat(codes, dim=0) else: - encoded_frames = model(padded_wav.cuda()) - codes = encoded_frames[0].cpu() # [B, n_codebook, T] + encoded_frames = model(padded_wav.cuda()) + codes = encoded_frames[0].cpu() # [B, n_codebook, T] for i, length in enumerate(all_lens): - save_fn = os.path.join(save_codes_dir, segment_id_batch[i]+".txt") - actual_len = round(length / args.downsample_rate) # 320 is downsample rate for this model - cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist() + save_fn = os.path.join( + save_codes_dir, segment_id_batch[i] + ".txt" + ) + actual_len = round( + length / args.downsample_rate + ) # 320 is downsample rate for this model + cur_code = ( + codes[i].tolist() + if type(codes) == list + else codes[i, :, :actual_len].tolist() + ) os.makedirs(os.path.dirname(save_fn), exist_ok=True) write_array_to_txt_file(cur_code, save_fn) - mani_wf.write(f"{segment_id_batch[i]}\t{len(cur_code[0])}\n") # write to manifest file + mani_wf.write( + f"{segment_id_batch[i]}\t{len(cur_code[0])}\n" + ) # write to manifest file # if i == 10: # raise except Exception as e: - print(f'exception!! at {m+1}') + print(f"exception!! at {m+1}") print(e) continue # break - logging.info(f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short") - # break \ No newline at end of file + logging.info( + f"split {split} partition {args.partition} has {len(loader.dataset)} samples in total, skipped {skip} due to utterance being too long or too short" + ) + # break diff --git a/data/ll60k_preprocessing/step5_find_nearest_neighbor.py b/data/ll60k_preprocessing/step5_find_nearest_neighbor.py index 9118941..8cc195b 100644 --- a/data/ll60k_preprocessing/step5_find_nearest_neighbor.py +++ b/data/ll60k_preprocessing/step5_find_nearest_neighbor.py @@ -18,44 +18,50 @@ import tqdm import json import tqdm + + def write_jsonl(data, fn): with open(fn, "w") as file: for entry in data: file.write(json.dumps(entry, ensure_ascii=False) + "\n") + + def read_jsonl(file_path): cur_data = [] - with open(file_path, 'r', encoding='utf-8-sig') as file: + with open(file_path, "r", encoding="utf-8-sig") as file: for line in file: cur_data.append(json.loads(line.strip())) return cur_data + + from collections import defaultdict + + # Function to create a defaultdict recursively def nested_defaultdict(levels, inner_type): if levels <= 1: return defaultdict(inner_type) - return defaultdict(lambda: nested_defaultdict(levels-1, inner_type)) + return defaultdict(lambda: nested_defaultdict(levels - 1, inner_type)) + def find_neighbor(args): split2manifest = { - "train": [ - "libriheavy_cuts_small.jsonl", - "libriheavy_cuts_medium.jsonl", - "libriheavy_cuts_large.jsonl", - "libriheavy_long_cuts_small.jsonl", - "libriheavy_long_cuts_medium.jsonl", - "libriheavy_long_cuts_large.jsonl" - ], - "valid": [ - "libriheavy_cuts_dev.jsonl", - "libriheavy_long_cuts_dev.jsonl" - ], - "test": [ - "libriheavy_cuts_test_clean.jsonl", - "libriheavy_cuts_test_other.jsonl", - "libriheavy_long_cuts_test_clean.jsonl", - "libriheavy_long_cuts_test_other.jsonl" - ] - } + "train": [ + "libriheavy_cuts_small.jsonl", + "libriheavy_cuts_medium.jsonl", + "libriheavy_cuts_large.jsonl", + "libriheavy_long_cuts_small.jsonl", + "libriheavy_long_cuts_medium.jsonl", + "libriheavy_long_cuts_large.jsonl", + ], + "valid": ["libriheavy_cuts_dev.jsonl", "libriheavy_long_cuts_dev.jsonl"], + "test": [ + "libriheavy_cuts_test_clean.jsonl", + "libriheavy_cuts_test_other.jsonl", + "libriheavy_long_cuts_test_clean.jsonl", + "libriheavy_long_cuts_test_other.jsonl", + ], + } stime = time.time() organized_data = nested_defaultdict(4, list) @@ -64,15 +70,21 @@ def find_neighbor(args): mani_full_fn = os.path.join(args.manifest_dir, mani_fn) data = read_jsonl(mani_full_fn) for item in data: - file_id = item['supervisions'][0]['id'] + '.flac' - recording_id = item['recording']['id'] + '.flac' - sizeSplit, spk, book, flac = recording_id.split("/") # e.g. 'medium/100/emerald_city_librivox_64kb_mp3/emeraldcity_01_baum_64kb' + file_id = item["supervisions"][0]["id"] + ".flac" + recording_id = item["recording"]["id"] + ".flac" + sizeSplit, spk, book, flac = recording_id.split( + "/" + ) # e.g. 'medium/100/emerald_city_librivox_64kb_mp3/emeraldcity_01_baum_64kb' if os.path.isfile(os.path.join(args.audio_dir, recording_id)): - vad = (item['start'], item['start']+item['duration']) - text = item['supervisions'][0]['custom']['texts'][0] - file_id = file_id.replace(".flac", "") + f"_{vad[0]:.2f}_{vad[1]:.2f}.flac" - organized_data[sizeSplit][spk][book][recording_id].append({"file_id": file_id, "vad":vad, "text": text}) - + vad = (item["start"], item["start"] + item["duration"]) + text = item["supervisions"][0]["custom"]["texts"][0] + file_id = ( + file_id.replace(".flac", "") + f"_{vad[0]:.2f}_{vad[1]:.2f}.flac" + ) + organized_data[sizeSplit][spk][book][recording_id].append( + {"file_id": file_id, "vad": vad, "text": text} + ) + # # for each recording_id, find the non-overlapping neighboring segments based on vad # for sizeSplit in organized_data: # for spk in organized_data[sizeSplit]: @@ -98,33 +110,70 @@ def find_neighbor(args): # f.write(f"{neighbor}\t{dist}\n") # use multiprocessing.Pool for the above - segments = [organized_data[sizeSplit][spk][book][recording_id] for sizeSplit in organized_data for spk in organized_data[sizeSplit] for book in organized_data[sizeSplit][spk] for recording_id in organized_data[sizeSplit][spk][book]] + segments = [ + organized_data[sizeSplit][spk][book][recording_id] + for sizeSplit in organized_data + for spk in organized_data[sizeSplit] + for book in organized_data[sizeSplit][spk] + for recording_id in organized_data[sizeSplit][spk][book] + ] # only keep those that are exist print(f"originally total {len(segments)} segments") - segments = [seg for seg in segments if os.path.isfile(os.path.join("/".join(args.output_dir.split("/")[:-1]),"audio", seg[0]['file_id']))] + segments = [ + seg + for seg in segments + if os.path.isfile( + os.path.join( + "/".join(args.output_dir.split("/")[:-1]), "audio", seg[0]["file_id"] + ) + ) + ] print(f"after check existance, total {len(segments)} segments") print(f"organizing took {(time.time()-stime)/60:.2f} minutes") with multiprocessing.Pool(processes=args.n_workers) as pool: - for _ in tqdm.tqdm(pool.imap_unordered(find_neighbor_each, segments), total=len(segments)): + for _ in tqdm.tqdm( + pool.imap_unordered(find_neighbor_each, segments), total=len(segments) + ): pass + # audio_root = "/data/scratch/pyp/datasets/librilight/preprocessed/audio" def find_neighbor_each(segments): # for each recording_id, find the non-overlapping neighboring segments based on vad # only keep segments that have audio files # actually only keep segments that have ipa_alignment files - segments = [seg for seg in segments if os.path.isfile(os.path.join("/".join(args.output_dir.split("/")[:-1]),"ipa_alignment", seg['file_id'].replace(".flac", ".txt")))] + segments = [ + seg + for seg in segments + if os.path.isfile( + os.path.join( + "/".join(args.output_dir.split("/")[:-1]), + "ipa_alignment", + seg["file_id"].replace(".flac", ".txt"), + ) + ) + ] if len(segments) <= 1: return for i in range(len(segments)): # for segment i, find the non-overlapping neighboring segments - write_fn = os.path.join(args.output_dir, f"{segments[i]['file_id'].replace('.flac', '.txt')}") + write_fn = os.path.join( + args.output_dir, f"{segments[i]['file_id'].replace('.flac', '.txt')}" + ) neighbors = [] distance = [] for j in range(len(segments)): - if segments[i]['vad'][1] < segments[j]['vad'][0] or segments[i]['vad'][0] > segments[j]['vad'][0]: + if ( + segments[i]["vad"][1] < segments[j]["vad"][0] + or segments[i]["vad"][0] > segments[j]["vad"][0] + ): neighbors.append(segments[j]) - distance.append(min(abs(segments[i]['vad'][1] - segments[j]['vad'][0]), abs(segments[i]['vad'][0] - segments[j]['vad'][1]))) + distance.append( + min( + abs(segments[i]["vad"][1] - segments[j]["vad"][0]), + abs(segments[i]["vad"][0] - segments[j]["vad"][1]), + ) + ) if len(neighbors) == 0: continue # order neighbors by distance @@ -134,24 +183,47 @@ def find_neighbor_each(segments): with open(write_fn, "w") as f: # note that there might be no neighbors, in which case the file is empty for neighbor, dist in neighbors_distance: - f.write(f"{neighbor['file_id'].replace('.flac', '.txt')}\t{dist}\t{neighbor['vad'][1] - neighbor['vad'][0]}\n") # file_id, distance, duration - + f.write( + f"{neighbor['file_id'].replace('.flac', '.txt')}\t{dist}\t{neighbor['vad'][1] - neighbor['vad'][0]}\n" + ) # file_id, distance, duration def parse_args(): - parser = argparse.ArgumentParser(description="Cut a dataset in small " - "sequences using VAD files") - parser.add_argument('--split', type=str, default='train', choices=['train', 'valid', 'test'], help="train = libriheavy_cuts_{small,medium,large}.jsonl.gz, valid = libriheavy_cuts_dev_{clean,other}.jsonl.gz, test = libriheavy_cuts_test_{clean,other}.jsonl.gz") - parser.add_argument('--audio_dir', type=str, default="/data/scratch/pyp/datasets/librilight_example", - help="Path to the audio directory") - parser.add_argument('--manifest_dir', type=str, default="/data/scratch/pyp/datasets/librilight/libriheavy", help="path to the transcription file's dir, can be downloaded https://huggingface.co/datasets/pkufool/libriheavy/tree/main/v0.1") - parser.add_argument('--output_dir', type=str, default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed/neighbors", - help="Path to the output directory") - parser.add_argument('--n_workers', type=int, default=16, - help="Number of parallel worker processes") + parser = argparse.ArgumentParser( + description="Cut a dataset in small " "sequences using VAD files" + ) + parser.add_argument( + "--split", + type=str, + default="train", + choices=["train", "valid", "test"], + help="train = libriheavy_cuts_{small,medium,large}.jsonl.gz, valid = libriheavy_cuts_dev_{clean,other}.jsonl.gz, test = libriheavy_cuts_test_{clean,other}.jsonl.gz", + ) + parser.add_argument( + "--audio_dir", + type=str, + default="/data/scratch/pyp/datasets/librilight_example", + help="Path to the audio directory", + ) + parser.add_argument( + "--manifest_dir", + type=str, + default="/data/scratch/pyp/datasets/librilight/libriheavy", + help="path to the transcription file's dir, can be downloaded https://huggingface.co/datasets/pkufool/libriheavy/tree/main/v0.1", + ) + parser.add_argument( + "--output_dir", + type=str, + default="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed/neighbors", + help="Path to the output directory", + ) + parser.add_argument( + "--n_workers", type=int, default=16, help="Number of parallel worker processes" + ) return parser.parse_args() + if __name__ == "__main__": args = parse_args() pathlib.Path(args.output_dir).mkdir(exist_ok=True, parents=True) - find_neighbor(args) \ No newline at end of file + find_neighbor(args) diff --git a/data/ll60k_preprocessing/step6_forced_alignment.py b/data/ll60k_preprocessing/step6_forced_alignment.py index 2606cac..31f700c 100644 --- a/data/ll60k_preprocessing/step6_forced_alignment.py +++ b/data/ll60k_preprocessing/step6_forced_alignment.py @@ -2,25 +2,46 @@ import subprocess, tqdm from concurrent.futures import ThreadPoolExecutor + def align_folders(audio_root, subfolder, subsubfolder): # Construct output folder path file_root = os.path.dirname(audio_root) out_folder = f"{file_root}/alignment/{subfolder}/{subsubfolder}" - + # Create the output directory os.makedirs(out_folder, exist_ok=True) - + # Construct the MFA align command command = [ - "mfa", "align", "--single_speaker", "-j", "8", "--clean", - f"{audio_root}/{subfolder}/{subsubfolder}", "english_us_arpa", "english_us_arpa", - out_folder, "--beam", "50", "--retry_beam", "400", "--output_format", "csv" + "mfa", + "align", + "--single_speaker", + "-j", + "8", + "--clean", + f"{audio_root}/{subfolder}/{subsubfolder}", + "english_us_arpa", + "english_us_arpa", + out_folder, + "--beam", + "50", + "--retry_beam", + "400", + "--output_format", + "csv", ] - + # Run the command subprocess.run(command, check=True) -def main(file_root = "/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", max_parallel_jobs=10, max_spk=100, partition="1/10", n_workers=64): + +def main( + file_root="/data/scratch/pyp/datasets/librilight/librilight_example_preprocessed", + max_parallel_jobs=10, + max_spk=100, + partition="1/10", + n_workers=64, +): # Find all subfolder/subsubfolder combinations tasks = [] audio_root = os.path.join(file_root, "audio") @@ -34,18 +55,28 @@ def main(file_root = "/data/scratch/pyp/datasets/librilight/librilight_example_p speaker_folder_map = {} for audio_root, subfolder, subsubfolder in tasks: if os.path.join(audio_root, subfolder) not in speaker_folder_map: - speaker_folder_map[os.path.join(audio_root, subfolder)] = [os.path.join(audio_root, subfolder, subsubfolder)] + speaker_folder_map[os.path.join(audio_root, subfolder)] = [ + os.path.join(audio_root, subfolder, subsubfolder) + ] else: - speaker_folder_map[os.path.join(audio_root, subfolder)].append(os.path.join(audio_root, subfolder, subsubfolder)) + speaker_folder_map[os.path.join(audio_root, subfolder)].append( + os.path.join(audio_root, subfolder, subsubfolder) + ) speaker_folder_partitions = [] for audio_root_subfolder, speaker_folders in speaker_folder_map.items(): - speaker_folder_partitions.extend([speaker_folders[i:i+max_spk] for i in range(0, len(speaker_folders), max_spk)]) + speaker_folder_partitions.extend( + [ + speaker_folders[i : i + max_spk] + for i in range(0, len(speaker_folders), max_spk) + ] + ) s, e = partition.split("/") - s, e = int(s)-1, int(e) + s, e = int(s) - 1, int(e) cur_tasks = speaker_folder_partitions[s::e] import secrets, string import soundfile, glob from joblib import Parallel, delayed + def delete_corrupted(fn): try: x = soundfile.read(fn) @@ -59,19 +90,29 @@ def delete_corrupted(fn): # assert that all subs are the same assert len(set(subs)) == 1, subs sub = subs[0] - # randomly generate a foldername + # randomly generate a foldername # generate a random character # make softlink from item in task to temp folder - random_string = ''.join(secrets.choice(string.ascii_letters + string.digits) for i in range(10)) + random_string = "".join( + secrets.choice(string.ascii_letters + string.digits) for i in range(10) + ) temp_folder = os.path.join(file_root, "softlink_audio", random_string) os.makedirs(temp_folder, exist_ok=True) out_folder = f"{file_root}/alignment/{sub}" - all_out_speaker_folders = [os.path.join(out_folder, os.path.basename(item)) for item in task] - if sum(os.path.isdir(curpath) for curpath in all_out_speaker_folders) == len(all_out_speaker_folders): + all_out_speaker_folders = [ + os.path.join(out_folder, os.path.basename(item)) for item in task + ] + if sum(os.path.isdir(curpath) for curpath in all_out_speaker_folders) == len( + all_out_speaker_folders + ): continue # remove audio files that are corrupted - all_audio_files = [audiofile for item in task for audiofile in glob.glob(item+"/*/*.flac")] - Parallel(n_jobs=n_workers)(delayed(delete_corrupted)(audiofn) for audiofn in all_audio_files) + all_audio_files = [ + audiofile for item in task for audiofile in glob.glob(item + "/*/*.flac") + ] + Parallel(n_jobs=n_workers)( + delayed(delete_corrupted)(audiofn) for audiofn in all_audio_files + ) for item in task: # make softlink from subsubfolder to a new folder in temp folder os.symlink(item, os.path.join(temp_folder, os.path.basename(item))) @@ -81,6 +122,8 @@ def delete_corrupted(fn): # delete the temp_folder os.system(f"rm -r {temp_folder}") + if __name__ == "__main__": import fire - fire.Fire(main) \ No newline at end of file + + fire.Fire(main) diff --git a/data/ll60k_preprocessing/step7_ipa_alignment.py b/data/ll60k_preprocessing/step7_ipa_alignment.py index 137c530..0dec3f5 100644 --- a/data/ll60k_preprocessing/step7_ipa_alignment.py +++ b/data/ll60k_preprocessing/step7_ipa_alignment.py @@ -1,6 +1,6 @@ -# we have raw transcript at +# we have raw transcript at # /data/scratch/pyp/datasets/librilight/preprocessed/audio -# we have word and ARPA alignment at +# we have word and ARPA alignment at # /data/scratch/pyp/datasets/librilight/preprocessed/alignment # we have manifest at /data/scratch/pyp/datasets/librilight/preprocessed/manifest_mimi @@ -15,12 +15,23 @@ def remove_punctuation(input_string): - translator = str.maketrans('', '', string.punctuation) + translator = str.maketrans("", "", string.punctuation) return input_string.translate(translator) - -def create_alignment(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, text_tokenizer, use_prob, ipa_alignment_fn, save=False, prompt_dur=30): +def create_alignment( + fn, + trans_dir, + align_dir, + audio_ext, + trans_ext, + arpa_ext, + text_tokenizer, + use_prob, + ipa_alignment_fn, + save=False, + prompt_dur=30, +): os.makedirs(os.path.dirname(ipa_alignment_fn), exist_ok=True) trans_fn = os.path.join(trans_dir, fn.replace(audio_ext, trans_ext)) if not os.path.isfile(trans_fn): @@ -29,13 +40,13 @@ def create_alignment(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, t if not os.path.isfile(align_fn): return [], True # get raw transcript - with open(trans_fn, 'r') as f: + with open(trans_fn, "r") as f: transcript = f.read().strip() raw_word_list = transcript.split(" ") # get word alignment - with open(align_fn, 'r') as f: + with open(align_fn, "r") as f: word_alignment = csv.reader(f) - word_alignment = [row for row in word_alignment if row[3]=='words'] + word_alignment = [row for row in word_alignment if row[3] == "words"] ipa_alignment = [] @@ -48,67 +59,83 @@ def create_alignment(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, t # print(f"word from alignment csv: {word}, word from txt: {raw_word}") return ipa_alignment, True if random.random() < use_prob: - cur_words = " ".join(raw_word_list[:j+1]) + cur_words = " ".join(raw_word_list[: j + 1]) phn = tokenize_text(text_tokenizer, cur_words) if len(phn) == 0: continue phn = " ".join(phn) - start = 0 # at this point, we always start from the beginning of the sentence + start = ( + 0 # at this point, we always start from the beginning of the sentence + ) ipa_alignment.append([start, end, phn]) if save: if ipa_alignment: - with open(ipa_alignment_fn, 'w') as f: + with open(ipa_alignment_fn, "w") as f: for item in ipa_alignment: f.write(f"{item[0]}\t{item[1]}\t{item[2]}\n") else: return ipa_alignment, False - def main( - data_root: str = '/data/scratch/pyp/datasets/librilight/preprocessed', - audio_ext: str = '.flac', - arpa_ext: str = '.csv', - trans_ext: str = '.txt', - split: str = 'valid', + data_root: str = "/data/scratch/pyp/datasets/librilight/preprocessed", + audio_ext: str = ".flac", + arpa_ext: str = ".csv", + trans_ext: str = ".txt", + split: str = "valid", use_prob: float = 0.5, - max_dur: float = 30., # do not consider utterance longer than this - prompt_dur: float = 30., # do not consider prompt longer than this + max_dur: float = 30.0, # do not consider utterance longer than this + prompt_dur: float = 30.0, # do not consider prompt longer than this ): text_tokenizer = TextTokenizer() - trans_dir = f'{data_root}/audio' - align_dir = f'{data_root}/alignment' + trans_dir = f"{data_root}/audio" + align_dir = f"{data_root}/alignment" manifest_fn = f"{data_root}/manifest_final_encodec/{split}*=*.txt" manifest_fns = glob.glob(manifest_fn) - target_dir = f'{data_root}/ipa_alignment' + target_dir = f"{data_root}/ipa_alignment" encodec_sr = 50 os.makedirs(target_dir, exist_ok=True) manifest = [] for manifest_fn in manifest_fns: - with open(manifest_fn, 'r') as f: + with open(manifest_fn, "r") as f: temp = [l.strip().split("\t") for l in f.readlines()] - manifest += [l[0] + audio_ext for l in temp if float(l[1])/encodec_sr < max_dur] + manifest += [ + l[0] + audio_ext for l in temp if float(l[1]) / encodec_sr < max_dur + ] # # sequential processing n_flags = 0 zero_words = 0 for j, fn in enumerate(tqdm.tqdm(manifest)): - ipa_alignment_fn = os.path.join(target_dir, fn.replace(audio_ext, '.txt')) - ipa_alignment, flag = create_alignment(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, text_tokenizer, use_prob, ipa_alignment_fn, prompt_dur=prompt_dur) + ipa_alignment_fn = os.path.join(target_dir, fn.replace(audio_ext, ".txt")) + ipa_alignment, flag = create_alignment( + fn, + trans_dir, + align_dir, + audio_ext, + trans_ext, + arpa_ext, + text_tokenizer, + use_prob, + ipa_alignment_fn, + prompt_dur=prompt_dur, + ) n_flags += flag if not ipa_alignment: zero_words += 1 # print(f"{n_flags} out of {j+1} utterances have mismatched words") # print(f"{zero_words} out of {j+1} utterances have zero words") if ipa_alignment: - with open(ipa_alignment_fn, 'w') as f: + with open(ipa_alignment_fn, "w") as f: for item in ipa_alignment: f.write(f"{item[0]}\t{item[1]}\t{item[2]}\n") - + # # # # do the above using joblib parallisim # print(f"Processing {len(manifest)} utterances") # from joblib import Parallel, delayed # Parallel(n_jobs=32, verbose=2)(delayed(create_alignment)(fn, trans_dir, align_dir, audio_ext, trans_ext, arpa_ext, text_tokenizer, use_prob, os.path.join(target_dir, fn.replace(audio_ext, '.txt')), save=True) for fn in manifest) - + + if __name__ == "__main__": import fire - fire.Fire(main) \ No newline at end of file + + fire.Fire(main) diff --git a/data/ll60k_preprocessing/tokenizer.py b/data/ll60k_preprocessing/tokenizer.py index ad9f114..cd39c6d 100644 --- a/data/ll60k_preprocessing/tokenizer.py +++ b/data/ll60k_preprocessing/tokenizer.py @@ -20,6 +20,7 @@ import numpy as np import torch import torchaudio + # from encodec import EncodecModel # from encodec.utils import convert_audio # from lhotse.features import FeatureExtractor @@ -62,9 +63,7 @@ def phonemize( phones = [] if self.backend == "pypinyin": for n, py in enumerate( - pinyin( - _text, style=Style.TONE3, neutral_tone_with_five=True - ) + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) ): if all([c in self.punctuation_marks for c in py[0]]): if len(phones): @@ -76,9 +75,7 @@ def phonemize( phones.extend([py[0], separator.syllable]) elif self.backend == "pypinyin_initials_finals": for n, py in enumerate( - pinyin( - _text, style=Style.TONE3, neutral_tone_with_five=True - ) + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) ): if all([c in self.punctuation_marks for c in py[0]]): if len(phones): @@ -89,10 +86,7 @@ def phonemize( if py[0][-1].isalnum(): initial = get_initials(py[0], strict=False) if py[0][-1].isdigit(): - final = ( - get_finals(py[0][:-1], strict=False) - + py[0][-1] - ) + final = get_finals(py[0][:-1], strict=False) + py[0][-1] else: final = get_finals(py[0], strict=False) phones.extend( @@ -155,8 +149,7 @@ def to_list(self, phonemized: str) -> List[str]: # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) fields.extend( - [p for p in pp if p != self.separator.phone] - + [self.separator.word] + [p for p in pp if p != self.separator.phone] + [self.separator.word] ) assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( self.separator.phone @@ -182,6 +175,7 @@ def remove_encodec_weight_norm(model): from encodec.modules import SConv1d from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock from torch.nn.utils import remove_weight_norm + encoder = model.encoder.model for key in encoder._modules: if isinstance(encoder._modules[key], SEANetResnetBlock): @@ -326,7 +320,7 @@ def remove_encodec_weight_norm(model): # # ret = self.codec.cpu().decode([(frames[0][0].cpu(),None)])[0].to(self._device) # # self.codec.to(self._device) # # return [ret] - + # def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1): # # Load and pre-process the audio waveform # if offset != -1 and num_frames!=-1: @@ -457,4 +451,4 @@ def remove_encodec_weight_norm(model): remove_encodec_weight_norm(model) codes_raw = model.encode(samples.cuda()) - assert torch.allclose(codes_raw[0][0], codes_norm[0][0]) \ No newline at end of file + assert torch.allclose(codes_raw[0][0], codes_norm[0][0]) diff --git a/data/tokenizer.py b/data/tokenizer.py index 46bcaa0..ce6b9ff 100644 --- a/data/tokenizer.py +++ b/data/tokenizer.py @@ -16,10 +16,11 @@ import re from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Pattern, Union - +from voicestar.data.encodec import get_compression_model import numpy as np import torch import torchaudio + # from encodec import EncodecModel # from encodec.utils import convert_audio # from lhotse.features import FeatureExtractor @@ -63,9 +64,7 @@ def phonemize( phones = [] if self.backend == "pypinyin": for n, py in enumerate( - pinyin( - _text, style=Style.TONE3, neutral_tone_with_five=True - ) + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) ): if all([c in self.punctuation_marks for c in py[0]]): if len(phones): @@ -77,9 +76,7 @@ def phonemize( phones.extend([py[0], separator.syllable]) elif self.backend == "pypinyin_initials_finals": for n, py in enumerate( - pinyin( - _text, style=Style.TONE3, neutral_tone_with_five=True - ) + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) ): if all([c in self.punctuation_marks for c in py[0]]): if len(phones): @@ -90,10 +87,7 @@ def phonemize( if py[0][-1].isalnum(): initial = get_initials(py[0], strict=False) if py[0][-1].isdigit(): - final = ( - get_finals(py[0][:-1], strict=False) - + py[0][-1] - ) + final = get_finals(py[0][:-1], strict=False) + py[0][-1] else: final = get_finals(py[0], strict=False) phones.extend( @@ -156,8 +150,7 @@ def to_list(self, phonemized: str) -> List[str]: # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) fields.extend( - [p for p in pp if p != self.separator.phone] - + [self.separator.word] + [p for p in pp if p != self.separator.phone] + [self.separator.word] ) assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( self.separator.phone @@ -183,6 +176,7 @@ def remove_encodec_weight_norm(model): from encodec.modules import SConv1d from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock from torch.nn.utils import remove_weight_norm + encoder = model.encoder.model for key in encoder._modules: if isinstance(encoder._modules[key], SEANetResnetBlock): @@ -212,14 +206,14 @@ class AudioTokenizer: def __init__( self, - bandwidth: float=6.0, + bandwidth: float = 6.0, device: Any = None, hificodec=False, - signature = None, - encode_only = False + signature=None, + encode_only=False, ) -> None: self.signature = signature - from data.encodec import get_compression_model + model = get_compression_model(signature, encode_only=encode_only, device=device) self.sample_rate = model.sample_rate self.channels = model.channels @@ -240,35 +234,44 @@ def device(self): def encode(self, wav: torch.Tensor) -> torch.Tensor: if self.signature != None: if self.signature == "lfsc": - if wav.ndim==3: - assert wav.shape[:2] == torch.Size((1,1)), wav.shape + if wav.ndim == 3: + assert wav.shape[:2] == torch.Size((1, 1)), wav.shape wav = wav.squeeze(0) - elif wav.ndim==2: + elif wav.ndim == 2: assert wav.shape[0] == 1, wav.shape else: raise ValueError(wav.shape) audio_len = torch.tensor([wav.shape[1]]).to(self.device) - codes, encoded_len = self.codec.encode(audio=wav.to(self.device), audio_len=audio_len) - return codes[:, :, :encoded_len[0]] + codes, encoded_len = self.codec.encode( + audio=wav.to(self.device), audio_len=audio_len + ) + return codes[:, :, : encoded_len[0]] else: codes = self.codec.encode(wav.to(self.device)) return codes[0] else: - assert wav.ndim==3 and wav.shape[:2] == torch.Size((1,1)), wav.shape + assert wav.ndim == 3 and wav.shape[:2] == torch.Size((1, 1)), wav.shape return self.codec.encode(wav.to(self.device)) def decode(self, frames: torch.Tensor) -> torch.Tensor: if self.signature != None and self.signature == "lfsc": encoded_len = torch.tensor([frames.shape[-1]]).to(self.device) - reconstructed_audio, decoded_len = self.codec.decode(tokens=frames, tokens_len=encoded_len) - return reconstructed_audio[:, :decoded_len[0]].unsqueeze(0) + reconstructed_audio, decoded_len = self.codec.decode( + tokens=frames, tokens_len=encoded_len + ) + return reconstructed_audio[:, : decoded_len[0]].unsqueeze(0) else: return self.codec.decode(frames) - -def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1): + + +def tokenize_audio( + tokenizer: AudioTokenizer, audio_path: str, offset=-1, num_frames=-1 +): # Load and pre-process the audio waveform - if offset != -1 and num_frames!=-1: - wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames) + if offset != -1 and num_frames != -1: + wav, sr = torchaudio.load( + audio_path, frame_offset=offset, num_frames=num_frames + ) else: wav, sr = torchaudio.load(audio_path) if sr != tokenizer.sample_rate: @@ -285,11 +288,17 @@ def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_ if __name__ == "__main__": # tok = AudioTokenizer(signature="lfsc", device="cpu") - tok = AudioTokenizer(signature="/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th", device="cpu") + tok = AudioTokenizer( + signature="/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th", + device="cpu", + ) inaudio = "/home/pyp/BoostedVoiceEditor/demo/pam.wav" encoded_frames = tokenize_audio(tok, inaudio) print(encoded_frames.shape) # decode it back decoded_audio = tok.decode(encoded_frames) - torchaudio.save("/home/pyp/BoostedVoiceEditor/demo/pam_reconstructed_encodec_4cb_2nd.wav", decoded_audio[0], tok.sample_rate) - + torchaudio.save( + "/home/pyp/BoostedVoiceEditor/demo/pam_reconstructed_encodec_4cb_2nd.wav", + decoded_audio[0], + tok.sample_rate, + ) diff --git a/docs/training.md b/docs/training.md new file mode 100644 index 0000000..0e5957c --- /dev/null +++ b/docs/training.md @@ -0,0 +1,66 @@ +# Training + +## Setup Environment + +First, setup the environment with the inference requirements: + +```bash +conda create -n voicestar python=3.10 +conda activate voicestar # this seems to lead to much worse results in terms of wer and spksim (comparing e9_rerun and e9_rerun_newba_upgraded) +pip install torch==2.5.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 +pip install numpy, tqdm, fire +pip install phonemizer==3.2.1 +apt-get install espeak-ng # backend for the phonemizer +pip install torchmetrics +pip install einops +pip install omegaconf==2.3.0 +pip install openai-whisper +pip install gradio +``` + + +* avoid warnings likes +[WARNING] words_mismatch.py:88 || words count mismatch on 200.0% of the lines (2/1) +```python +# go to ~/miniconda3/envs/voicestar/lib/python3.10/site-packages/phonemizer/backend/espeak/words_mismatch.py +# pass the warning like this + def _resume(self, nmismatch: int, nlines: int): + """Logs a high level undetailed warning""" + pass + # if nmismatch: + # self._logger.warning( + # 'words count mismatch on %s%% of the lines (%s/%s)', + # round(nmismatch / nlines, 2) * 100, nmismatch, nlines) +``` + + +Install additional packages required for training and data processing: + +```bash +pip install huggingface_hub +pip install datasets +pip install tensorboard +pip install wandb +pip install matplotlib +pip install ffmpeg-python +pip install scipy +pip install soundfile +``` + +## Download Models + +If you are training, you may need to download models manually: + +```bash +# under VoiceStar root dir +mkdir pretrained +wget -O ./pretrained/encodec_6f79c6a8.th "https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th" +wget -O ./pretrained/VoiceStar_840M_30s.pth "https://huggingface.co/pyp1/VoiceStar/resolve/main/VoiceStar_840M_30s.pth" +wget -O ./pretrained/VoiceStar_840M_40s.pth "https://huggingface.co/pyp1/VoiceStar/resolve/main/VoiceStar_840M_40s.pth" +``` + +## Training + +TODO: Finish training docs + +Training scripts can be found in `train` folder. The data processing scripts can be found in `data` folder. An example training script can be found in `scripts/e1_840M_30s.sh`. \ No newline at end of file diff --git a/generated_tts/generated.wav b/generated_tts/generated.wav deleted file mode 100644 index a04b7e1..0000000 Binary files a/generated_tts/generated.wav and /dev/null differ diff --git a/inference_commandline.py b/inference_commandline.py deleted file mode 100644 index a0429d5..0000000 --- a/inference_commandline.py +++ /dev/null @@ -1,192 +0,0 @@ -import os -import torch -import torchaudio -import numpy as np -import random -import whisper -import fire -from argparse import Namespace - -from data.tokenizer import ( - AudioTokenizer, - TextTokenizer, -) - -from models import voice_star -from inference_tts_utils import inference_one_sample - -############################################################ -# Utility Functions -############################################################ - -def seed_everything(seed=1): - os.environ['PYTHONHASHSEED'] = str(seed) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - - -def estimate_duration(ref_audio_path, text): - """ - Estimate duration based on seconds per character from the reference audio. - """ - info = torchaudio.info(ref_audio_path) - audio_duration = info.num_frames / info.sample_rate - length_text = max(len(text), 1) - spc = audio_duration / length_text # seconds per character - return len(text) * spc - -############################################################ -# Main Inference Function -############################################################ - -def run_inference( - reference_speech="./demo/5895_34622_000026_000002.wav", - target_text="I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long.", - # Model - model_name="VoiceStar_840M_30s", # or VoiceStar_840M_40s, the later model is trained on maximally 40s long speech - model_root="./pretrained", - # Additional optional - reference_text=None, # if None => run whisper on reference_speech - target_duration=None, # if None => estimate from reference_speech and target_text - # Default hyperparameters from snippet - codec_audio_sr=16000, # do not change - codec_sr=50, # do not change - top_k=10, # try 10, 20, 30, 40 - top_p=1, # do not change - min_p=1, # do not change - temperature=1, - silence_tokens=None, # do not change it - kvcache=1, # if OOM, set to 0 - multi_trial=None, # do not change it - repeat_prompt=1, # increase this to improve speaker similarity, but it reference speech duration in total adding target duration is longer than maximal training duration, quality may drop - stop_repetition=3, # will not use it - sample_batch_size=1, # do not change - # Others - seed=1, - output_dir="./generated_tts", - # Some snippet-based defaults - cut_off_sec=100, # do not adjust this, we always use the entire reference speech. If you wish to change, also make sure to change the reference_transcript, so that it's only the trasnscript of the speech remained -): - """ - Inference script using Fire. - - Example: - python inference_commandline.py \ - --reference_speech "./demo/5895_34622_000026_000002.wav" \ - --target_text "I cannot believe ... this audio is 10 seconds long." \ - --reference_text "(optional) text to use as prefix" \ - --target_duration (optional float) - """ - - # Seed everything - seed_everything(seed) - - # Load model, phn2num, and args - torch.serialization.add_safe_globals([Namespace]) - device = "cuda" if torch.cuda.is_available() else "cpu" - ckpt_fn = os.path.join(model_root, model_name+".pth") - if not os.path.exists(ckpt_fn): - # use wget to download - print(f"[Info] Downloading {model_name} checkpoint...") - os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}") - bundle = torch.load(ckpt_fn, map_location=device, weights_only=True) - args = bundle["args"] - phn2num = bundle["phn2num"] - model = voice_star.VoiceStar(args) - model.load_state_dict(bundle["model"]) - model.to(device) - model.eval() - - # If reference_text not provided, use whisper large-v3-turbo - if reference_text is None: - print("[Info] No reference_text provided, transcribing reference_speech with Whisper.") - wh_model = whisper.load_model("large-v3-turbo") - result = wh_model.transcribe(reference_speech) - prefix_transcript = result["text"] - print(f"[Info] Whisper transcribed text: {prefix_transcript}") - else: - prefix_transcript = reference_text - - # If target_duration not provided, estimate from reference speech + target_text - if target_duration is None: - target_generation_length = estimate_duration(reference_speech, target_text) - print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f} seconds. If not desired, please provide a target_duration.") - else: - target_generation_length = float(target_duration) - - # signature from snippet - if args.n_codebooks == 4: - signature = "./pretrained/encodec_6f79c6a8.th" - elif args.n_codebooks == 8: - signature = "./pretrained/encodec_8cb1024_giga.th" - else: - # fallback, just use the 6-f79c6a8 - signature = "./pretrained/encodec_6f79c6a8.th" - - if silence_tokens is None: - # default from snippet - silence_tokens = [] - - if multi_trial is None: - # default from snippet - multi_trial = [] - - delay_pattern_increment = args.n_codebooks + 1 # from snippet - - # We can compute prompt_end_frame if we want, from snippet - info = torchaudio.info(reference_speech) - prompt_end_frame = int(cut_off_sec * info.sample_rate) - - # Prepare tokenizers - audio_tokenizer = AudioTokenizer(signature=signature) - text_tokenizer = TextTokenizer(backend="espeak") - - # decode_config from snippet - decode_config = { - 'top_k': top_k, - 'top_p': top_p, - 'min_p': min_p, - 'temperature': temperature, - 'stop_repetition': stop_repetition, - 'kvcache': kvcache, - 'codec_audio_sr': codec_audio_sr, - 'codec_sr': codec_sr, - 'silence_tokens': silence_tokens, - 'sample_batch_size': sample_batch_size - } - - # Run inference - print("[Info] Running TTS inference...") - concated_audio, gen_audio = inference_one_sample( - model, args, phn2num, text_tokenizer, audio_tokenizer, - reference_speech, target_text, - device, decode_config, - prompt_end_frame=prompt_end_frame, - target_generation_length=target_generation_length, - delay_pattern_increment=delay_pattern_increment, - prefix_transcript=prefix_transcript, - multi_trial=multi_trial, - repeat_prompt=repeat_prompt, - ) - - # The model returns a list of waveforms, pick the first - concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() - - # Save the audio (just the generated portion, as the snippet does) - os.makedirs(output_dir, exist_ok=True) - out_filename = "generated.wav" - out_path = os.path.join(output_dir, out_filename) - torchaudio.save(out_path, gen_audio, codec_audio_sr) - - print(f"[Success] Generated audio saved to {out_path}") - - -def main(): - fire.Fire(run_inference) - -if __name__ == "__main__": - main() diff --git a/inference_gradio.py b/inference_gradio.py deleted file mode 100644 index 4bb2879..0000000 --- a/inference_gradio.py +++ /dev/null @@ -1,334 +0,0 @@ -#!/usr/bin/env python3 -""" -gradio_tts_app.py - -Run: - python gradio_tts_app.py - -Then open the printed local or public URL in your browser. -""" - -import os -import random -import numpy as np -import torch -import torchaudio -import whisper -import gradio as gr -from argparse import Namespace - -# --------------------------------------------------------------------- -# The following imports assume your local project structure: -# data/tokenizer.py -# models/voice_star.py -# inference_tts_utils.py -# Adjust if needed. -# --------------------------------------------------------------------- -from data.tokenizer import AudioTokenizer, TextTokenizer -from models import voice_star -from inference_tts_utils import inference_one_sample - - -############################################################ -# Utility Functions -############################################################ - -def seed_everything(seed=1): - os.environ['PYTHONHASHSEED'] = str(seed) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - - -def estimate_duration(ref_audio_path, text): - """ - Estimate duration based on seconds per character from the reference audio. - """ - info = torchaudio.info(ref_audio_path) - audio_duration = info.num_frames / info.sample_rate - length_text = max(len(text), 1) - spc = audio_duration / length_text # seconds per character - return len(text) * spc - - -############################################################ -# Main Inference Function -############################################################ - -def run_inference( - # User-adjustable parameters (no "# do not change" in snippet) - reference_speech="./demo/5895_34622_000026_000002.wav", - target_text="VoiceStar is a very interesting model, it's duration controllable and can extrapolate", - model_name="VoiceStar_840M_40s", - model_root="./pretrained", - reference_text=None, # optional - target_duration=None, # optional - top_k=10, # can try 10, 20, 30, 40 - temperature=1, - kvcache=1, # if OOM, set to 0 - repeat_prompt=1, # use higher to improve speaker similarity - stop_repetition=3, # snippet says "will not use it" but not "do not change" - seed=1, - output_dir="./generated_tts", - - # Non-adjustable parameters (based on snippet instructions) - codec_audio_sr=16000, # do not change - codec_sr=50, # do not change - top_p=1, # do not change - min_p=1, # do not change - silence_tokens=None, # do not change it - multi_trial=None, # do not change it - sample_batch_size=1, # do not change - cut_off_sec=100, # do not adjust -): - """ - Inference script for VoiceStar TTS. - """ - # 1. Set seed - seed_everything(seed) - - # 2. Load model checkpoint - torch.serialization.add_safe_globals([Namespace]) - device = "cuda" if torch.cuda.is_available() else "cpu" - ckpt_fn = os.path.join(model_root, model_name + ".pth") - if not os.path.exists(ckpt_fn): - # use wget to download - print(f"[Info] Downloading {model_name} checkpoint...") - os.system(f"wget https://huggingface.co/pyp1/VoiceStar/resolve/main/{model_name}.pth?download=true -O {ckpt_fn}") - bundle = torch.load(ckpt_fn, map_location=device, weights_only=True) - args = bundle["args"] - phn2num = bundle["phn2num"] - - model = voice_star.VoiceStar(args) - model.load_state_dict(bundle["model"]) - model.to(device) - model.eval() - - # 3. If reference_text not provided, transcribe reference speech with Whisper - if reference_text is None: - print("[Info] No reference_text provided. Transcribing reference_speech with Whisper (large-v3-turbo).") - wh_model = whisper.load_model("large-v3-turbo") - result = wh_model.transcribe(reference_speech) - prefix_transcript = result["text"] - print(f"[Info] Whisper transcribed text: {prefix_transcript}") - else: - prefix_transcript = reference_text - - # 4. If target_duration not provided, estimate from reference speech + target_text - if target_duration is None: - target_generation_length = estimate_duration(reference_speech, target_text) - print(f"[Info] target_duration not provided, estimated as {target_generation_length:.2f}s. Provide --target_duration if needed.") - else: - target_generation_length = float(target_duration) - - # 5. Prepare signature from snippet - if args.n_codebooks == 4: - signature = "./pretrained/encodec_6f79c6a8.th" - elif args.n_codebooks == 8: - signature = "./pretrained/encodec_8cb1024_giga.th" - else: - signature = "./pretrained/encodec_6f79c6a8.th" - - if silence_tokens is None: - silence_tokens = [] - - if multi_trial is None: - multi_trial = [] - - delay_pattern_increment = args.n_codebooks + 1 # from snippet - - info = torchaudio.info(reference_speech) - prompt_end_frame = int(cut_off_sec * info.sample_rate) - - # 6. Tokenizers - audio_tokenizer = AudioTokenizer(signature=signature) - text_tokenizer = TextTokenizer(backend="espeak") - - # 7. decode_config - decode_config = { - "top_k": top_k, - "top_p": top_p, - "min_p": min_p, - "temperature": temperature, - "stop_repetition": stop_repetition, - "kvcache": kvcache, - "codec_audio_sr": codec_audio_sr, - "codec_sr": codec_sr, - "silence_tokens": silence_tokens, - "sample_batch_size": sample_batch_size, - } - - # 8. Run inference - print("[Info] Running TTS inference...") - concated_audio, gen_audio = inference_one_sample( - model, args, phn2num, text_tokenizer, audio_tokenizer, - reference_speech, target_text, - device, decode_config, - prompt_end_frame=prompt_end_frame, - target_generation_length=target_generation_length, - delay_pattern_increment=delay_pattern_increment, - prefix_transcript=prefix_transcript, - multi_trial=multi_trial, - repeat_prompt=repeat_prompt, - ) - - # The model returns a list of waveforms, pick the first - concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() - - # 9. Save generated audio - os.makedirs(output_dir, exist_ok=True) - out_filename = "generated.wav" - out_path = os.path.join(output_dir, out_filename) - torchaudio.save(out_path, gen_audio, codec_audio_sr) - - print(f"[Success] Generated audio saved to {out_path}") - return out_path # Return the path for Gradio to load - - -############################ -# Transcription function -############################ - -def transcribe_audio(reference_speech): - """ - Transcribe uploaded reference audio with Whisper, return text. - If no file, return empty string. - """ - if reference_speech is None: - return "" - audio_path = reference_speech # Because type="filepath" - - if not os.path.exists(audio_path): - return "File not found." - - print("[Info] Transcribing with Whisper...") - model = whisper.load_model("medium") # or "large-v2" etc. - result = model.transcribe(audio_path) - return result["text"] - -############################ -# Gradio UI -############################ - -def main(): - with gr.Blocks() as demo: - gr.Markdown("## VoiceStar TTS with Editable Reference Text") - - with gr.Row(): - reference_speech_input = gr.Audio( - label="Reference Speech", - type="filepath", - elem_id="ref_speech" - ) - transcribe_button = gr.Button("Transcribe") - - # The transcribed text appears here and can be edited - reference_text_box = gr.Textbox( - label="Reference Text (Editable)", - placeholder="Click 'Transcribe' to auto-fill from reference speech...", - lines=2 - ) - - target_text_box = gr.Textbox( - label="Target Text", - value="VoiceStar is a very interesting model, it's duration controllable and can extrapolate to unseen duration.", - lines=3 - ) - - model_name_box = gr.Textbox( - label="Model Name", - value="VoiceStar_840M_40s" - ) - - model_root_box = gr.Textbox( - label="Model Root Directory", - value="/data1/scratch/pyp/BoostedVoiceEditor/runs" - ) - - reference_duration_box = gr.Textbox( - label="Target Duration (Optional)", - placeholder="Leave empty for auto-estimate." - ) - - top_k_box = gr.Number(label="top_k", value=10) - temperature_box = gr.Number(label="temperature", value=1.0) - kvcache_box = gr.Number(label="kvcache (1 or 0)", value=1) - repeat_prompt_box = gr.Number(label="repeat_prompt", value=1) - stop_repetition_box = gr.Number(label="stop_repetition", value=3) - seed_box = gr.Number(label="Random Seed", value=1) - output_dir_box = gr.Textbox(label="Output Directory", value="./generated_tts") - - generate_button = gr.Button("Generate TTS") - output_audio = gr.Audio(label="Generated Audio", type="filepath") - - # 1) When user clicks "Transcribe", we call `transcribe_audio` - transcribe_button.click( - fn=transcribe_audio, - inputs=[reference_speech_input], - outputs=[reference_text_box], - ) - - # 2) The actual TTS generation function. - def gradio_inference( - reference_speech, - reference_text, - target_text, - model_name, - model_root, - target_duration, - top_k, - temperature, - kvcache, - repeat_prompt, - stop_repetition, - seed, - output_dir - ): - # Convert any empty strings to None for optional fields - dur = float(target_duration) if target_duration else None - - out_path = run_inference( - reference_speech=reference_speech, - reference_text=reference_text if reference_text else None, - target_text=target_text, - model_name=model_name, - model_root=model_root, - target_duration=dur, - top_k=int(top_k), - temperature=float(temperature), - kvcache=int(kvcache), - repeat_prompt=int(repeat_prompt), - stop_repetition=int(stop_repetition), - seed=int(seed), - output_dir=output_dir - ) - return out_path - - # 3) Link the "Generate TTS" button - generate_button.click( - fn=gradio_inference, - inputs=[ - reference_speech_input, - reference_text_box, - target_text_box, - model_name_box, - model_root_box, - reference_duration_box, - top_k_box, - temperature_box, - kvcache_box, - repeat_prompt_box, - stop_repetition_box, - seed_box, - output_dir_box - ], - outputs=[output_audio], - ) - - demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/inference_tts_utils.py b/inference_tts_utils.py deleted file mode 100644 index 412e7c2..0000000 --- a/inference_tts_utils.py +++ /dev/null @@ -1,155 +0,0 @@ -import argparse, pickle -import logging -import os, random -import numpy as np -import torch -import torchaudio - -from data.tokenizer import ( - AudioTokenizer, - TextTokenizer, - tokenize_audio, - tokenize_text -) -import argparse, time, tqdm - - -# this script only works for the musicgen architecture -def get_args(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file") - parser.add_argument("--audio_root", type=str, default="path/to/audio_folder") - parser.add_argument("--exp_dir", type=str, default="path/to/model_folder") - parser.add_argument("--seed", type=int, default=1) - parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for') - parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes') - parser.add_argument("--top_k", type=int, default=0, help="sampling param") - parser.add_argument("--top_p", type=float, default=0.8, help="sampling param") - parser.add_argument("--temperature", type=float, default=1.0, help="sampling param") - parser.add_argument("--output_dir", type=str, default=None) - parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--signature", type=str, default=None, help="path to the encodec model") - parser.add_argument("--crop_concat", type=int, default=0) - parser.add_argument("--stop_repetition", type=int, default=-1, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it") - parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without') - parser.add_argument("--sample_batch_size", type=int, default=1, help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation") - parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default") - return parser.parse_args() - - -@torch.no_grad() -def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame, target_generation_length, delay_pattern_increment, prefix_transcript=None, quiet=False, repeat_prompt=0, multi_trial=[]): - # seq_len_thres = 500 # 10s, 26% of the data in seed tts - # encode audio - encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame) - # if sequence length is shorter than seq_len_thres, repeat the audio - # if encoded_frames.shape[2] < seq_len_thres: - # encoded_frames = torch.cat([encoded_frames, encoded_frames, encoded_frames], dim=2) - # doubled = True - single_encoded_frames = encoded_frames - - if isinstance(repeat_prompt, int) and repeat_prompt > 0: - cur_repeat_prompt = repeat_prompt - while cur_repeat_prompt > 0: - encoded_frames = torch.cat([encoded_frames, single_encoded_frames], dim=2) - cur_repeat_prompt -= 1 - elif isinstance(repeat_prompt, str) and repeat_prompt.lower() == "max": - repeat_prompt = 0 - while encoded_frames.shape[2] + decode_config['codec_sr'] * target_generation_length + delay_pattern_increment + single_encoded_frames.shape[2] < model_args.audio_max_length * decode_config['codec_sr']: - encoded_frames = torch.cat([encoded_frames, single_encoded_frames], dim=2) - repeat_prompt += 1 - if getattr(model_args, "y_sep_token", None) != None: - encoded_frames = torch.cat([encoded_frames, torch.LongTensor([model_args.y_sep_token]*model_args.n_codebooks).unsqueeze(0).unsqueeze(2).to(encoded_frames.device)], dim=2) - # print(encoded_frames.shape) - original_audio = encoded_frames.transpose(2,1) # [1,T,K] - assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape - - # phonemize - if isinstance(target_text, list): - text_tokens = [phn2num[phn] for phn in target_text if phn in phn2num] - else: - text_tokens = [phn2num[phn] for phn in - tokenize_text( - text_tokenizer, text=target_text.strip() - ) if phn in phn2num - ] - if getattr(model_args, "x_sep_token", None) != None: - assert prefix_transcript != None, "prefix_transcript must be provided if x_sep_token is not None" - if prefix_transcript is not None: - if isinstance(prefix_transcript, list): - prefix_tokens = [phn2num[phn] for phn in prefix_transcript if phn in phn2num] - else: - prefix_tokens = [phn2num[phn] for phn in - tokenize_text( - text_tokenizer, text=prefix_transcript.strip() - ) if phn in phn2num - ] - # if doubled: - # prefix_tokens = prefix_tokens + prefix_tokens + prefix_tokens - single_prefix_tokens = prefix_tokens - while repeat_prompt > 0: - prefix_tokens = prefix_tokens + single_prefix_tokens - repeat_prompt -= 1 - if getattr(model_args, "x_sep_token", None) != None: - text_tokens = prefix_tokens + [getattr(model_args, "x_sep_token", None)] + text_tokens - else: - text_tokens = prefix_tokens + text_tokens - if getattr(model_args, "add_eos_to_text", 0) != 0: - text_tokens.append(model_args.add_eos_to_text) - if getattr(model_args, "add_bos_to_text", 0) != 0: - text_tokens = [model_args.add_bos_to_text] + text_tokens - text_tokens = torch.LongTensor(text_tokens).unsqueeze(0) - text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]]) - - if not quiet: - logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.") - - - if getattr(model_args, "parallel_pattern", 0) != 0: - tgt_y_lens = torch.LongTensor([int(original_audio.shape[1] + decode_config['codec_sr'] * target_generation_length + 2)]) # parallel pattern, therefore only add the empty_token (i.e. the sos token) and eos (i.e. 2 more tokens). Note that the delayed pattern between, both sos and eos is counted (sos is counted in the n_codebooks, eos is counted in the 1) - else: - tgt_y_lens = torch.LongTensor([int(original_audio.shape[1] + decode_config['codec_sr'] * target_generation_length + delay_pattern_increment)]) # delay pattern increment has accounted for the added eos - - # forward - assert decode_config['sample_batch_size'] <= 1 - stime = time.time() - assert multi_trial == [] - if not quiet: - logging.info(f"running inference with batch size 1") - concat_frames, gen_frames = model.inference_tts( - text_tokens.to(device), - text_tokens_lens.to(device), - original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8] - tgt_y_lens = tgt_y_lens.to(device), - top_k=decode_config['top_k'], - top_p=decode_config['top_p'], - min_p=decode_config['min_p'], - temperature=decode_config['temperature'], - stop_repetition=decode_config['stop_repetition'], - kvcache=decode_config['kvcache'], - silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens'] - ) # output is [1,K,T] - if not quiet: - logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.") - - logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.") - - # for timestamp, codes in enumerate(gen_frames[0].transpose(1,0)): - # logging.info(f"{timestamp}: {codes.tolist()}") - # decode (both original and generated) - # concat_sample = audio_tokenizer.decode( - # [(concat_frames, None)] # [1,T,8] -> [1,8,T] - # ) - if getattr(model_args, "y_sep_token", None) != None: - concat_frames = torch.cat([concat_frames[:, :, :original_audio.shape[1]-1], concat_frames[:, :, original_audio.shape[1]:]], dim=2) - concat_sample = audio_tokenizer.decode( - concat_frames # [1,8,T] - ) - gen_sample = audio_tokenizer.decode( - gen_frames - ) - #Empty cuda cache between runs - if torch.cuda.is_available(): - torch.cuda.empty_cache() - # return - return concat_sample, gen_sample \ No newline at end of file diff --git a/models/voice_star.py b/models/voice_star.py deleted file mode 100644 index 7cc4806..0000000 --- a/models/voice_star.py +++ /dev/null @@ -1,784 +0,0 @@ -import random, os, copy -from typing import Dict, Iterator, List, Tuple, Union -import logging -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchmetrics.classification import MulticlassAccuracy -import torch.distributed as dist - -from .modules.utils import make_pad_mask, generate_partial_autoregressive_mask - -from .modules.embedding import SinePositionalEmbedding, TokenEmbedding, SinePositionalEmbedding_progress -from .modules.transformer import ( - AdaptiveLayerNorm, - LayerNorm, - TransformerDecoderLayer, - TransformerDecoder, - TransformerEncoder, - TransformerEncoderLayer, -) - -def top_k_top_p_filtering( - logits, top_k=0, top_p=1.0, min_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 -): - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (batch size, vocabulary size) - if top_k > 0: keep only top k tokens with highest probability (top-k filtering). - if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - Make sure we keep at least min_tokens_to_keep per batch example in the output - From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 - """ - if min_p < 1.0: - probs = F.softmax(logits, dim=-1) - indices_to_remove = probs < min_p - if not torch.any(indices_to_remove.sum(-1) == logits.size(-1)): - logits[indices_to_remove] = filter_value - top_k = 0 - top_p = 1.0 - # else will use other types of sampling, or no filtering - - # If top_k is a single integer - if isinstance(top_k, int) and top_k > 0: - # Safety check to ensure we don't ask for more than available - top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) - - # Remove all tokens with a probability less than the last token of the top-k - threshold = torch.topk(logits, top_k, dim=-1)[0][..., -1, None] - indices_to_remove = logits < threshold - logits[indices_to_remove] = filter_value - - # If top_k is a list, assume it has the same length as M - elif isinstance(top_k, list): - # Ensure the length matches the first dimension - assert len(top_k) == logits.size(0), \ - f"top_k list length ({len(top_k)}) must match logits.size(0) ({logits.size(0)})" - - for i in range(logits.size(0)): - k_i = top_k[i] - if k_i > 0: - # Safety check - k_i = min(max(k_i, min_tokens_to_keep), logits.size(-1)) - row_threshold = torch.topk(logits[i], k_i, dim=-1)[0][-1] - indices_to_remove_i = logits[i] < row_threshold - logits[i, indices_to_remove_i] = filter_value - - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum( - F.softmax(sorted_logits, dim=-1), dim=-1 - ) - - # Remove tokens with cumulative probability above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs > top_p - if min_tokens_to_keep > 1: - # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) - sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ - ..., :-1 - ].clone() - sorted_indices_to_remove[..., 0] = 0 - - return logits - - -def topk_sampling(logits, top_k=10, top_p=1.0, min_p=1.0, temperature=1.0): - # temperature: (`optional`) float - # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. - # top_k: (`optional`) int - # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. - # top_p: (`optional`) float - # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. - - # Temperature (higher temperature => more likely to sample low probability tokens) - if temperature != 1.0: - logits = logits / temperature - # Top-p/top-k filtering - logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_p=min_p) - # Sample - token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) - return token - - - -class VoiceStar(nn.Module): - def __init__(self, args): - super().__init__() - self.args = args - assert self.args.enc_dec ^ self.args.dec, f"self.args.enc_dec: {self.args.enc_dec}, self.args.dec: {self.args.dec}" - if not getattr(self.args, "special_first", False): - self.args.special_first = 0 - if not getattr(self.args, "n_special", False): - self.args.n_special = 3 - self.args.eos = getattr(self.args, "eos", -1) - self.eog = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long), requires_grad=False) # [K 1] - if self.args.eos > 0: - assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos - self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1] - if type(self.args.audio_vocab_size) == str: - self.args.audio_vocab_size = eval(self.args.audio_vocab_size) - if type(self.args.audio_vocab_size) == list: # otherwise they are all lists - assert self.args.special_first - - - self.n_text_tokens = self.args.text_vocab_size + 1 - assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}" - - if self.args.special_first and type(self.args.audio_vocab_size) == list: - self.n_audio_tokens = [tok + self.args.n_special for tok in self.args.audio_vocab_size] # special tokens: empty token, EOG token, audio pad token - assert self.args.empty_token == 0, self.args.empty_token - assert self.args.eog == 1, self.args.eog - assert self.args.audio_pad_token == 2, self.args.audio_pad_token - else: - self.n_audio_tokens = [self.args.audio_vocab_size + self.args.n_special] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token - assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token - assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog - assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token - - self.text_embedding = TokenEmbedding( - dim_model=self.args.d_model, - vocab_size=self.n_text_tokens, - dropout=self.args.text_embedding_dropout - ) - - self.audio_embedding = nn.ModuleList( - [ - TokenEmbedding( - dim_model=self.args.audio_embedding_dim, - vocab_size=self.n_audio_tokens[k], - dropout=self.args.audio_embedding_dropout - ) for k in range(self.args.n_codebooks) - ] - ) - - rope_base = getattr(self.args, "rope_base", None) - use_sinusoidal = getattr(self.args, "use_sinusoidal", False) - use_sinusoidal_progress = getattr(self.args, "use_sinusoidal_progress", False) - logging.info(f"rope_base: {rope_base}, use_sinusoidal: {use_sinusoidal}") - if use_sinusoidal: - self.text_positional_embedding = SinePositionalEmbedding( - self.args.d_model, - dropout=self.args.text_positional_embedding_dropout, - scale=False, - alpha=True, # learnable scaler, scale the volume of positional embedding - ) - self.audio_positional_embedding = SinePositionalEmbedding( - self.args.d_model, - dropout=self.args.audio_positional_embedding_dropout, - scale=False, - alpha=True, # learnable scaler, scale the volume of positional embedding - ) - elif use_sinusoidal_progress: - self.text_positional_embedding = SinePositionalEmbedding_progress( - self.args.d_model, - dropout=self.args.text_positional_embedding_dropout, - scale=False, - alpha=True, # learnable scaler, scale the volume of positional embedding - args = self.args - ) - self.audio_positional_embedding = SinePositionalEmbedding_progress( - self.args.d_model, - dropout=self.args.audio_positional_embedding_dropout, - scale=False, - alpha=True, # learnable scaler, scale the volume of positional embedding - args = self.args - ) - - else: - class NoOp: - def __init__(self): - pass - def __call__(self, *args, **kwargs): - return args[0] - self.text_positional_embedding = NoOp() - self.audio_positional_embedding = NoOp() - - if self.args.enc_dec: - enc_layer = TransformerEncoderLayer( - d_model=self.args.d_model, - nhead=self.args.nhead, - dim_feedforward=self.args.d_model*4, - dropout=self.args.trm_dropout, - batch_first=True, - norm_first=True, - layer_norm_cls=LayerNorm - ) # use the pre-norm arch - - self.encoder = TransformerEncoder( - encoder_layer=enc_layer, - num_layers=self.args.num_encoder_layers, - norm=LayerNorm(self.args.d_model), - rope_base = self.args.rope_base, - d_model = self.args.d_model, - nhead = self.args.nhead, - args = self.args - ) # use the pre-norm arch - - dec_layer = TransformerDecoderLayer( - d_model=self.args.d_model, - nhead=self.args.nhead, - dim_feedforward=self.args.d_model*4, - dropout=self.args.trm_dropout, - batch_first=True, - norm_first=True, - layer_norm_cls=LayerNorm - ) - - self.decoder = TransformerDecoder( - decoder_layer=dec_layer, - num_layers=self.args.num_decoder_layers, - norm=LayerNorm(self.args.d_model), - rope_base = self.args.rope_base, - d_model = self.args.d_model, - nhead = self.args.nhead, - args = self.args - ) # NOTE: this one I use torch.nn native implementation, as it's not implemented in .modules - - else: - dec_layer = TransformerEncoderLayer( - self.args.d_model, - self.args.nhead, - dim_feedforward=self.args.d_model * 4, - dropout=self.args.trm_dropout, - batch_first=True, - norm_first=True, - layer_norm_cls=LayerNorm - ) - self.decoder = TransformerEncoder( - dec_layer, - num_layers=self.args.num_decoder_layers, - norm=LayerNorm(self.args.d_model), - ) - - if type(self.args.audio_vocab_size) == int: - self.predict_layer = nn.ModuleList( - [ - nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks) - ] - ) - else: - self.predict_layer = nn.ModuleList( - [ - nn.Sequential(nn.Linear(self.args.d_model, self.args.d_model//2), nn.GELU(), nn.Linear(self.args.d_model//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks) - ] - ) - - self.accuracy_metrics = nn.ModuleList( - [MulticlassAccuracy( - self.n_audio_tokens[k], - top_k=10, - average="micro", - multidim_average="global", - ignore_index=None, - ) for k in range(self.args.n_codebooks)] - ) - - if self.args.eog_weight != 1: - raise NotImplementedError("now have different vocab_size for different codebooks, therefore currently don't support eog_weight") - self.class_weight = nn.Parameter(torch.ones(self.n_audio_tokens), requires_grad=False) - self.class_weight.data[self.args.eog] = self.args.eog_weight - - def dec_forward( - self, - x_input, - x_lens, - x_attention_mask, - x_padding_mask, - y_input, - new_y_lens, - y_attention_mask, - y_padding_mask, - need_weights=False, - past=None, - last_3_tokens=False - ): - x_attn_mask = F.pad( - x_attention_mask, - (0, new_y_lens.max()), - value=True, - ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper - y_attn_mask = F.pad( - y_attention_mask, - (x_lens.max(), 0), # y is padded at the front - value=False, - ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive - xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) - - # merge key padding and attention masks - bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max() - xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1) - _xy_padding_mask = ( - xy_padding_mask.view(bsz, 1, 1, src_len) - .expand(-1, self.args.nhead, -1, -1) - .reshape(bsz * self.args.nhead, 1, src_len) - ) - xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) - - new_attn_mask = torch.zeros_like(xy_attn_mask) - new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) - xy_attn_mask = new_attn_mask - - xy_input = torch.cat([x_input, y_input], dim=1) - if need_weights: - raise NotImplementedError("not implemented yet") - out, layer_attn_weights = self.decoder((xy_input, None), mask=xy_attn_mask, need_weights=True) - return layer_attn_weights - - if past == None: # do not use kvcache - out, _ = self.decoder((xy_input, None), mask=xy_attn_mask) - return out[:, x_lens.max():], None - else: # use kvcache - if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet - if last_3_tokens: - xy_input = xy_input[:, -3:] - xy_attn_mask = xy_attn_mask[:, -3:] - else: - xy_input = xy_input[:, -1:] - xy_attn_mask = xy_attn_mask[:, -1:] - - out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past) - if isinstance(out, tuple): # get rid of stage_embedding - out = out[0] - - if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet - return out[:, x_lens.max():], present - else: # used kvcache - return out, present - - def enc_dec_forward( - self, - xa, - x_attention_mask, - x_padding_mask, - y_input, - new_y_lens, - y_attention_mask, - y_padding_mask, - tgt_y_lens=None, - need_weights=False, - past=None, - last_3_tokens=False - ): - assert not need_weights - if past != None and past.ndim > 3: - y_input = y_input[:, -1:] - y_attention_mask = y_attention_mask[-1:] - yhat, present = self.decoder(tgt=y_input, memory=xa, tgt_mask=y_attention_mask, tgt_key_padding_mask=y_padding_mask, memory_key_padding_mask=x_padding_mask, query_lens=tgt_y_lens, past=past) - return yhat, present - - def forward(self, batch, calc_loss = False): - """ - Args: - x: - A 2-D tensor of shape (N, S). - x_lens: - A 1-D tensor of shape (N,). It contains the number of tokens in `x` - before padding. - y: - A 3-D tensor of shape (N, K, T). - where K is the number of codebooks - y_lens: - A 1-D tensor of shape (N,). It contains the number of tokens in `x` - before padding. - """ - x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"] - if len(x) == 0: - return None - x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x - y = y[...,:y_lens.max()] - assert x.ndim == 2, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape - assert y_lens.ndim == 1, y_lens.shape - x_padding_mask = make_pad_mask(x_lens).to(x.device) - x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device) - x_input = self.text_embedding(x) - x_input = self.text_positional_embedding(x_input, x_lens) - y_with_eos = [torch.cat([item[:, :y_lens[i]], self.eos], dim=-1) for i, item in enumerate(y)] - targets = y_with_eos - # apply delayed stacking on y - shifted_y = [] - patterns = [] - new_y_lens = [] - if getattr(self, "empty_tokens", None) == None: - self.empty_tokens = torch.full((self.args.n_codebooks, self.args.n_codebooks), self.args.empty_token, dtype=torch.long).to(y.device) # [K, K] - for i in range(len(y)): - tmp = torch.cat([y_with_eos[i], self.empty_tokens], dim=-1) # [K, T+n_codebooks] - for ii in range(self.args.n_codebooks): - tmp[ii] = torch.roll(tmp[ii], shifts=ii+1, dims=0) - shifted_y.append(tmp.transpose(1,0)) # [K, T+n_codebooks] -> [T+n_codebooks, K] - new_y_lens.append(y_with_eos[i].shape[1] + self.empty_tokens.shape[1]) - - new_y_lens = torch.LongTensor(new_y_lens).to(y.device) - - cated_y = torch.nn.utils.rnn.pad_sequence(shifted_y, batch_first=False, padding_value=self.args.audio_pad_token) - assert cated_y.shape == torch.Size([max(new_y_lens), len(y), self.args.n_codebooks]), cated_y.shape - cated_y = cated_y.permute(2,0,1) # [T,B,K]->[K,T,B] - stacked_embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D] - assert stacked_embedded_y.shape[0] == self.args.n_codebooks and stacked_embedded_y.shape[2] == len(y) and stacked_embedded_y.shape[-1] == self.args.d_model, stacked_embedded_y.shape - embedded_y = stacked_embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D] - embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D] - assert embedded_y.shape[1:] == torch.Size([max(new_y_lens), self.args.d_model]), embedded_y.shape - y_input = self.audio_positional_embedding(embedded_y, new_y_lens) - y_padding_mask = make_pad_mask(new_y_lens).to(y.device) - y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device) - if self.args.dec: - y_out = self.dec_forward( - x_input, - x_lens, - x_attention_mask, - x_padding_mask, - y_input, - new_y_lens, - y_attention_mask, - y_padding_mask - ) - else: - xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask) - y_out = self.enc_dec_forward( - xa, - x_attention_mask, - x_padding_mask, - y_input, - new_y_lens, - y_attention_mask, - y_padding_mask - ) - y_out = y_out[0] # no kv-caching during training - assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D] - logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card] - assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape - logits_use = [logit[:, :new_y_lens[i]] for i, logit in enumerate(logits)] # each of shape [K, T, card] - logits_final = [] - for i, logit in enumerate(logits_use): - logit_copy = logit.clone() - for ii in range(self.args.n_codebooks): - logit_copy[ii] = torch.roll(logit_copy[ii], shifts=-ii, dims=0) - logit = logit_copy[:, :-self.args.n_codebooks] # [K, T, card] -> [K, T-n_codebooks, card] - logits_final.append(logit) - if self.args.no_loss_on_prefix: - assert "y_sep_token_position" in batch, f"y_sep_token_position should be in batch, but it's not" - logit_temp = [] - target_temp = [] - for jj, (logit, target) in enumerate(zip(logits_final, targets)): - # TODO already taken into consideration in depth transformer - logit_temp.append(logit[:, batch['y_sep_token_position'][jj]:]) - target_temp.append(target[:, batch['y_sep_token_position'][jj]:]) - logits_final = logit_temp - targets = target_temp - logits = torch.cat(logits_final, dim=1) # [K, T1+T2+T3+..., card] - targets = torch.cat(targets, dim=1) # [K, T1+T2+T3+...] - - assert targets.shape[:2] == logits.shape[:2], f"{targets.shape}, {logits.shape}" - loss = [] - ntokens = [] - top10acc = [] - for k, (logit, target) in enumerate(zip(logits, targets)): # even though the loss and top10acc is calculated in a loop (loop through n_codebooks), validation is still taking a lot of mem, need to optimize this a little more - loss.append(F.cross_entropy(logit, target, reduction='mean', weight=self.class_weight.data if self.args.eog_weight!=1 else None, ignore_index=self.args.y_sep_token if self.args.y_sep_token != None else -100)) # ignore audio sep token as it's unpredictable (like the random early stop bug happened in 2023) - # NOTE have to ignore the sep token in the loss calculation - top10acc.append(self.accuracy_metrics[k](logit.detach(), target)) - ntokens.append(len(logit)) - - all_ntokens = sum(ntokens) - if self.args.codebook_weight != None: - codebook_weight = eval(self.args.codebook_weight) if isinstance(self.args.codebook_weight, str) else self.args.codebook_weight - else: - codebook_weight = [1.] * self.args.n_codebooks - perplexity_by_codebook = [torch.exp(l) for l in loss] - loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)]) - - top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)] - top10acc = sum(top10acc_by_codebook) - - ntokens = torch.tensor(all_ntokens).to(logits.device) - - ret = { - "loss": loss, - "perplexity_by_codebook": perplexity_by_codebook, - "top10acc": top10acc, - "top10acc_by_codebook": top10acc_by_codebook, - "effective_ntoken": ntokens, - } - - return ret - - def inference_tts( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: torch.Tensor, - tgt_y_lens: torch.Tensor, # - top_k: Union[int, list[int]]=-100, - top_p: float=1.0, - min_p: float=1.0, - temperature: float=1.0, - stop_repetition: int=3, - kvcache: int=1, - silence_tokens: list[int]=[], - multi_trial: list[int]=[], - *kargs - ) -> torch.Tensor: - """ - different from inference_tts, this implementation uses kvcache, which should have significant speed up - Args: - x: - A 2-D tensor of shape (1, L). - x_lens: - A 1-D tensor of shape (1,). It contains the number of tokens in `x` - before padding. - y: - A 3-D tensor of shape (1, T, K). - tgt_y_lens: - *new arg* this specify the target length of y - top_k: (`optional`) int - The number of highest probability tokens to keep for top-k-filtering. Default to -100. - top_p: (`optional`) float - For Neucleus sampling - min_p: (`optional`) float - For min_p filtered sampling - temperature: (`optional`) float - The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. - multi_trial: (`optional`) list[int] - If not empty, it will be [n_trials, beam_size, trial_interval] - from the start and begining trial_interval, we duplicate the current sample by beam_size, - at the end of every trial_interval, we choose the sample with the highest log likelihood to keep and throw away the rest - """ - eog_inference = self.args.eos if self.args.eos>0 else self.args.eog - assert x.ndim == 2, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.ndim == 3, y.shape - if self.args.special_first: - y = y + int(self.args.n_special) - y = y.transpose(2,1) # [1,T,K] -> [1,K,T] - assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding - - # make x attention mask and x_input - x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device) - # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device) - x_input = self.text_embedding(x) - x_input = self.text_positional_embedding(x_input, x_lens) - - y_len = y.shape[2] - y_lens = torch.LongTensor([y_len]).to(y.device) - - # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario - rearranged_y = [[y[0]]] - assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape - - # # shift y to create the delayed pattern - if getattr(self, "empty_tokens", None) == None: - self.empty_tokens = torch.full((self.args.n_codebooks, self.args.n_codebooks), self.args.empty_token, dtype=torch.long).to(y.device) # [K, K] - temp = rearranged_y[0][0] - assert temp.ndim == 2 and temp.shape[0] == self.args.n_codebooks, temp.shape - temp = torch.cat([temp, self.empty_tokens], dim=-1) # [K, T+n_codebooks] - for ii in range(self.args.n_codebooks): - temp[ii] = torch.roll(temp[ii], shifts=ii+1, dims=0) - shifted_y = [[temp]] - - # below is different from forward or inference - # where we cut this shifted part - shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)] - assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0] - # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that - # next section is concate tensors of each sample to one tensor, which we also don't need - cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B] - new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device) - assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1)) - assert not (cated_y == self.args.audio_pad_token).any(), cated_y - - # replace tokens in y with the embeddings, add sum codebooks up - embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D] - assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape - assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape - embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D] - embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D] - - # positional embedding - y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens) - - # make attention mask and padding mask - y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) - - x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device) - y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) - - # entering the generation stage - # starting from line 708 - codebook_eog = [False] * self.args.n_codebooks - generated = [] # doesn't contain any empty token, contain eog - cur_generated = [] - # say 0 is empty, 4 is eog - # tensor([[ 1, 2, 3, 4, 0, 0], - # [ 0, 1, 2, 3, 4, 0], - # [ 0, 0, 1, 2, 3, 4]]) - num_gen = [] - cur_num_gen = 0 - ##################### silence repetition handling ##################### - ##################### silence repetition handling ##################### - # silence_tokens = [1388,1898,131] # [1388, 2045, 2041, 1996] - # silence_tokens = [] - consec_silence_count = 0 - prev_token = None - ##################### silence repetition handling ##################### - ##################### silence repetition handling ##################### - - def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, min_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen): - if n_eog == 0: - logits_adjust = logits - for jj in range(1,self.args.n_codebooks): - logits_adjust[jj][eog_inference] = -10000 - logits_adjust[jj][self.args.empty_token] = -10000 - if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early - logits_adjust[0][eog_inference] = -10000 - ##################### silence repetition handling ##################### - if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition: - if logits_adjust[0, prev_token] < 0: - logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1)) - else: - logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1)) - ##################### silence repetition handling ##################### - samples = topk_sampling( - logits_adjust, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature - ) # [K, 1] - assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}" - if cur_num_gen < self.args.n_codebooks-1: - for jj in range(1, self.args.n_codebooks - cur_num_gen): - samples[-jj, 0] = self.args.empty_token - - if ( - samples[0,0] == eog_inference or torch.argmax(logits[0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr//4) - ) or self.args.rope_base is not None and not self.args.decoder_regular_rope and self.args.progress_no_multiple and cur_num_gen > (tgt_y_lens[0] + self.args.encodec_sr * getattr(self.args, "extra_cutoff", 5)): - # last one condition in the first bracket means y is already too long, shouldn't happen, but put it here - # the second bracket means we are using progress-monitoring RoPE, but the model is generating excessively long sequence (5 seconds more than specified), in which case we terminate the generation - samples[0,0] = eog_inference - codebook_eog[0] = True - ##################### silence repetition handling ##################### - if samples[0,0] in silence_tokens and samples[0,0] == prev_token: - consec_silence_count += 1 - else: - consec_silence_count = 0 - prev_token = samples[0,0] - ##################### silence repetition handling ##################### - return samples, codebook_eog, prev_token, consec_silence_count - else: - assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}" - logits_adjust = logits - for jj in range(n_eog+1,self.args.n_codebooks): - logits_adjust[jj][eog_inference] = -10000 - logits_adjust[jj][self.args.empty_token] = -10000 - samples = topk_sampling( - logits_adjust, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature - ) # [K, 1] - for jj in range(n_eog): - samples[jj, 0] = self.args.empty_token - samples[n_eog, 0] = eog_inference - codebook_eog[n_eog] = True - return samples, codebook_eog, prev_token, consec_silence_count - - # prepare the cache placeholder - # n_layers, 2, bsz, num_heads, src_len, head_dim, 2 means [key, value] - past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None - if self.args.enc_dec: - xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask) - while True: - if self.args.dec: - y_out, present = self.dec_forward( - x_input, - x_lens, - x_attention_mask, - x_padding_mask, - y_input, - new_y_lens, - y_attention_mask, - y_padding_mask, - past=past - ) - else: - y_out, present = self.enc_dec_forward( - xa, - x_attention_mask, - x_padding_mask, - y_input, - new_y_lens, - y_attention_mask, - y_padding_mask, - tgt_y_lens=tgt_y_lens, - past=past - ) - if past != None: - past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype) - - - y_out = y_out[:, -1:] # only take the last token - - logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card] - logits = logits.squeeze(0).squeeze(1) # [K card] - assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}" - - n_eog = sum(codebook_eog) - assert n_eog < self.args.n_codebooks - if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans - for jj in range(self.args.n_codebooks): - logits[jj][self.args.eog] = -10000. - - samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, min_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen) - # samples.shape is [K,1] - # ge samples_emb - samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D] - samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D] - - cur_num_gen += 1 - cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K] - - if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done - codebook_eog = [False] * self.args.n_codebooks - num_gen.append(cur_num_gen) - cur_num_gen = 0 - generated.append(cur_generated) - cur_generated = [] - break - else: - assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}" - - embedded_y = torch.cat([embedded_y, samples_emb], dim=1) - new_y_lens = torch.LongTensor([embedded_y.shape[1]]).to(y.device) - y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens) # [B T D] - # make attention mask and padding mask - y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device) - y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device) - - assert len(generated) == 1, f"len(generated): {len(generated)}" - - # revert the pattern - flatten_gen = [] - for l, orig_span in enumerate(generated): - span = torch.stack(orig_span, dim=0) # [T, K] - span = span.transpose(1,0) # [K, T] - assert span.shape[0] == self.args.n_codebooks, span.shape - unshifted_span = [] - for j, s in enumerate(span): - start_from = j - end_at = - (self.args.n_codebooks - start_from) - unshifted_span.append(s[start_from:end_at]) - unshifted_span = torch.stack(unshifted_span, dim=0) - - assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}" - - flatten_gen.append(unshifted_span) - assert len(flatten_gen) == 1, len(flatten_gen) - - # combine - res = [y[0], flatten_gen[0]] - res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T] - expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen]) - assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}" - - if self.args.special_first: - res = res - int(self.args.n_special) - flatten_gen = flatten_gen - int(self.args.n_special) - return res, flatten_gen[0].unsqueeze(0) \ No newline at end of file diff --git a/copy_codebase.py b/scripts/copy_codebase.py similarity index 94% rename from copy_codebase.py rename to scripts/copy_codebase.py index 0bee09d..ea2fcc5 100644 --- a/copy_codebase.py +++ b/scripts/copy_codebase.py @@ -1,8 +1,8 @@ - import os import shutil import fnmatch + def parse_gitignore(gitignore_path): """Parse a .gitignore file and return a list of patterns.""" patterns = [] @@ -16,6 +16,7 @@ def parse_gitignore(gitignore_path): patterns.append(line) return patterns + def file_matches_patterns(file_path, patterns): """Check if a file matches any of the patterns in .gitignore.""" for pattern in patterns: @@ -23,8 +24,9 @@ def file_matches_patterns(file_path, patterns): return True return False + def copy_codebase(src, dst, max_size_mb=5, gitignore_path=None): - """ Copy files from src to dst, skipping files larger than max_size_mb and matching .gitignore patterns. """ + """Copy files from src to dst, skipping files larger than max_size_mb and matching .gitignore patterns.""" if gitignore_path and os.path.exists(gitignore_path): patterns = parse_gitignore(gitignore_path) else: @@ -50,7 +52,6 @@ def copy_codebase(src, dst, max_size_mb=5, gitignore_path=None): print(f"Skipping {file_path} because it's larger than {max_size_mb}MB") continue - # Make sure the destination directory exists os.makedirs(os.path.dirname(dst_path), exist_ok=True) shutil.copy(file_path, dst_path) diff --git a/z_scripts_new/e1_840M_30s.sh b/scripts/e1_840M_30s.sh similarity index 100% rename from z_scripts_new/e1_840M_30s.sh rename to scripts/e1_840M_30s.sh diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..3429a83 --- /dev/null +++ b/setup.py @@ -0,0 +1,44 @@ +from setuptools import setup, find_packages +from voicestar import __version__ + +setup( + name="voicestar", + version=__version__, + description="VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://github.com/jasonppy/VoiceStar", + author="Puyuan Peng", + license="MIT", + packages=find_packages(), + install_requires=[ + "torch", + "torchaudio", + "numpy", + "tqdm", + "fire", + "phonemizer", + "torchmetrics", + "einops", + "omegaconf==2.3.0", + "openai-whisper", + "transformers[torch]", + "huggingface_hub", + "gradio", + "click", + "txtsplit", + ], + extras_require={ + "train": [ + "huggingface_hub", + "datasets", + "tensorboard", + "wandb", + "matplotlib", + "ffmpeg-python", + "scipy", + "soundfile", + ] + }, + entry_points={"console_scripts": ["voicestar=voicestar.cli:run_inference"]}, +) diff --git a/steps/optim.py b/steps/optim.py index 88bd02b..ab1a6b6 100644 --- a/steps/optim.py +++ b/steps/optim.py @@ -80,7 +80,9 @@ def batched_params(self, param_group, group_params_names): list ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - assert len(param_group) == len(group_params_names), f"len(param_group): {len(param_group)}, len(group_params_names): {len(group_params_names)}" + assert len(param_group) == len( + group_params_names + ), f"len(param_group): {len(param_group)}, len(group_params_names): {len(group_params_names)}" for p, named_p in zip(param_group, group_params_names): key = (str(p.dtype), *p.shape) batches[key].append(p) @@ -90,9 +92,7 @@ def batched_params(self, param_group, group_params_names): sorted_idx = sorted( range(len(batches_names)), key=lambda i: batches_names_keys[i] ) - batches_names = [ - batches_names[batches_names_keys[idx]] for idx in sorted_idx - ] + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] stacked_params_dict = dict() @@ -110,10 +110,7 @@ def batched_params(self, param_group, group_params_names): state = self.state[p] p_stacked = torch.stack(batch) grad = torch.stack( - [ - torch.zeros_like(p) if p.grad is None else p.grad - for p in batch - ] + [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked @@ -121,7 +118,7 @@ def batched_params(self, param_group, group_params_names): yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): + for (stacked_params, _state, _names), batch in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -227,13 +224,9 @@ def step(self, closure=None): batch = True - for group, group_params_names in zip( - self.param_groups, self.parameters_names - ): + for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params( - group["params"], group_params_names - ) as batches: + with self.batched_params(group["params"], group_params_names) as batches: # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to @@ -286,9 +279,7 @@ def _init_state(self, group: dict, p: Tensor, state: dict): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) batch_size = p.shape[0] numel = p.numel() // batch_size @@ -298,9 +289,7 @@ def _init_state(self, group: dict, p: Tensor, state: dict): # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = ( - (p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) @@ -309,9 +298,7 @@ def _init_state(self, group: dict, p: Tensor, state: dict): ) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] @@ -340,16 +327,14 @@ def _get_clipping_scale( clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: + for p, state, param_names in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += ( - grad ** 2 - ).sum() # sum() to change shape [1] to [] + tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() @@ -428,11 +413,11 @@ def _show_gradient_dominating_parameter( from tuples, we still pass it to save some time. """ all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: + for p, state, batch_param_names in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad ** 2 + batch_sumsq_orig = batch_grad**2 # Dummpy values used by following `zip` statement. batch_rms_orig = torch.ones(p.shape[0]) else: @@ -510,9 +495,7 @@ def _step_one_batch( if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) param_rms.copy_( - (p ** 2) - .mean(dim=list(range(1, p.ndim)), keepdim=True) - .sqrt() + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() ) if step > 0: # self._size_update() learns the overall scale on the @@ -557,32 +540,23 @@ def _size_update( size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2 ** size_update_period + beta2_corr = beta2**size_update_period - scale_exp_avg_sq = state[ - "scale_exp_avg_sq" - ] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads ** 2).mean( - dim=0 - ), # mean over dim `size_update_period` + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr ** size_step + bias_correction2 = 1 - beta2_corr**size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = ( - -size_lr - * (bias_correction2 ** 0.5) - * scale_grads.sum(dim=0) - / denom - ) + scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom is_too_small = param_rms < param_min_rms is_too_large = param_rms > param_max_rms @@ -618,9 +592,7 @@ def _step(self, group: dict, p: Tensor, state: dict): exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - this_step = state["step"] - ( - state["zero_step"] if "zero_step" in state else 0 - ) + this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) if bias_correction2 < 0.99: # note: not in-place. @@ -670,9 +642,7 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError( - "{} is not an Optimizer".format(type(optimizer).__name__) - ) + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) self.optimizer = optimizer self.verbose = verbose @@ -793,10 +763,9 @@ def __init__( def get_lr(self): factor = ( - (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 + (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) - ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 ) warmup_factor = ( 1.0 @@ -883,17 +852,11 @@ def __init__( if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError( - "Invalid beta parameter at index 0: {}".format(betas[0]) - ) + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError( - "Invalid beta parameter at index 1: {}".format(betas[1]) - ) + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0 <= weight_decay <= 0.1: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay) - ) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -929,9 +892,7 @@ def step(self, closure=None): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError( - "AdamW does not support sparse gradients" - ) + raise RuntimeError("AdamW does not support sparse gradients") state = self.state[p] @@ -958,7 +919,7 @@ def step(self, closure=None): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( group["eps"] ) @@ -969,9 +930,7 @@ def step(self, closure=None): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > ( - target_rms * (p.numel() ** 0.5) - ) + is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -984,6 +943,7 @@ def step(self, closure=None): return loss + def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear @@ -1003,10 +963,10 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_( - ans.bias, -0.1 * initial_scale, 0.1 * initial_scale - ) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans + + def _test_scaled_adam(hidden_dim: int): import timeit @@ -1040,8 +1000,7 @@ def _test_scaled_adam(hidden_dim: int): 100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) - * output_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, ) for _ in range(20) ] diff --git a/steps/trainer.py b/steps/trainer.py index ee17c64..953416a 100644 --- a/steps/trainer.py +++ b/steps/trainer.py @@ -1,3 +1,12 @@ +""" +VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate + +GitHub: https://github.com/jasonppy/VoiceStar +License: MIT + +Copyright (c) 2025 Puyuan Peng +""" + import time, sys, subprocess, json, re from pathlib import Path import os, random @@ -14,77 +23,135 @@ import numpy as np from torch.utils.data.distributed import DistributedSampler import logging + # from data import librilight, gigaspeech, gigaspeech_waveform from data import combined_dataset -from models import voice_star - -from .trainer_utils import DistributedDynamicBatchSampler, StatefulDistributedSampler, StatefulSampler, AverageMeter, print_model_info +from voicestar import voicestar as voice_star # legacy compatability TODO: change + +from .trainer_utils import ( + DistributedDynamicBatchSampler, + StatefulDistributedSampler, + StatefulSampler, + AverageMeter, + print_model_info, +) from .optim import ScaledAdam, Eden import run_gen import wandb, socket + class Trainer: - + def __init__(self, args, world_size, rank, local_rank): self.start_time = time.time() self.args = args if self.args.val_max_num_tokens == None: - self.args.val_max_num_tokens = self.args.max_num_tokens + self.args.val_max_num_tokens = self.args.max_num_tokens self.world_size, self.rank, self.local_rank = world_size, rank, local_rank - self.device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + self.device = torch.device( + f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" + ) if self.rank == 0: self.writer = SummaryWriter(args.exp_dir) - self.wandb = wandb.init(project="voice_editor", name=args.exp_dir.split("/")[-1], config=args, dir=args.exp_dir, entity=self.args.wandb_entity) + self.wandb = wandb.init( + project="voice_editor", + name=args.exp_dir.split("/")[-1], + config=args, + dir=args.exp_dir, + entity=self.args.wandb_entity, + ) self.seed_everything(seed=self.args.seed) self.meters = self._setup_meters() self.progress, self.total_progress = self._setup_progress() - self.model, self.trainables, self.optim_states, self.scheduler_states, self.phn2num = self._setup_models() - - self.train_dataset_length, self.train_sampler, self.train_loader, self.valid_loader = self._setup_dataloader() # both are use DistributedSampler, train sampler is stateful + ( + self.model, + self.trainables, + self.optim_states, + self.scheduler_states, + self.phn2num, + ) = self._setup_models() + + ( + self.train_dataset_length, + self.train_sampler, + self.train_loader, + self.valid_loader, + ) = ( + self._setup_dataloader() + ) # both are use DistributedSampler, train sampler is stateful if self.args.num_steps != None: self.total_step = self.args.num_steps - self.args.num_epochs = math.ceil(self.total_step / math.floor(self.train_dataset_length / self.args.batch_size)) if not self.args.dynamic_batching else None + self.args.num_epochs = ( + math.ceil( + self.total_step + / math.floor(self.train_dataset_length / self.args.batch_size) + ) + if not self.args.dynamic_batching + else None + ) else: - self.total_step = int(math.floor(self.train_dataset_length / self.args.batch_size))*self.args.num_epochs + self.total_step = ( + int(math.floor(self.train_dataset_length / self.args.batch_size)) + * self.args.num_epochs + ) self.optimizer, self.scheduler = self._setup_optimizer() self.scaler = torch.cuda.amp.GradScaler() - self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.local_rank], find_unused_parameters=False) + self.model = torch.nn.parallel.DistributedDataParallel( + self.model, device_ids=[self.local_rank], find_unused_parameters=False + ) self.early_stop_accu_steps = 0 if self.rank == 0: if self.args.dynamic_batching: - logging.info(f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}") + logging.info( + f"max number of tokens per GPU in a training batch: {self.args.max_num_tokens}, max number of tokens per GPU in a inference batch: {self.args.val_max_num_tokens}" + ) else: logging.info(f"batch size (per gpu): {self.args.batch_size}") - - self.args.inference_every_n_steps = getattr(self.args, "inference_every_n_steps", self.args.val_every_n_steps*5) - assert self.args.inference_every_n_steps > self.args.val_every_n_steps and self.args.inference_every_n_steps % self.args.val_every_n_steps == 0, "inference_every_n_steps should be divisible by val_every_n_steps, otherwise the code will not get a chance to run inference" + + self.args.inference_every_n_steps = getattr( + self.args, "inference_every_n_steps", self.args.val_every_n_steps * 5 + ) + assert ( + self.args.inference_every_n_steps > self.args.val_every_n_steps + and self.args.inference_every_n_steps % self.args.val_every_n_steps == 0 + ), "inference_every_n_steps should be divisible by val_every_n_steps, otherwise the code will not get a chance to run inference" def train(self): flag = True skip_flag = False data_start_time = time.time() - if self.progress['step'] >= self.total_step: + if self.progress["step"] >= self.total_step: if self.rank == 0: self.writer.close() self.wandb.finish() return while flag: - self.train_sampler.set_epoch(self.progress['epoch']) + self.train_sampler.set_epoch(self.progress["epoch"]) for i, batch in enumerate(self.train_loader): - if len(batch['y_lens']) < self.args.gradient_accumulation_steps: + if len(batch["y_lens"]) < self.args.gradient_accumulation_steps: continue data_end_time = time.time() self.model.train() - if self.progress['step'] >= getattr(self.args, "uniform_weight_start_step", 1e50): - if self.progress['step'] == getattr(self.args, "uniform_weight_start_step", 1e50) and self.rank == 0: - logging.info("NOTE: start using uniform weight from step: {}".format(self.progress['step'])) - self.args.codebook_weight = [2.5,2,1.5,0.6] - self.model.module.args.codebook_weight = [2.5,2,1.5,0.6] + if self.progress["step"] >= getattr( + self.args, "uniform_weight_start_step", 1e50 + ): + if ( + self.progress["step"] + == getattr(self.args, "uniform_weight_start_step", 1e50) + and self.rank == 0 + ): + logging.info( + "NOTE: start using uniform weight from step: {}".format( + self.progress["step"] + ) + ) + self.args.codebook_weight = [2.5, 2, 1.5, 0.6] + self.model.module.args.codebook_weight = [2.5, 2, 1.5, 0.6] - if self.progress['step'] >= self.total_step: + if self.progress["step"] >= self.total_step: dist.barrier() flag = False self.validate_and_save() @@ -93,19 +160,26 @@ def train(self): self.wandb.finish() break if isinstance(self.scheduler, Eden): - self.scheduler.step_epoch(self.progress['step']//self.args.pseudo_epoch_size + 1) + self.scheduler.step_epoch( + self.progress["step"] // self.args.pseudo_epoch_size + 1 + ) if self.args.optimizer_name == "ScaledAdam": cur_lr = self.scheduler.get_last_lr()[0] else: - lrs = [param_group['lr'] for param_group in self.optimizer.param_groups] + lrs = [ + param_group["lr"] for param_group in self.optimizer.param_groups + ] assert lrs[0] == lrs[1] cur_lr = lrs[0] - if self.rank == 0 and self.progress['step'] % self.args.tb_write_every_n_steps == 0: - self.writer.add_scalar("train/lr", cur_lr, self.progress['step']) - self.wandb.log({"train/lr": cur_lr}, step=self.progress['step']) + if ( + self.rank == 0 + and self.progress["step"] % self.args.tb_write_every_n_steps == 0 + ): + self.writer.add_scalar("train/lr", cur_lr, self.progress["step"]) + self.wandb.log({"train/lr": cur_lr}, step=self.progress["step"]) - all_inds = list(range(len(batch['y']))) + all_inds = list(range(len(batch["y"]))) sum_losses = 0 sum_top10acc = 0 sum_ntoken = 0 @@ -116,28 +190,45 @@ def train(self): # therefore we re-calculate graduent_accumulation_steps based on the effective batch size if self.args.neighbor_prompt_prob > 0: - effective_batch_size = self.args.max_num_tokens // self.args.gradient_accumulation_steps - total_batch_size = sum(batch['y_lens']).item() - cur_gradient_accumulation_steps = max(self.args.gradient_accumulation_steps, total_batch_size // effective_batch_size) - gas = torch.tensor(cur_gradient_accumulation_steps, dtype=torch.int, device=self.local_rank) + effective_batch_size = ( + self.args.max_num_tokens + // self.args.gradient_accumulation_steps + ) + total_batch_size = sum(batch["y_lens"]).item() + cur_gradient_accumulation_steps = max( + self.args.gradient_accumulation_steps, + total_batch_size // effective_batch_size, + ) + gas = torch.tensor( + cur_gradient_accumulation_steps, + dtype=torch.int, + device=self.local_rank, + ) dist.all_reduce(gas, op=dist.ReduceOp.MAX) cur_gradient_accumulation_steps = gas.item() - len_batch = torch.tensor(len(batch['y']), dtype=torch.int, device=self.local_rank) + len_batch = torch.tensor( + len(batch["y"]), dtype=torch.int, device=self.local_rank + ) dist.all_reduce(len_batch, op=dist.ReduceOp.MIN) len_batch = len_batch.item() - cur_gradient_accumulation_steps = min(cur_gradient_accumulation_steps, len_batch) + cur_gradient_accumulation_steps = min( + cur_gradient_accumulation_steps, len_batch + ) # for those that cur_gradient_accumulation_steps * effective_batch_size < total_batch_size, we only use the first cur_gradient_accumulation_steps * effective_batch_size samples cur_len = 0 final_all_inds = [] pointer = 0 - while cur_len < self.args.max_num_tokens and pointer < len(all_inds): - cur_len += batch['y_lens'][pointer] + while cur_len < self.args.max_num_tokens and pointer < len( + all_inds + ): + cur_len += batch["y_lens"][pointer] final_all_inds.append(all_inds[pointer]) pointer += 1 all_inds = final_all_inds else: - cur_gradient_accumulation_steps = self.args.gradient_accumulation_steps - + cur_gradient_accumulation_steps = ( + self.args.gradient_accumulation_steps + ) sum_losses_local = 0.0 sum_top10acc_local = 0.0 @@ -159,12 +250,12 @@ def train(self): else: precision_used = torch.float32 - with torch.amp.autocast('cuda', dtype=precision_used): + with torch.amp.autocast("cuda", dtype=precision_used): out = self.model(cur_batch, calc_loss=True) if out is None: continue - if torch.isnan(out['loss']).any(): + if torch.isnan(out["loss"]).any(): local_nan_flag = torch.tensor(1, device=self.local_rank) else: local_nan_flag = torch.tensor(0, device=self.local_rank) @@ -174,56 +265,70 @@ def train(self): global_nan_flag = local_nan_flag.item() if global_nan_flag > 0: # Now *all* ranks break at the same j - logging.info(f"rank: {self.rank}. Loss at micro-batch {j} in step {self.progress['step']} was NaN on at least one rank; skipping.") + logging.info( + f"rank: {self.rank}. Loss at micro-batch {j} in step {self.progress['step']} was NaN on at least one rank; skipping." + ) break # Accumulate local values - record_loss = out['loss'].detach() - top10acc = out['top10acc'].detach() - effective_ntoken = out['effective_ntoken'].detach() + record_loss = out["loss"].detach() + top10acc = out["top10acc"].detach() + effective_ntoken = out["effective_ntoken"].detach() sum_losses_local += record_loss.item() sum_top10acc_local += top10acc.item() sum_ntoken_local += effective_ntoken.item() # Optional losses - if 'entropy_loss' in out: - sum_entropy_loss_local += out['entropy_loss'].detach().item() - if 'ctc_loss' in out: - sum_ctc_loss_local += out['ctc_loss'].detach().item() + if "entropy_loss" in out: + sum_entropy_loss_local += out["entropy_loss"].detach().item() + if "ctc_loss" in out: + sum_ctc_loss_local += out["ctc_loss"].detach().item() # Codebook accuracy - if 'top10acc_by_codebook' in out: + if "top10acc_by_codebook" in out: for cb in range(self.args.n_codebooks): - sum_top10acc_cbi_local[cb] += out['top10acc_by_codebook'][cb].detach().item() + sum_top10acc_cbi_local[cb] += ( + out["top10acc_by_codebook"][cb].detach().item() + ) # Backprop on this micro-batch if self.args.optimizer_name == "ScaledAdam": - self.scaler.scale(out['loss']).backward() + self.scaler.scale(out["loss"]).backward() else: - self.scaler.scale(out['loss'] / out['effective_ntoken']).backward() + self.scaler.scale( + out["loss"] / out["effective_ntoken"] + ).backward() if global_nan_flag > 0: # If *any* rank had NaN, skip this step - logging.info(f"rank: {self.rank}. Loss at one micro-batch in step {self.progress['step']} was NaN on at least one rank; skipping.") - self.progress['step'] += 1 - self.progress['cur_step'] += 1 + logging.info( + f"rank: {self.rank}. Loss at one micro-batch in step {self.progress['step']} was NaN on at least one rank; skipping." + ) + self.progress["step"] += 1 + self.progress["cur_step"] += 1 self.optimizer.zero_grad() continue # Otherwise, do one big reduce for the summed metrics - metrics_tensor = torch.tensor([ - sum_losses_local, - sum_top10acc_local, - sum_entropy_loss_local, - sum_ctc_loss_local, - sum_ntoken_local - ], device=self.local_rank, dtype=torch.float32) + metrics_tensor = torch.tensor( + [ + sum_losses_local, + sum_top10acc_local, + sum_entropy_loss_local, + sum_ctc_loss_local, + sum_ntoken_local, + ], + device=self.local_rank, + dtype=torch.float32, + ) dist.all_reduce(metrics_tensor, op=dist.ReduceOp.SUM) # Also reduce the codebook array in one shot if needed - codebook_tensor = torch.tensor(sum_top10acc_cbi_local, device=self.local_rank, dtype=torch.float32) + codebook_tensor = torch.tensor( + sum_top10acc_cbi_local, device=self.local_rank, dtype=torch.float32 + ) dist.all_reduce(codebook_tensor, op=dist.ReduceOp.SUM) # Convert them back to Python scalars @@ -237,166 +342,258 @@ def train(self): if self.args.optimizer_name != "ScaledAdam": self.scaler.unscale_(self.optimizer) - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip_val) + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.args.gradient_clip_val + ) self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() if self.args.optimizer_name == "ScaledAdam": - self.scheduler.step_batch(self.progress['step']) + self.scheduler.step_batch(self.progress["step"]) else: self.scheduler.step() - + # logging if self.rank == 0: average_loss = sum_losses / sum_ntoken average_top10acc = sum_top10acc / sum_ntoken - average_top10acc_cbi = [sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks for cb in range(self.args.n_codebooks)] - self.meters['train_loss'].update(average_loss, batch['x'].shape[0]*self.world_size) - self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) - self.meters['train_top10acc'].update(average_top10acc, batch['x'].shape[0]*self.world_size) + average_top10acc_cbi = [ + sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks + for cb in range(self.args.n_codebooks) + ] + self.meters["train_loss"].update( + average_loss, batch["x"].shape[0] * self.world_size + ) + self.meters["train_top10acc"].update( + average_top10acc, batch["x"].shape[0] * self.world_size + ) + self.meters["train_top10acc"].update( + average_top10acc, batch["x"].shape[0] * self.world_size + ) for cb in range(self.args.n_codebooks): - self.meters[f'train_top10acc_cb{cb+1}'].update(average_top10acc_cbi[cb], batch['x'].shape[0]*self.world_size) - self.meters['data_time'].update(data_end_time - data_start_time) - self.meters['train_time'].update(time.time() - data_end_time) + self.meters[f"train_top10acc_cb{cb+1}"].update( + average_top10acc_cbi[cb], + batch["x"].shape[0] * self.world_size, + ) + self.meters["data_time"].update(data_end_time - data_start_time) + self.meters["train_time"].update(time.time() - data_end_time) # log extra losses for key in sum_extra_losses: - if "train_"+key not in self.meters: - self.meters["train_"+key] = AverageMeter() - self.meters["train_"+key].update(sum(sum_extra_losses[key])/len(sum_extra_losses[key]), batch['x'].shape[0]*self.world_size) + if "train_" + key not in self.meters: + self.meters["train_" + key] = AverageMeter() + self.meters["train_" + key].update( + sum(sum_extra_losses[key]) / len(sum_extra_losses[key]), + batch["x"].shape[0] * self.world_size, + ) - if self.progress['step'] % self.args.tb_write_every_n_steps == 0: - self.writer.add_scalar('train/loss', average_loss, self.progress['step']) - self.writer.add_scalar('train/top10acc', average_top10acc, self.progress['step']) - self.writer.add_scalar("train/ntokens", sum_ntoken, self.progress['step']) - self.wandb.log({"train/loss": average_loss, "train/top10acc": average_top10acc, "train/ntokens": sum_ntoken, "train/data_time": data_end_time - data_start_time, "train/train_time": time.time() - data_end_time}, step=self.progress['step']) + if self.progress["step"] % self.args.tb_write_every_n_steps == 0: + self.writer.add_scalar( + "train/loss", average_loss, self.progress["step"] + ) + self.writer.add_scalar( + "train/top10acc", average_top10acc, self.progress["step"] + ) + self.writer.add_scalar( + "train/ntokens", sum_ntoken, self.progress["step"] + ) + self.wandb.log( + { + "train/loss": average_loss, + "train/top10acc": average_top10acc, + "train/ntokens": sum_ntoken, + "train/data_time": data_end_time - data_start_time, + "train/train_time": time.time() - data_end_time, + }, + step=self.progress["step"], + ) for cb in range(self.args.n_codebooks): - self.writer.add_scalar(f'train/top10acc_cb{cb+1}', average_top10acc_cbi[cb], self.progress['step']) - self.wandb.log({f'train/top10acc_cb{cb+1}': average_top10acc_cbi[cb]}, step=self.progress['step']) - self.writer.add_scalar("train/data_time", data_end_time - data_start_time, self.progress['step']) - self.writer.add_scalar("train/train_time", time.time() - data_end_time, self.progress['step']) + self.writer.add_scalar( + f"train/top10acc_cb{cb+1}", + average_top10acc_cbi[cb], + self.progress["step"], + ) + self.wandb.log( + {f"train/top10acc_cb{cb+1}": average_top10acc_cbi[cb]}, + step=self.progress["step"], + ) + self.writer.add_scalar( + "train/data_time", + data_end_time - data_start_time, + self.progress["step"], + ) + self.writer.add_scalar( + "train/train_time", + time.time() - data_end_time, + self.progress["step"], + ) # write extra losses for key in sum_extra_losses: - self.writer.add_scalar(f"train/{key}", sum(sum_extra_losses[key])/len(sum_extra_losses[key]), self.progress['step']) - self.wandb.log({f"train/{key}": sum(sum_extra_losses[key])/len(sum_extra_losses[key])}, step=self.progress['step']) + self.writer.add_scalar( + f"train/{key}", + sum(sum_extra_losses[key]) / len(sum_extra_losses[key]), + self.progress["step"], + ) + self.wandb.log( + { + f"train/{key}": sum(sum_extra_losses[key]) + / len(sum_extra_losses[key]) + }, + step=self.progress["step"], + ) # logging.info(f"ntoken: {sum_ntoken}") - + # logging - if self.progress['step'] % self.args.print_every_n_steps == 0: + if self.progress["step"] % self.args.print_every_n_steps == 0: log_out = {} - log_out['cur_epoch'] = f"{self.progress['epoch']}/{self.args.num_epochs}" if self.args.num_epochs is not None else f"{self.progress['epoch']}" - log_out['cur_step'] = f"{int(self.progress['cur_step']+1)}" - log_out['total_step'] = f"{self.progress['step']}/{self.args.num_steps}" - log_out['lr'] = f"{cur_lr:.7f}" - log_out['ntokens'] = f"{sum_ntoken}" + log_out["cur_epoch"] = ( + f"{self.progress['epoch']}/{self.args.num_epochs}" + if self.args.num_epochs is not None + else f"{self.progress['epoch']}" + ) + log_out["cur_step"] = f"{int(self.progress['cur_step']+1)}" + log_out["total_step"] = ( + f"{self.progress['step']}/{self.args.num_steps}" + ) + log_out["lr"] = f"{cur_lr:.7f}" + log_out["ntokens"] = f"{sum_ntoken}" for key in self.meters: if self.meters[key].val != 0 or self.meters[key].avg != 0: - log_out[key] = f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" if isinstance(self.meters[key].val, float) else f"{self.meters[key].val}" + log_out[key] = ( + f"{self.meters[key].val:.4f} ({self.meters[key].avg:.4f})" + if isinstance(self.meters[key].val, float) + else f"{self.meters[key].val}" + ) logging.info(log_out) - if np.isnan(self.meters['train_loss'].avg): + if np.isnan(self.meters["train_loss"].avg): logging.warning("training diverged...") raise RuntimeError("training diverged...") # save the model only - if self.progress['step'] % self.args.save_every_n_steps == 0: + if self.progress["step"] % self.args.save_every_n_steps == 0: dist.barrier() if self.rank == 0: - save_path = os.path.join(self.args.exp_dir,f"bundle_step{self.progress['step']}.pth") + save_path = os.path.join( + self.args.exp_dir, f"bundle_step{self.progress['step']}.pth" + ) self.save_progress(name=f"step{self.progress['step']}") torch.save( { "model": self.model.module.state_dict(), "args": self.args, "phn2num": self.train_loader.dataset.phn2num, - "optimizer": self.optimizer.state_dict(), + "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict(), - },save_path + }, + save_path, + ) + logging.info( + f"save model, optimizer, scheduler and progress at {save_path} at global step {self.progress['step']}" ) - logging.info(f"save model, optimizer, scheduler and progress at {save_path} at global step {self.progress['step']}") dist.barrier() - + # validation and save models - if self.progress['step'] % self.args.val_every_n_steps == 0: + if self.progress["step"] % self.args.val_every_n_steps == 0: dist.barrier() continue_training = self.validate_and_save() # broadcast continue_training to all processes, so that all processes gets into generation stage - continue_training = torch.tensor(int(continue_training), dtype=torch.int, device=self.local_rank) + continue_training = torch.tensor( + int(continue_training), dtype=torch.int, device=self.local_rank + ) dist.broadcast(continue_training, src=0) continue_training = bool(continue_training.item()) - dist.barrier() # need this to ensure all processes get to the next line? - logging.info(f"rank: {self.rank}, continue_training: {continue_training}") - if not continue_training: + dist.barrier() # need this to ensure all processes get to the next line? + logging.info( + f"rank: {self.rank}, continue_training: {continue_training}" + ) + if not continue_training: if self.rank == 0: self.writer.close() self.wandb.finish() flag = False break - self.progress['step'] += 1 - self.progress['cur_step'] += 1 + self.progress["step"] += 1 + self.progress["cur_step"] += 1 data_start_time = time.time() - self.progress['epoch'] += 1 - self.progress['cur_step'] = 0 # reset cur_step to be 0 + self.progress["epoch"] += 1 + self.progress["cur_step"] = 0 # reset cur_step to be 0 dist.destroy_process_group() def validate_and_save(self): self.model.eval() - + score = self.validate(self.valid_loader) if self.args.early_stop_threshold > 0: - if self.progress['best_score'] - score < self.args.early_stop_threshold: + if self.progress["best_score"] - score < self.args.early_stop_threshold: self.early_stop_accu_steps += self.args.val_every_n_steps - if self.early_stop_accu_steps >= self.args.early_stop_step-1: - logging.info(f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}") - logging.info(f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}") + if self.early_stop_accu_steps >= self.args.early_stop_step - 1: + logging.info( + f"early stop based on self.args.early_stop_threshold: {self.args.early_stop_threshold}, and self.args.early_stop_step: {self.args.early_stop_step}" + ) + logging.info( + f"best validation score at step: {self.progress['best_step']}, and the score is {self.progress['best_score']:.4f}" + ) return False else: self.early_stop_accu_steps = 0 if self.rank == 0: - save_path = os.path.join(self.args.exp_dir,"bundle.pth") + save_path = os.path.join(self.args.exp_dir, "bundle.pth") if os.path.isfile(save_path): os.system(f"mv {save_path} {save_path.replace('.pth', '_prev.pth')}") torch.save( { "model": self.model.module.state_dict(), - "optimizer": self.optimizer.state_dict(), + "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict(), "args": self.args, - "phn2num": self.train_loader.dataset.phn2num - },save_path + "phn2num": self.train_loader.dataset.phn2num, + }, + save_path, ) self.save_progress() - logging.info(f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}") - if (score < self.progress['best_score']): - self.progress['best_step'] = self.progress['step'] - self.progress['best_score'] = score - save_path = os.path.join(self.args.exp_dir,"best_bundle.pth") + logging.info( + f"save models, indices, acc and other statistics at {save_path} and {self.args.exp_dir}/progress.pkl at global step {self.progress['step']}" + ) + if score < self.progress["best_score"]: + self.progress["best_step"] = self.progress["step"] + self.progress["best_score"] = score + save_path = os.path.join(self.args.exp_dir, "best_bundle.pth") if os.path.isfile(save_path): - os.system(f"mv {save_path} {save_path.replace('.pth', '_prev.pth')}") + os.system( + f"mv {save_path} {save_path.replace('.pth', '_prev.pth')}" + ) torch.save( { "model": self.model.module.state_dict(), - "optimizer": self.optimizer.state_dict(), + "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict(), "args": self.args, - "phn2num": self.train_loader.dataset.phn2num - },save_path + "phn2num": self.train_loader.dataset.phn2num, + }, + save_path, + ) + logging.info( + f"save *best* models at {save_path} at global step {self.progress['step']}" ) - logging.info(f"save *best* models at {save_path} at global step {self.progress['step']}") - + # sync best score and best step, so that all processes early stop at the same time - best_score_tensor = torch.tensor(self.progress['best_score'], device=self.local_rank) + best_score_tensor = torch.tensor( + self.progress["best_score"], device=self.local_rank + ) dist.broadcast(best_score_tensor, src=0) - self.progress['best_score'] = float(best_score_tensor.item()) - best_step_tensor = torch.tensor(self.progress['best_step'], device=self.local_rank) + self.progress["best_score"] = float(best_score_tensor.item()) + best_step_tensor = torch.tensor( + self.progress["best_step"], device=self.local_rank + ) dist.broadcast(best_step_tensor, src=0) - self.progress['best_step'] = int(best_step_tensor.item()) + self.progress["best_step"] = int(best_step_tensor.item()) dist.barrier() return True @@ -419,29 +616,30 @@ def validate(self, valid_loader=None, hide_progress=True): with torch.no_grad(): for i, batch in enumerate(tqdm(valid_loader, disable=hide_progress)): - out = self.model(batch, calc_loss=True) # no reduction is applied to loss - sum_losses += out['loss'] - sum_top10acc += out['top10acc'] - sum_ntoken += out['effective_ntoken'] + out = self.model( + batch, calc_loss=True + ) # no reduction is applied to loss + sum_losses += out["loss"] + sum_top10acc += out["top10acc"] + sum_ntoken += out["effective_ntoken"] if "dur_loss" in out: - sum_dur_loss += out['dur_loss'] - sum_dur_acc += out['dur_acc'] + sum_dur_loss += out["dur_loss"] + sum_dur_acc += out["dur_acc"] if "entropy_loss" in out: - sum_entropy_loss += out['entropy_loss'] + sum_entropy_loss += out["entropy_loss"] if "ctc_loss" in out: - sum_ctc_loss += out['ctc_loss'] + sum_ctc_loss += out["ctc_loss"] # logging.info(f"iter {i}::: {sum_losses}, {sum_top10acc}, {sum_ntoken}") - if 'top10acc_by_codebook' in out: + if "top10acc_by_codebook" in out: for cb in range(self.args.n_codebooks): - sum_top10acc_cbi[cb] += out['top10acc_by_codebook'][cb] - - if 'perplexity_by_codebook' in out: + sum_top10acc_cbi[cb] += out["top10acc_by_codebook"][cb] + + if "perplexity_by_codebook" in out: for cb in range(self.args.n_codebooks): - mean_perplexity_cbi[cb] += out['perplexity_by_codebook'][cb] + mean_perplexity_cbi[cb] += out["perplexity_by_codebook"][cb] # if i > 10: # break - dist.all_reduce(sum_losses, op=dist.ReduceOp.SUM) dist.all_reduce(sum_top10acc, op=dist.ReduceOp.SUM) dist.all_reduce(sum_ntoken, op=dist.ReduceOp.SUM) @@ -453,110 +651,180 @@ def validate(self, valid_loader=None, hide_progress=True): if "ctc_loss" in out: dist.all_reduce(sum_ctc_loss, op=dist.ReduceOp.SUM) - if 'top10acc_by_codebook' in out: + if "top10acc_by_codebook" in out: for cb in range(self.args.n_codebooks): dist.all_reduce(sum_top10acc_cbi[cb], op=dist.ReduceOp.SUM) - - if 'perplexity_by_codebook' in out: + + if "perplexity_by_codebook" in out: for cb in range(self.args.n_codebooks): dist.all_reduce(mean_perplexity_cbi[cb], op=dist.ReduceOp.SUM) - + val_loss = sum_losses / sum_ntoken val_top10acc = sum_top10acc / sum_ntoken - + if self.rank == 0: if "dur_loss" in out: val_dur_loss = sum_dur_loss / sum_ntoken val_dur_acc = sum_dur_acc / sum_ntoken - self.meters['val_dur_loss'].update(val_dur_loss) + self.meters["val_dur_loss"].update(val_dur_loss) logging.info(f"val dur_loss: {val_dur_loss:.5f}") - self.meters['val_dur_acc'].update(val_dur_acc) + self.meters["val_dur_acc"].update(val_dur_acc) logging.info(f"val dur_acc: {val_dur_acc:.5f}") - self.writer.add_scalar("val/dur_loss", val_dur_loss, self.progress['step']) - self.writer.add_scalar("val/dur_acc", val_dur_acc, self.progress['step']) - self.wandb.log({"val/dur_loss": val_dur_loss, "val/dur_acc": val_dur_acc}, step=self.progress['step']) + self.writer.add_scalar( + "val/dur_loss", val_dur_loss, self.progress["step"] + ) + self.writer.add_scalar( + "val/dur_acc", val_dur_acc, self.progress["step"] + ) + self.wandb.log( + {"val/dur_loss": val_dur_loss, "val/dur_acc": val_dur_acc}, + step=self.progress["step"], + ) # logging - self.meters['val_loss'].update(val_loss) + self.meters["val_loss"].update(val_loss) logging.info(f"val loss: {val_loss:.5f}") - self.writer.add_scalar("val/loss", val_loss, self.progress['step']) - self.wandb.log({"val/loss": val_loss}, step=self.progress['step']) + self.writer.add_scalar("val/loss", val_loss, self.progress["step"]) + self.wandb.log({"val/loss": val_loss}, step=self.progress["step"]) - self.meters['val_top10acc'].update(val_top10acc) + self.meters["val_top10acc"].update(val_top10acc) logging.info(f"val top10acc: {val_top10acc:.5f}") - self.writer.add_scalar("val/top10acc", val_top10acc, self.progress['step']) - self.wandb.log({"val/top10acc": val_top10acc}, step=self.progress['step']) + self.writer.add_scalar("val/top10acc", val_top10acc, self.progress["step"]) + self.wandb.log({"val/top10acc": val_top10acc}, step=self.progress["step"]) for cb in range(self.args.n_codebooks): - average_top10acc_cbi = sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks - self.meters[f'val_top10acc_cb{cb+1}'].update(average_top10acc_cbi) - self.writer.add_scalar(f'val/top10acc_cb{cb+1}', average_top10acc_cbi, self.progress['step']) - self.wandb.log({f'val/top10acc_cb{cb+1}': average_top10acc_cbi}, step=self.progress['step']) - - temp = mean_perplexity_cbi[cb]/len(valid_loader) - self.writer.add_scalar(f'val/perplexity_cb{cb+1}', temp, self.progress['step']) - self.wandb.log({f'val/perplexity_cb{cb+1}': temp}, step=self.progress['step']) - - average_perplexity = sum(mean_perplexity_cbi)/(self.args.n_codebooks*len(valid_loader)) - self.wandb.log({"val/average_perplexity": average_perplexity}, step=self.progress['step']) - self.writer.add_scalar('val/average_perplexity', average_perplexity, self.progress['step']) - + average_top10acc_cbi = ( + sum_top10acc_cbi[cb] / sum_ntoken * self.args.n_codebooks + ) + self.meters[f"val_top10acc_cb{cb+1}"].update(average_top10acc_cbi) + self.writer.add_scalar( + f"val/top10acc_cb{cb+1}", + average_top10acc_cbi, + self.progress["step"], + ) + self.wandb.log( + {f"val/top10acc_cb{cb+1}": average_top10acc_cbi}, + step=self.progress["step"], + ) + + temp = mean_perplexity_cbi[cb] / len(valid_loader) + self.writer.add_scalar( + f"val/perplexity_cb{cb+1}", temp, self.progress["step"] + ) + self.wandb.log( + {f"val/perplexity_cb{cb+1}": temp}, step=self.progress["step"] + ) + + average_perplexity = sum(mean_perplexity_cbi) / ( + self.args.n_codebooks * len(valid_loader) + ) + self.wandb.log( + {"val/average_perplexity": average_perplexity}, + step=self.progress["step"], + ) + self.writer.add_scalar( + "val/average_perplexity", average_perplexity, self.progress["step"] + ) + # log entropy and ctc loss if "entropy_loss" in out: - val_entropy_loss = sum_entropy_loss / ((i+1) * self.world_size) - self.meters['val_entropy_loss'].update(val_entropy_loss) + val_entropy_loss = sum_entropy_loss / ((i + 1) * self.world_size) + self.meters["val_entropy_loss"].update(val_entropy_loss) logging.info(f"val entropy_loss: {val_entropy_loss:.5f}") - self.writer.add_scalar("val/entropy_loss", val_entropy_loss, self.progress['step']) - self.wandb.log({"val/entropy_loss": val_entropy_loss}, step=self.progress['step']) + self.writer.add_scalar( + "val/entropy_loss", val_entropy_loss, self.progress["step"] + ) + self.wandb.log( + {"val/entropy_loss": val_entropy_loss}, step=self.progress["step"] + ) if "ctc_loss" in out: - val_ctc_loss = sum_ctc_loss / ((i+1) * self.world_size) - self.meters['val_ctc_loss'].update(val_ctc_loss) + val_ctc_loss = sum_ctc_loss / ((i + 1) * self.world_size) + self.meters["val_ctc_loss"].update(val_ctc_loss) logging.info(f"val ctc_loss: {val_ctc_loss:.5f}") - self.writer.add_scalar("val/ctc_loss", val_ctc_loss, self.progress['step']) - self.wandb.log({"val/ctc_loss": val_ctc_loss}, step=self.progress['step']) + self.writer.add_scalar( + "val/ctc_loss", val_ctc_loss, self.progress["step"] + ) + self.wandb.log( + {"val/ctc_loss": val_ctc_loss}, step=self.progress["step"] + ) logging.info(f"validation takes: {time.time() - start_val_time:.2f}s") - logging.info(f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}") + logging.info( + f"Step [{self.progress['step']}/{self.total_step}]\t Time elapsed {(time.time() - self.start_time)/3600.:.2f}h, Val Loss: {val_loss:.4f}, Val Top10Acc: {val_top10acc:.4f}" + ) return val_loss.item() def _setup_meters(self): meters = {} - meter_names = ['train_loss', 'val_loss', 'train_top10acc', 'val_top10acc', 'data_time', 'train_time'] - meter_names += ['train_dur_loss', 'train_dur_acc', 'val_dur_loss', 'val_dur_acc'] - meter_names += ['val_perplexity'] - meter_names += [f'train_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] - meter_names += [f'val_top10acc_cb{cb+1}' for cb in range(self.args.n_codebooks)] - meter_names += [f'val_perplexity_cb{cb+1}' for cb in range(self.args.n_codebooks)] + meter_names = [ + "train_loss", + "val_loss", + "train_top10acc", + "val_top10acc", + "data_time", + "train_time", + ] + meter_names += [ + "train_dur_loss", + "train_dur_acc", + "val_dur_loss", + "val_dur_acc", + ] + meter_names += ["val_perplexity"] + meter_names += [ + f"train_top10acc_cb{cb+1}" for cb in range(self.args.n_codebooks) + ] + meter_names += [f"val_top10acc_cb{cb+1}" for cb in range(self.args.n_codebooks)] + meter_names += [ + f"val_perplexity_cb{cb+1}" for cb in range(self.args.n_codebooks) + ] for name in meter_names: meters[name] = AverageMeter() return meters + def _setup_progress(self): """ Need to customize it """ progress = {} - progress['best_step'] = 1 - progress['best_score'] = np.inf # this records loss value - progress['step'] = 1 - progress['epoch'] = 1 - progress['cur_step'] = 0 # step in the current epoch, for resuming the sampler + progress["best_step"] = 1 + progress["best_score"] = np.inf # this records loss value + progress["step"] = 1 + progress["epoch"] = 1 + progress["cur_step"] = 0 # step in the current epoch, for resuming the sampler total_progress = [] # if self.args.resume or self.args.validate: if self.args.resume: progress_pkl = "%s/progress.pkl" % self.args.exp_dir with open(progress_pkl, "rb") as f: total_progress = pickle.load(f) - progress['best_step'], progress['best_score'], progress['step'], progress['epoch'], progress['cur_step'], _ = total_progress[-1] + ( + progress["best_step"], + progress["best_score"], + progress["step"], + progress["epoch"], + progress["cur_step"], + _, + ) = total_progress[-1] if self.rank == 0: logging.info("\nResume training from:") - logging.info(" epoch = %s" % progress['epoch']) - logging.info(" cur_step = %s" % progress['cur_step']) - logging.info(" step = %s" % progress['step']) - logging.info(" best_step = %s" % progress['best_step']) - logging.info(" best_score = %s" % progress['best_score']) + logging.info(" epoch = %s" % progress["epoch"]) + logging.info(" cur_step = %s" % progress["cur_step"]) + logging.info(" step = %s" % progress["step"]) + logging.info(" best_step = %s" % progress["best_step"]) + logging.info(" best_score = %s" % progress["best_score"]) return progress, total_progress - + def save_progress(self, name=None): - self.total_progress.append([self.progress['best_step'], self.progress['best_score'], int(self.progress['step']+1), self.progress['epoch'], int(self.progress['cur_step']+1), time.time() - self.start_time]) + self.total_progress.append( + [ + self.progress["best_step"], + self.progress["best_score"], + int(self.progress["step"] + 1), + self.progress["epoch"], + int(self.progress["cur_step"] + 1), + time.time() - self.start_time, + ] + ) if name is not None: progress_fn = f"{self.args.exp_dir}/progress_{name}.pkl" else: @@ -565,17 +833,60 @@ def save_progress(self, name=None): pickle.dump(self.total_progress, f) def _setup_dataloader(self): - train_dataset, val_dataset = combined_dataset.dataset(self.args, 'train'), combined_dataset.dataset(self.args, 'valid') # need to change 'train' to 'valid' in actual training + train_dataset, val_dataset = combined_dataset.dataset( + self.args, "train" + ), combined_dataset.dataset( + self.args, "valid" + ) # need to change 'train' to 'valid' in actual training if self.args.dynamic_batching: - train_sampler = DistributedDynamicBatchSampler(train_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=train_dataset.lengths_list, verbose=True, epoch=0) - valid_sampler = DistributedDynamicBatchSampler(val_dataset, self.args, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True, lengths_list=val_dataset.lengths_list, verbose=True, epoch=0) + train_sampler = DistributedDynamicBatchSampler( + train_dataset, + self.args, + num_replicas=self.world_size, + rank=self.rank, + shuffle=True, + seed=self.args.seed, + drop_last=True, + lengths_list=train_dataset.lengths_list, + verbose=True, + epoch=0, + ) + valid_sampler = DistributedDynamicBatchSampler( + val_dataset, + self.args, + num_replicas=self.world_size, + rank=self.rank, + shuffle=True, + seed=self.args.seed, + drop_last=True, + lengths_list=val_dataset.lengths_list, + verbose=True, + epoch=0, + ) else: - train_sampler = StatefulDistributedSampler(train_dataset, self.args.batch_size//self.world_size, num_replicas=self.world_size, rank=self.rank, shuffle=True, seed=self.args.seed, drop_last=True) - valid_sampler = DistributedSampler(val_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, seed=self.args.seed, drop_last=False) - - if self.progress['step'] > 1: - train_sampler.set_epoch_resume(self.progress['epoch'], self.progress['cur_step']) + train_sampler = StatefulDistributedSampler( + train_dataset, + self.args.batch_size // self.world_size, + num_replicas=self.world_size, + rank=self.rank, + shuffle=True, + seed=self.args.seed, + drop_last=True, + ) + valid_sampler = DistributedSampler( + val_dataset, + num_replicas=self.world_size, + rank=self.rank, + shuffle=False, + seed=self.args.seed, + drop_last=False, + ) + + if self.progress["step"] > 1: + train_sampler.set_epoch_resume( + self.progress["epoch"], self.progress["cur_step"] + ) assert self.phn2num != None if self.phn2num != None: @@ -583,79 +894,109 @@ def _setup_dataloader(self): val_dataset.phn2num = self.phn2num if self.args.dynamic_batching: - train_loader = torch.utils.data.DataLoader(train_dataset, - batch_sampler=train_sampler, - num_workers=self.args.num_workers, - collate_fn=train_dataset.collate, persistent_workers=True - ) - valid_loader = torch.utils.data.DataLoader(val_dataset, - batch_sampler=valid_sampler, - num_workers=self.args.num_workers, - collate_fn=val_dataset.collate, persistent_workers=True - ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=self.args.num_workers, + collate_fn=train_dataset.collate, + persistent_workers=True, + ) + valid_loader = torch.utils.data.DataLoader( + val_dataset, + batch_sampler=valid_sampler, + num_workers=self.args.num_workers, + collate_fn=val_dataset.collate, + persistent_workers=True, + ) else: - train_loader = torch.utils.data.DataLoader(train_dataset, - batch_size=self.args.batch_size, sampler=train_sampler, num_workers=self.args.num_workers, - collate_fn=train_dataset.collate, persistent_workers=True - ) - valid_loader = torch.utils.data.DataLoader(val_dataset, - batch_size=self.args.batch_size, sampler=valid_sampler, - num_workers=self.args.num_workers, - collate_fn=val_dataset.collate, persistent_workers=True - ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=self.args.batch_size, + sampler=train_sampler, + num_workers=self.args.num_workers, + collate_fn=train_dataset.collate, + persistent_workers=True, + ) + valid_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=self.args.batch_size, + sampler=valid_sampler, + num_workers=self.args.num_workers, + collate_fn=val_dataset.collate, + persistent_workers=True, + ) return len(train_dataset), train_sampler, train_loader, valid_loader - - def _setup_models(self): - model = voice_star.VoiceStar(self.args) + model = voice_star.VoiceStarModel(self.args) if self.rank == 0: logging.info(model) logging.info("model parameters") print_model_info(model) - + phn2num = None optim_states = None scheduler_states = None - if self.progress['step'] > 1: - bundle = torch.load(os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu") - model.load_state_dict(bundle['model']) - optim_states = bundle['optimizer'] - scheduler_states = bundle['scheduler'] - phn2num = bundle['phn2num'] + if self.progress["step"] > 1: + bundle = torch.load( + os.path.join(self.args.exp_dir, "bundle.pth"), map_location="cpu" + ) + model.load_state_dict(bundle["model"]) + optim_states = bundle["optimizer"] + scheduler_states = bundle["scheduler"] + phn2num = bundle["phn2num"] if self.rank == 0: - logging.info("loaded parameters and data indices from epoch %d, global step %d" % (self.progress['epoch'], self.progress['step'])) - del bundle['model'] + logging.info( + "loaded parameters and data indices from epoch %d, global step %d" + % (self.progress["epoch"], self.progress["step"]) + ) + del bundle["model"] - if self.args.load_model_from != None and self.progress['step'] <= 1: + if self.args.load_model_from != None and self.progress["step"] <= 1: logging.info(f"load weights from {self.args.load_model_from}") sd = torch.load(self.args.load_model_from, map_location="cpu") if hasattr(model, "carefully_load_state_dict"): - model.carefully_load_state_dict(sd['model']) + model.carefully_load_state_dict(sd["model"]) else: - model.load_state_dict(sd['model']) - phn2num = sd['phn2num'] + model.load_state_dict(sd["model"]) + phn2num = sd["phn2num"] del sd - #### below operations is for getting params for optimizer, which is at wrapper level ### if self.args.optimizer_name == "ScaledAdam": trainables = [p for p in model.parameters() if p.requires_grad] else: - no_decay = [".bias", ".audio_embeddings.weight", ".text_embeddings.weight", ".norm.weight", ".norm1.weight", ".norm2.weight"] + no_decay = [ + ".bias", + ".audio_embeddings.weight", + ".text_embeddings.weight", + ".norm.weight", + ".norm1.weight", + ".norm2.weight", + ] optimizer_grouped_parameters = [ { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], + "params": [ + p + for n, p in model.named_parameters() + if not any(nd in n for nd in no_decay) and p.requires_grad + ], "weight_decay": self.args.weight_decay, }, { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], + "params": [ + p + for n, p in model.named_parameters() + if any(nd in n for nd in no_decay) and p.requires_grad + ], "weight_decay": 0.0, }, ] - if len(optimizer_grouped_parameters[1]['params']) == 0: - logging.info("there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names") + if len(optimizer_grouped_parameters[1]["params"]) == 0: + logging.info( + "there is no embedding weights, bias, and layernorm parameters in the model, which should be True, check model parameter names" + ) trainables = optimizer_grouped_parameters[0] else: trainables = optimizer_grouped_parameters @@ -664,12 +1005,17 @@ def _setup_models(self): return model, trainables, optim_states, scheduler_states, phn2num - def _setup_optimizer(self): if self.args.optimizer_name == "ScaledAdam": parameters_names = [] - _model = self.model.module if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) else self.model - parameters_names.append([n for n,p in self.model.named_parameters() if p.requires_grad]) + _model = ( + self.model.module + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) + else self.model + ) + parameters_names.append( + [n for n, p in self.model.named_parameters() if p.requires_grad] + ) optimizer = ScaledAdam( self.trainables, lr=self.args.lr, @@ -679,22 +1025,30 @@ def _setup_optimizer(self): show_dominant_parameters=False, clipping_update_period=self.args.clipping_update_period, ) - scheduler = Eden(optimizer, self.args.reduce_lr_start_step, self.args.reduce_lr_start_epoch, warmup_batches=self.total_step * self.args.warmup_fraction) # NOTE: if using ScaledAdam, we will use the Eden scheduler! + scheduler = Eden( + optimizer, + self.args.reduce_lr_start_step, + self.args.reduce_lr_start_epoch, + warmup_batches=self.total_step * self.args.warmup_fraction, + ) # NOTE: if using ScaledAdam, we will use the Eden scheduler! else: optimizer = AdamW(self.trainables, lr=self.args.lr) warmup_steps = self.total_step * self.args.warmup_fraction + def lr_lambda(current_step: int): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) return max( - 0.0, float(self.total_step - current_step) / float(max(1, self.total_step - warmup_steps)) + 0.0, + float(self.total_step - current_step) + / float(max(1, self.total_step - warmup_steps)), ) scheduler = LambdaLR(optimizer, lr_lambda, last_epoch=-1) - + # if resume - if self.progress['step'] > 1: + if self.progress["step"] > 1: optimizer.load_state_dict(self.optim_states) for state in optimizer.state.values(): for k, v in state.items(): @@ -706,12 +1060,12 @@ def lr_lambda(current_step: int): optimizer.zero_grad() return optimizer, scheduler - + def seed_everything(self, seed=1): - os.environ['PYTHONHASHSEED'] = str(seed) + os.environ["PYTHONHASHSEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True \ No newline at end of file + torch.backends.cudnn.deterministic = True diff --git a/steps/trainer_utils.py b/steps/trainer_utils.py index 51162ec..349770d 100644 --- a/steps/trainer_utils.py +++ b/steps/trainer_utils.py @@ -1,3 +1,11 @@ +""" +VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate + +GitHub: https://github.com/jasonppy/VoiceStar +License: MIT + +Copyright (c) 2025 Puyuan Peng +""" import torch import math @@ -9,8 +17,18 @@ from scipy.stats import lognorm import logging + class StatefulDistributedSampler(Sampler[int]): - def __init__(self, dataset, batch_size, num_replicas = None, rank = None, shuffle = True, seed = 0, drop_last = False): + def __init__( + self, + dataset, + batch_size, + num_replicas=None, + rank=None, + shuffle=True, + seed=0, + drop_last=False, + ): if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -22,7 +40,8 @@ def __init__(self, dataset, batch_size, num_replicas = None, rank = None, shuffl if rank >= num_replicas or rank < 0: raise ValueError( "Invalid rank {}, rank should be in the interval" - " [0, {}]".format(rank, num_replicas - 1)) + " [0, {}]".format(rank, num_replicas - 1) + ) self.dataset = dataset self.batch_size = batch_size self.num_replicas = num_replicas @@ -45,6 +64,7 @@ def __init__(self, dataset, batch_size, num_replicas = None, rank = None, shuffl self.shuffle = shuffle self.seed = seed self.continue_flag = False + def __len__(self): return self.num_samples @@ -73,25 +93,27 @@ def set_epoch(self, epoch): if padding_size <= len(indices): indices += indices[:padding_size] else: - indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] else: # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples self.indices = indices if self.continue_flag: - self.indices = self.indices[int(self.cur_step*self.batch_size):] + self.indices = self.indices[int(self.cur_step * self.batch_size) :] self.num_samples = len(self.indices) self.continue_flag = False - + def __iter__(self): for idx in self.indices: - yield idx + yield idx def set_epoch_resume(self, epoch, cur_step): self.epoch = epoch @@ -100,7 +122,9 @@ def set_epoch_resume(self, epoch, cur_step): class StatefulSampler(Sampler): - def __init__(self, data_source_length, batch_size, use_random=True, seed=1, epoch=0): + def __init__( + self, data_source_length, batch_size, use_random=True, seed=1, epoch=0 + ): self.use_random = use_random self.data_source_length = data_source_length self.num_samples = self.data_source_length @@ -129,10 +153,10 @@ def set_epoch(self, epoch): self.indices = list(range(self.data_source_length)) # type: ignore[arg-type] if self.continue_flag == True: self.continue_flag = False - self.indices = self.indices[int(self.cur_step*self.batch_size):] - + self.indices = self.indices[int(self.cur_step * self.batch_size) :] + self.num_samples = len(self.indices) - + def set_epoch_resume(self, epoch, cur_step): self.epoch = epoch self.cur_step = cur_step @@ -141,6 +165,7 @@ def set_epoch_resume(self, epoch, cur_step): class AverageMeter: """Computes and stores the average and current value""" + def __init__(self): self.reset() @@ -156,7 +181,8 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count -def print_model_info(model, print_model = False, print_params = True): + +def print_model_info(model, print_model=False, print_params=True): if print_model: logging.info(model) if print_params: @@ -284,11 +310,11 @@ def __init__( self, dataset, args, - num_replicas = None, - rank = None, - shuffle = True, - seed = 0, - drop_last = False, + num_replicas=None, + rank=None, + shuffle=True, + seed=0, + drop_last=False, length_func=lambda x: x["duration"], batch_ordering: str = "random", max_batch_ex: int = None, @@ -309,20 +335,26 @@ def __init__( if rank >= num_replicas or rank < 0: raise ValueError( "Invalid rank {}, rank should be in the interval" - " [0, {}]".format(rank, num_replicas - 1)) + " [0, {}]".format(rank, num_replicas - 1) + ) self.num_replicas = num_replicas self.rank = rank - max_batch_length = self.args.max_num_tokens if dataset.split == "train" else self.args.val_max_num_tokens + max_batch_length = ( + self.args.max_num_tokens + if dataset.split == "train" + else self.args.val_max_num_tokens + ) if type(lengths_list[0]) == float: max_batch_length = float(max_batch_length) / self.args.encodec_sr - logging.info(f"max_num_seconds per GPU for {dataset.split} split: {max_batch_length} seconds") + logging.info( + f"max_num_seconds per GPU for {dataset.split} split: {max_batch_length} seconds" + ) else: - logging.info(f"max_num_tokens per GPU for {dataset.split} split: {max_batch_length}") + logging.info( + f"max_num_tokens per GPU for {dataset.split} split: {max_batch_length}" + ) num_buckets = self.args.num_buckets ############# - - - self._dataset = dataset self._ex_lengths = {} @@ -336,8 +368,14 @@ def __init__( "Check the docs, and/or the tutorial !" ) assert lengths_list != None - max_len = int(self.args.audio_max_length * self.args.encodec_sr) if type(lengths_list[0]) == int else self.args.audio_max_length # if the length is in float, that means it's in seconds, otherwise it's in number of frames - lengths_list = [min(l, max_len) for l in lengths_list] # replace all utt whose length is longer than max_len to max_len, will also do this in __getitem__ in dataset + max_len = ( + int(self.args.audio_max_length * self.args.encodec_sr) + if type(lengths_list[0]) == int + else self.args.audio_max_length + ) # if the length is in float, that means it's in seconds, otherwise it's in number of frames + lengths_list = [ + min(l, max_len) for l in lengths_list + ] # replace all utt whose length is longer than max_len to max_len, will also do this in __getitem__ in dataset for indx in range(len(lengths_list)): self._ex_lengths[str(indx)] = lengths_list[indx] # if lengths_list is not None: @@ -361,9 +399,7 @@ def __init__( "All elements in bucket boundaries should be non-negative (>= 0)." ) if not len(set(bucket_boundaries)) == len(bucket_boundaries): - raise ValueError( - "Bucket_boundaries should not contain duplicates." - ) + raise ValueError("Bucket_boundaries should not contain duplicates.") np.testing.assert_array_equal( np.array(bucket_boundaries), np.array(sorted(bucket_boundaries)), @@ -399,21 +435,29 @@ def __init__( self._generate_batches() self.num_samples = int(math.floor(len(self._batches) / self.num_replicas)) self.total_size = int(self.num_samples * self.num_replicas) - self._replica_batches = self._batches[self.rank:self.total_size:self.num_replicas] - assert len(self._replica_batches) == self.num_samples, f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}" - logging.info(f"len(self._batches): {len(self._batches)}") # 285773 - logging.info(f"self.total_size: {self.total_size}") # 285768 - logging.info(f"self.num_samples: {self.num_samples}") # 35721 - logging.info(f"len(self._replica_batches): {len(self._replica_batches)}") # 35721 - logging.info(f"self.num_replicas: {self.num_replicas}") # 8 - #logging.info(f"num of batches on each replica: {self.num_samples}") # 35721 + self._replica_batches = self._batches[ + self.rank : self.total_size : self.num_replicas + ] + assert ( + len(self._replica_batches) == self.num_samples + ), f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}" + logging.info(f"len(self._batches): {len(self._batches)}") # 285773 + logging.info(f"self.total_size: {self.total_size}") # 285768 + logging.info(f"self.num_samples: {self.num_samples}") # 35721 + logging.info( + f"len(self._replica_batches): {len(self._replica_batches)}" + ) # 35721 + logging.info(f"self.num_replicas: {self.num_replicas}") # 8 + # logging.info(f"num of batches on each replica: {self.num_samples}") # 35721 def get_durations(self, batch): """Gets durations of the elements in the batch.""" return [self._ex_lengths[str(idx)] for idx in batch] def _get_boundaries_through_warping( - self, max_batch_length: int, num_quantiles: int, + self, + max_batch_length: int, + num_quantiles: int, ) -> List[int]: # NOTE: the following lines do not cover that there is only one example in the dataset @@ -423,7 +467,9 @@ def _get_boundaries_through_warping( num_boundaries = num_quantiles + 1 # create latent linearly equal spaced buckets latent_boundaries = np.linspace( - 1 / num_boundaries, num_quantiles / num_boundaries, num_quantiles, + 1 / num_boundaries, + num_quantiles / num_boundaries, + num_quantiles, ) # get quantiles using lognormal distribution quantiles = lognorm.ppf(latent_boundaries, 1) @@ -448,7 +494,9 @@ def _permute_batches(self): if self._batch_ordering == "random": # deterministically shuffle based on epoch and seed g = torch.Generator() - g.manual_seed(self._seed + self._epoch) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset + g.manual_seed( + self._seed + self._epoch + ) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset sampler = torch.randperm( len(self._batches), generator=g ).tolist() # type: ignore @@ -476,8 +524,10 @@ def _generate_batches(self): if self._shuffle_ex: # deterministically shuffle based on epoch and seed g = torch.Generator() - g.manual_seed(self._seed + self._epoch) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset - sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore + g.manual_seed( + self._seed + self._epoch + ) # since the random seed is based on self._seed and self._epoch, it should be the same for different processes when using DDP, and therefore the generated order should be the same across different process, this is important, because each replica will only take a portion of it, we want to make sure they take a non-overlapping portion, and all of them constitute the entire dataset + sampler = torch.randperm(len(self._dataset), generator=g).tolist() # type: ignore # pyp note: this is actually randomly permoted indices else: # take examples as they are: e.g. they have been sorted @@ -530,16 +580,23 @@ def _generate_batches(self): # put the 5 longest batches in the beginning # find 5 longest from self._ex_lengths, which is a dict of lengths, with key being index (in str), and value being length # sort the dict by value, and get the first 5 keys - self._batches = sorted( - self._batches, - key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), - reverse=True, - )[:5] + self._batches[5:] + self._batches = ( + sorted( + self._batches, + key=lambda x: max([self._ex_lengths[str(idx)] for idx in x]), + reverse=True, + )[:5] + + self._batches[5:] + ) if not dist.is_initialized() or dist.get_rank() == 0: - logging.info(f"replace the first 5 samples in the batch with the 5 longest samples: ") + logging.info( + f"replace the first 5 samples in the batch with the 5 longest samples: " + ) logging.info(f"their lengths: ") for i in range(5): - logging.info(f"{[self._ex_lengths[str(idx)] for idx in self._batches[i]]}") + logging.info( + f"{[self._ex_lengths[str(idx)] for idx in self._batches[i]]}" + ) # frames per batch & their padding remaining boundaries = [0] + self._bucket_boundaries.tolist() @@ -582,18 +639,11 @@ def _generate_batches(self): "pad_%": [], } for batch in self._batches: - tot_frames = sum( - [self._ex_lengths[str(idx)] for idx in batch] - ) + tot_frames = sum([self._ex_lengths[str(idx)] for idx in batch]) batch_stats["tot_frames"].append(tot_frames) - max_frames = max( - [self._ex_lengths[str(idx)] for idx in batch] - ) + max_frames = max([self._ex_lengths[str(idx)] for idx in batch]) tot_pad = sum( - [ - max_frames - self._ex_lengths[str(idx)] - for idx in batch - ] + [max_frames - self._ex_lengths[str(idx)] for idx in batch] ) batch_stats["tot_pad_frames"].append(tot_pad) batch_stats["pad_%"].append(tot_pad / tot_frames * 100) @@ -616,7 +666,6 @@ def __iter__(self): for batch in self._replica_batches: yield batch - # if self._shuffle_ex: # re-generate examples if ex_ordering == "random" # self._generate_batches() # if self._batch_ordering == "random": @@ -630,19 +679,22 @@ def set_epoch(self, epoch): """ self._epoch = epoch self._generate_batches() - self._replica_batches = self._batches[self.rank:self.total_size:self.num_replicas] + self._replica_batches = self._batches[ + self.rank : self.total_size : self.num_replicas + ] self.num_samples = int(math.floor(len(self._batches) / self.num_replicas)) - assert len(self._replica_batches) == self.num_samples, f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}" + assert ( + len(self._replica_batches) == self.num_samples + ), f"len(self._batches): {len(self._batches)}, self.total_size: {self.total_size}, self.num_samples: {self.num_samples},len(self._replica_batches): {len(self._replica_batches)}" if self.continue_flag: self.continue_flag = False - self._replica_batches = self._replica_batches[self._cur_step:] + self._replica_batches = self._replica_batches[self._cur_step :] self.num_samples = len(self._replica_batches) - def __len__(self): return self.num_samples - + def set_epoch_resume(self, epoch, cur_step): self.continue_flag = True self._epoch = epoch diff --git a/main.py b/train/train.py similarity index 68% rename from main.py rename to train/train.py index 1e5bbe0..438bd84 100644 --- a/main.py +++ b/train/train.py @@ -8,7 +8,8 @@ import torch.distributed as dist from config import MyParser from steps import trainer -from copy_codebase import copy_codebase +from voicestar.scripts.copy_codebase import copy_codebase + def world_info_from_env(): local_rank = int(os.environ["LOCAL_RANK"]) @@ -16,23 +17,25 @@ def world_info_from_env(): world_size = int(os.environ["WORLD_SIZE"]) return local_rank, global_rank, world_size + if __name__ == "__main__": - formatter = ( - "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" - ) + formatter = "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - + torch.cuda.empty_cache() args = MyParser().parse_args() exp_dir = Path(args.exp_dir) exp_dir.mkdir(exist_ok=True, parents=True) logging.info(f"exp_dir: {str(exp_dir)}") - if args.resume and (os.path.exists("%s/bundle.pth" % args.exp_dir) or os.path.exists("%s/bundle_prev.pth" % args.exp_dir)): + if args.resume and ( + os.path.exists("%s/bundle.pth" % args.exp_dir) + or os.path.exists("%s/bundle_prev.pth" % args.exp_dir) + ): if not os.path.exists("%s/bundle.pth" % args.exp_dir): os.system(f"cp {args.exp_dir}/bundle_prev.pth {args.exp_dir}/bundle.pth") resume = args.resume - assert(bool(args.exp_dir)) + assert bool(args.exp_dir) with open("%s/args.pkl" % args.exp_dir, "rb") as f: old_args = pickle.load(f) new_args = vars(args) @@ -46,15 +49,17 @@ def world_info_from_env(): args.resume = False with open("%s/args.pkl" % args.exp_dir, "wb") as f: pickle.dump(args, f) - + # make timeout longer (for generation) timeout = datetime.timedelta(seconds=7200) # 60 minutes if args.multinodes: _local_rank, _, _ = world_info_from_env() - dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout) + dist.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, timeout=timeout + ) else: - dist.init_process_group(backend='nccl', init_method='env://', timeout=timeout) + dist.init_process_group(backend="nccl", init_method="env://", timeout=timeout) if args.local_wandb: os.environ["WANDB_MODE"] = "offline" @@ -66,8 +71,10 @@ def world_info_from_env(): world_size = dist.get_world_size() local_rank = int(_local_rank) if args.multinodes else rank - num_devices= torch.cuda.device_count() - logging.info(f"{local_rank=}, {rank=}, {world_size=}, {type(local_rank)=}, {type(rank)=}, {type(world_size)=}") + num_devices = torch.cuda.device_count() + logging.info( + f"{local_rank=}, {rank=}, {world_size=}, {type(local_rank)=}, {type(rank)=}, {type(world_size)=}" + ) for device_idx in range(num_devices): device_name = torch.cuda.get_device_name(device_idx) logging.info(f"Device {device_idx}: {device_name}") @@ -76,7 +83,12 @@ def world_info_from_env(): if rank == 0: user_dir = os.path.expanduser("~") codebase_name = "VoiceStar" - now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - copy_codebase(os.path.join(user_dir, codebase_name), os.path.join(exp_dir, f"{codebase_name}_{now}"), max_size_mb=5, gitignore_path=os.path.join(user_dir, codebase_name, ".gitignore")) + now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + copy_codebase( + os.path.join(user_dir, codebase_name), + os.path.join(exp_dir, f"{codebase_name}_{now}"), + max_size_mb=5, + gitignore_path=os.path.join(user_dir, codebase_name, ".gitignore"), + ) my_trainer = trainer.Trainer(args, world_size, rank, local_rank) - my_trainer.train() \ No newline at end of file + my_trainer.train() diff --git a/voicestar/__init__.py b/voicestar/__init__.py new file mode 100644 index 0000000..f9ca486 --- /dev/null +++ b/voicestar/__init__.py @@ -0,0 +1,218 @@ +""" +VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate + +GitHub: https://github.com/jasonppy/VoiceStar +License: MIT + +Copyright (c) 2025 Puyuan Peng +""" + +__version__ = "0.1.0" + + +class VoiceStar: + """ + VoiceStar API - Easy-to-use Python API for VoiceStar model. + + This class provides an easy-to-use interface to the VoiceStar TTS model, + allowing you to generate speech in the voice of a reference audio sample. + + Example: + ```python + from voicestar import VoiceStar + + # Initialize the model (downloads from HF Hub if needed) + tts = VoiceStar(model_name="VoiceStar_840M_30s") + + # Generate speech in the voice of the reference audio + audio = tts.generate( + reference_speech="path/to/reference.wav", + text="This is the text I want to synthesize.", + target_duration=5.0 # Optional: specify desired duration in seconds + ) + + # Save the generated audio + audio.save("output.wav") + ``` + """ + + def __init__( + self, + model_name="VoiceStar_840M_30s", + device=None, + top_k=10, + top_p=1.0, + temperature=1.0, + repeat_prompt=1, + ): + """ + Initialize the VoiceStar TTS model. + + Args: + model_name (str): Model name to use. Options: + - "VoiceStar_840M_30s" (default): 840M parameter model that can generate up to 30s + - "VoiceStar_840M_40s": 840M parameter model that can generate up to 40s + device (str, optional): Device to run inference on. If None, will use CUDA if available, + then MPS (for Apple Silicon), then CPU. + top_k (int): Top-k sampling parameter. Higher values = more diversity. + top_p (float): Top-p sampling parameter (nucleus sampling). + temperature (float): Sampling temperature. Higher values = more diversity. + repeat_prompt (int): Number of times to repeat the prompt to improve speaker similarity. + """ + import os + import torch + from huggingface_hub import hf_hub_download + from argparse import Namespace + import voicestar.voicestar as voice_star + from data.tokenizer import AudioTokenizer, TextTokenizer + from voicestar.api import VoiceStarAPI + + # Set device + if device is None: + self.device = ( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + else: + self.device = device + + # Download and load model + torch.serialization.add_safe_globals([Namespace]) + ckpt_fn = hf_hub_download( + repo_id="pyp1/VoiceStar", filename=f"{model_name}.pth" + ) + + bundle = torch.load(ckpt_fn, map_location=self.device, weights_only=True) + self.args = bundle["args"] + self.phn2num = bundle["phn2num"] + self.model = voice_star.VoiceStarModel(self.args) + self.model.load_state_dict(bundle["model"]) + self.model.to(self.device) + self.model.eval() + + # Load tokenizers + if self.args.n_codebooks == 4: + signature = hf_hub_download( + repo_id="pyp1/VoiceCraft", filename="encodec_4cb2048_giga.th" + ) + elif self.args.n_codebooks == 8: + signature = hf_hub_download( + repo_id="pyp1/VoiceCraft", filename="encodec_8cb1024_giga.th" + ) + else: + raise ValueError(f"Invalid number of codebooks: {self.args.n_codebooks}") + + self.audio_tokenizer = AudioTokenizer(signature=signature) + self.text_tokenizer = TextTokenizer(backend="espeak") + + # Create API instance + self.api = VoiceStarAPI( + model=self.model, + model_args=self.args, + phn2num=self.phn2num, + text_tokenizer=self.text_tokenizer, + audio_tokenizer=self.audio_tokenizer, + device=self.device, + top_k=top_k, + top_p=top_p, + temperature=temperature, + ) + + # Store parameters + self.repeat_prompt = repeat_prompt + + def generate( + self, + reference_speech, + text, + reference_text=None, + target_duration=None, + output_path=None, + ): + """ + Generate speech in the voice of the reference audio. + + Args: + reference_speech (str): Path to reference speech audio file + text (str): Text to synthesize + reference_text (str, optional): Reference text transcript. If None, will use + Whisper to automatically transcribe the reference speech. + target_duration (float, optional): Target duration in seconds. If None, + will estimate based on reference speech and text length. + output_path (str, optional): If provided, saves the generated audio to this path + + Returns: + AudioSegment: The generated audio (can be saved with .save() method) + """ + import os + import torch + import torchaudio + import whisper + from voicestar.utils import estimate_duration + + # Transcribe reference speech if needed + if reference_text is None: + print( + "[Info] No reference_text provided, transcribing reference_speech with Whisper." + ) + wh_model = whisper.load_model("large-v3-turbo") + result = wh_model.transcribe(reference_speech) + reference_text = result["text"] + print(f"[Info] Whisper transcribed text: {reference_text}") + + # Estimate duration if not provided + if target_duration is None: + target_duration = estimate_duration(reference_speech, text) + print(f"[Info] Estimated target duration: {target_duration:.2f} seconds") + + # Get audio info for prompt_end_frame + info = torchaudio.info(reference_speech) + prompt_end_frame = int(100 * info.sample_rate) # 100 seconds max + + # Set delay pattern increment + delay_pattern_increment = self.args.n_codebooks + 1 + + # Generate audio + _, generated_audio = self.api.generate( + audio_fn=reference_speech, + target_text=text, + prompt_end_frame=prompt_end_frame, + target_generation_length=target_duration, + delay_pattern_increment=delay_pattern_increment, + prefix_transcript=reference_text, + repeat_prompt=self.repeat_prompt, + ) + + # Save if output_path provided + if output_path: + os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) + torchaudio.save(output_path, generated_audio[0].cpu(), 16000) + print(f"[Success] Generated audio saved to {output_path}") + + # Return audio segment for further manipulation + from dataclasses import dataclass + + @dataclass + class AudioSegment: + waveform: torch.Tensor + sample_rate: int = 16000 + + def save(self, path): + """Save audio to file""" + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + torchaudio.save(path, self.waveform.cpu(), self.sample_rate) + return path + + def play(self): + """Play audio (if in notebook environment)""" + try: + from IPython.display import Audio, display + + display(Audio(self.waveform.cpu().numpy().T, rate=self.sample_rate)) + except ImportError: + print( + "Audio playback requires IPython. Use .save() method instead." + ) + + return AudioSegment(generated_audio[0].cpu()) diff --git a/voicestar/api.py b/voicestar/api.py new file mode 100644 index 0000000..9ca7067 --- /dev/null +++ b/voicestar/api.py @@ -0,0 +1,399 @@ +""" +VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate + +GitHub: https://github.com/jasonppy/VoiceStar +License: MIT + +Copyright (c) 2025 Puyuan Peng +""" + +import argparse, pickle +import logging +import os, random +import numpy as np +import torch +import torchaudio + +from data.tokenizer import AudioTokenizer, TextTokenizer, tokenize_audio, tokenize_text +import argparse, time, tqdm + + +class VoiceStarAPI: + def __init__( + self, + model, + model_args, + phn2num, + text_tokenizer, + audio_tokenizer, + device="cuda", + codec_sr=50, + top_k=0, + top_p=0.8, + min_p=0.0, + temperature=1.0, + stop_repetition=-1, + kvcache=1, + silence_tokens="[1388,1898,131]", + quiet=False, + ): + """ + Initialize the VoiceStar API with model and configuration. + + Args: + model: The VoiceStar model + model_args: Model arguments + phn2num: Phoneme to number mapping + text_tokenizer: Text tokenizer + audio_tokenizer: Audio tokenizer + device: Device to run inference on + codec_sr: Codec sample rate + top_k, top_p, min_p, temperature: Sampling parameters + stop_repetition: Stop generation when token repeats this many times + kvcache: Whether to use KV cache for faster inference + silence_tokens: Tokens representing silence + quiet: Whether to suppress logging + """ + self.model = model + self.model_args = model_args + self.phn2num = phn2num + self.text_tokenizer = text_tokenizer + self.audio_tokenizer = audio_tokenizer + self.device = device + self.quiet = quiet + + # Default decode config + self.decode_config = { + "codec_sr": codec_sr, + "top_k": top_k, + "top_p": top_p, + "min_p": min_p, + "temperature": temperature, + "stop_repetition": stop_repetition, + "kvcache": kvcache, + "silence_tokens": silence_tokens, + "sample_batch_size": 1, + } + + @torch.no_grad() + def generate( + self, + audio_fn, + target_text, + prompt_end_frame, + target_generation_length, + delay_pattern_increment, + prefix_transcript=None, + repeat_prompt=0, + decode_config=None, + multi_trial=[], + ): + """ + Generate speech for the given text using the reference audio. + + Args: + audio_fn: Path to reference audio file + target_text: Text to synthesize + prompt_end_frame: Number of frames to use from reference audio + target_generation_length: Target length of generated audio in seconds + delay_pattern_increment: Delay pattern increment + prefix_transcript: Optional transcript of the reference audio + repeat_prompt: Number of times to repeat the prompt (or "max") + decode_config: Optional custom decoding configuration + multi_trial: List for multi-trial inference (usually empty) + + Returns: + tuple: (concatenated_sample, generated_sample) + """ + # Use custom decode config if provided, otherwise use default + if decode_config is None: + decode_config = self.decode_config + + # encode audio + encoded_frames = tokenize_audio( + self.audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame + ) + single_encoded_frames = encoded_frames + + if isinstance(repeat_prompt, int) and repeat_prompt > 0: + cur_repeat_prompt = repeat_prompt + while cur_repeat_prompt > 0: + encoded_frames = torch.cat( + [encoded_frames, single_encoded_frames], dim=2 + ) + cur_repeat_prompt -= 1 + elif isinstance(repeat_prompt, str) and repeat_prompt.lower() == "max": + repeat_prompt = 0 + while ( + encoded_frames.shape[2] + + decode_config["codec_sr"] * target_generation_length + + delay_pattern_increment + + single_encoded_frames.shape[2] + < self.model_args.audio_max_length * decode_config["codec_sr"] + ): + encoded_frames = torch.cat( + [encoded_frames, single_encoded_frames], dim=2 + ) + repeat_prompt += 1 + if getattr(self.model_args, "y_sep_token", None) != None: + encoded_frames = torch.cat( + [ + encoded_frames, + torch.LongTensor( + [self.model_args.y_sep_token] * self.model_args.n_codebooks + ) + .unsqueeze(0) + .unsqueeze(2) + .to(encoded_frames.device), + ], + dim=2, + ) + original_audio = encoded_frames.transpose(2, 1) # [1,T,K] + assert ( + original_audio.ndim == 3 + and original_audio.shape[0] == 1 + and original_audio.shape[2] == self.model_args.n_codebooks + ), original_audio.shape + + # phonemize + if isinstance(target_text, list): + text_tokens = [ + self.phn2num[phn] for phn in target_text if phn in self.phn2num + ] + else: + text_tokens = [ + self.phn2num[phn] + for phn in tokenize_text(self.text_tokenizer, text=target_text.strip()) + if phn in self.phn2num + ] + if getattr(self.model_args, "x_sep_token", None) != None: + assert ( + prefix_transcript != None + ), "prefix_transcript must be provided if x_sep_token is not None" + if prefix_transcript is not None: + if isinstance(prefix_transcript, list): + prefix_tokens = [ + self.phn2num[phn] + for phn in prefix_transcript + if phn in self.phn2num + ] + else: + prefix_tokens = [ + self.phn2num[phn] + for phn in tokenize_text( + self.text_tokenizer, text=prefix_transcript.strip() + ) + if phn in self.phn2num + ] + single_prefix_tokens = prefix_tokens + repeat_prompt_count = repeat_prompt + while repeat_prompt_count > 0: + prefix_tokens = prefix_tokens + single_prefix_tokens + repeat_prompt_count -= 1 + if getattr(self.model_args, "x_sep_token", None) != None: + text_tokens = ( + prefix_tokens + + [getattr(self.model_args, "x_sep_token", None)] + + text_tokens + ) + else: + text_tokens = prefix_tokens + text_tokens + if getattr(self.model_args, "add_eos_to_text", 0) != 0: + text_tokens.append(self.model_args.add_eos_to_text) + if getattr(self.model_args, "add_bos_to_text", 0) != 0: + text_tokens = [self.model_args.add_bos_to_text] + text_tokens + text_tokens = torch.LongTensor(text_tokens).unsqueeze(0) + text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]]) + + if not self.quiet: + logging.info( + f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec." + ) + + if getattr(self.model_args, "parallel_pattern", 0) != 0: + tgt_y_lens = torch.LongTensor( + [ + int( + original_audio.shape[1] + + decode_config["codec_sr"] * target_generation_length + + 2 + ) + ] + ) # parallel pattern, therefore only add the empty_token (i.e. the sos token) and eos (i.e. 2 more tokens). Note that the delayed pattern between, both sos and eos is counted (sos is counted in the n_codebooks, eos is counted in the 1) + else: + tgt_y_lens = torch.LongTensor( + [ + int( + original_audio.shape[1] + + decode_config["codec_sr"] * target_generation_length + + delay_pattern_increment + ) + ] + ) # delay pattern increment has accounted for the added eos + + # forward + assert decode_config["sample_batch_size"] <= 1 + stime = time.time() + assert multi_trial == [] + if not self.quiet: + logging.info(f"running inference with batch size 1") + concat_frames, gen_frames = self.model.inference_tts( + text_tokens.to(self.device), + text_tokens_lens.to(self.device), + original_audio[..., : self.model_args.n_codebooks].to( + self.device + ), # [1,T,8] + tgt_y_lens=tgt_y_lens.to(self.device), + top_k=decode_config["top_k"], + top_p=decode_config["top_p"], + min_p=decode_config["min_p"], + temperature=decode_config["temperature"], + stop_repetition=decode_config["stop_repetition"], + kvcache=decode_config["kvcache"], + silence_tokens=( + eval(decode_config["silence_tokens"]) + if type(decode_config["silence_tokens"]) == str + else decode_config["silence_tokens"] + ), + ) # output is [1,K,T] + if not self.quiet: + logging.info( + f"inference on one sample take: {time.time() - stime:.4f} sec." + ) + logging.info( + f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec." + ) + + if getattr(self.model_args, "y_sep_token", None) != None: + concat_frames = torch.cat( + [ + concat_frames[:, :, : original_audio.shape[1] - 1], + concat_frames[:, :, original_audio.shape[1] :], + ], + dim=2, + ) + # Handle MPS device compatibility + if self.device == "mps": + # Move tensors to CPU before decoding to avoid MPS placeholder storage error + concat_frames = concat_frames.cpu() + gen_frames = gen_frames.cpu() + concat_sample = self.audio_tokenizer.decode(concat_frames) # [1,8,T] + gen_sample = self.audio_tokenizer.decode(gen_frames) + + # Empty cuda cache between runs + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return concat_sample, gen_sample + + +# this script only works for the musicgen architecture +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file") + parser.add_argument("--audio_root", type=str, default="path/to/audio_folder") + parser.add_argument("--exp_dir", type=str, default="path/to/model_folder") + parser.add_argument("--seed", type=int, default=1) + parser.add_argument( + "--codec_audio_sr", + type=int, + default=16000, + help="the sample rate of audio that the codec is trained for", + ) + parser.add_argument( + "--codec_sr", type=int, default=50, help="the sample rate of the codec codes" + ) + parser.add_argument("--top_k", type=int, default=0, help="sampling param") + parser.add_argument("--top_p", type=float, default=0.8, help="sampling param") + parser.add_argument("--temperature", type=float, default=1.0, help="sampling param") + parser.add_argument("--output_dir", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument( + "--signature", type=str, default=None, help="path to the encodec model" + ) + parser.add_argument("--crop_concat", type=int, default=0) + parser.add_argument( + "--stop_repetition", + type=int, + default=-1, + help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it", + ) + parser.add_argument( + "--kvcache", + type=int, + default=1, + help="if true, use kv cache, which is 4-8x faster than without", + ) + parser.add_argument( + "--sample_batch_size", + type=int, + default=1, + help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation", + ) + parser.add_argument( + "--silence_tokens", + type=str, + default="[1388,1898,131]", + help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default", + ) + return parser.parse_args() + + +@torch.no_grad() +def inference_one_sample( + model, + model_args, + phn2num, + text_tokenizer, + audio_tokenizer, + audio_fn, + target_text, + device, + decode_config, + prompt_end_frame, + target_generation_length, + delay_pattern_increment, + prefix_transcript=None, + quiet=False, + repeat_prompt=0, + multi_trial=[], +): + """ + Backward compatibility function that uses the VoiceStarAPI class. + + This function has the same signature as the original inference_one_sample + but internally uses the new VoiceStarAPI class. + """ + # Create API instance + api = VoiceStarAPI( + model=model, + model_args=model_args, + phn2num=phn2num, + text_tokenizer=text_tokenizer, + audio_tokenizer=audio_tokenizer, + device=device, + codec_sr=decode_config.get("codec_sr", 50), + top_k=decode_config.get("top_k", 0), + top_p=decode_config.get("top_p", 0.8), + min_p=decode_config.get("min_p", 0.0), + temperature=decode_config.get("temperature", 1.0), + stop_repetition=decode_config.get("stop_repetition", -1), + kvcache=decode_config.get("kvcache", 1), + silence_tokens=decode_config.get("silence_tokens", "[1388,1898,131]"), + quiet=quiet, + ) + + # Call the generate method + return api.generate( + audio_fn=audio_fn, + target_text=target_text, + prompt_end_frame=prompt_end_frame, + target_generation_length=target_generation_length, + delay_pattern_increment=delay_pattern_increment, + prefix_transcript=prefix_transcript, + repeat_prompt=repeat_prompt, + decode_config=decode_config, + multi_trial=multi_trial, + ) diff --git a/voicestar/cli.py b/voicestar/cli.py new file mode 100644 index 0000000..ab80402 --- /dev/null +++ b/voicestar/cli.py @@ -0,0 +1,236 @@ +import os +import torch +import torchaudio +import numpy as np +import random +import whisper +import click +from argparse import Namespace + +from data.tokenizer import ( + AudioTokenizer, + TextTokenizer, +) + +import voicestar.voicestar as voice_star +from voicestar.api import inference_one_sample +from huggingface_hub import hf_hub_download +from transformers import pipeline + +from voicestar.utils import seed_everything, estimate_duration + +############################################################ +# Main Inference Function +############################################################ + + +@click.command() +@click.option( + "--reference-speech", + default="./demo/5895_34622_000026_000002.wav", + help="Path to reference speech audio file", +) +@click.option( + "--target-text", + default="I cannot believe that the same model can also do text to speech synthesis too! And you know what? this audio is 8 seconds long.", + help="Text to synthesize", +) +@click.option( + "--model-name", + default="VoiceStar_840M_30s", + help="Model name (VoiceStar_840M_30s or VoiceStar_840M_40s)", +) +@click.option( + "--reference-text", + default=None, + help="Reference text (if None, will use Whisper to transcribe)", +) +@click.option( + "--target-duration", + default=None, + type=float, + help="Target duration in seconds (if None, will estimate)", +) +@click.option( + "--codec-audio-sr", default=16000, help="Codec audio sample rate (do not change)" +) +@click.option("--codec-sr", default=50, help="Codec sample rate (do not change)") +@click.option( + "--top-k", default=10, help="Top-k sampling parameter (try 10, 20, 30, 40)" +) +@click.option("--top-p", default=1.0, help="Top-p sampling parameter (do not change)") +@click.option("--min-p", default=1.0, help="Min-p sampling parameter (do not change)") +@click.option("--temperature", default=1.0, help="Sampling temperature") +@click.option("--kvcache", default=1, help="Use KV cache (set to 0 if OOM)") +@click.option( + "--repeat-prompt", default=1, help="Repeat prompt to improve speaker similarity" +) +@click.option( + "--stop-repetition", default=3, help="Stop repetition parameter (will not use it)" +) +@click.option( + "--sample-batch-size", default=1, help="Sample batch size (do not change)" +) +@click.option("--seed", default=1, help="Random seed") +@click.option("--output-dir", default="./generated_tts", help="Output directory") +@click.option("--cut-off-sec", default=100, help="Cut-off seconds (do not adjust)") +def run_inference( + reference_speech, + target_text, + model_name, + reference_text, + target_duration, + codec_audio_sr, + codec_sr, + top_k, + top_p, + min_p, + temperature, + kvcache, + repeat_prompt, + stop_repetition, + sample_batch_size, + seed, + output_dir, + cut_off_sec, +): + """ + VoiceStar TTS inference CLI. + + Example: + voicestar --reference-speech "./demo/5895_34622_000026_000002.wav" \ + --target-text "I cannot believe ... this audio is 10 seconds long." \ + --reference-text "Optional text to use as prefix" \ + --target-duration 10.0 + """ + # Default values for parameters not exposed in click options + silence_tokens = None + multi_trial = None + + # Seed everything + seed_everything(seed) + + # Load model, phn2num, and args + torch.serialization.add_safe_globals([Namespace]) + device = ( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) # MPS support + ckpt_fn = hf_hub_download(repo_id="pyp1/VoiceStar", filename=f"{model_name}.pth") + + bundle = torch.load(ckpt_fn, map_location=device, weights_only=True) + args = bundle["args"] + phn2num = bundle["phn2num"] + model = voice_star.VoiceStarModel(args) + model.load_state_dict(bundle["model"]) + model.to(device) + model.eval() + + # If reference_text not provided, use whisper large-v3-turbo + if reference_text is None: + print( + "[Info] No reference_text provided, transcribing reference_speech with Whisper." + ) + wh_model = whisper.load_model("large-v3-turbo") + result = wh_model.transcribe(reference_speech) + prefix_transcript = result["text"] + print(f"[Info] Whisper transcribed text: {prefix_transcript}") + else: + prefix_transcript = reference_text + + # If target_duration not provided, estimate from reference speech + target_text + if target_duration is None: + target_generation_length = estimate_duration(reference_speech, target_text) + print( + f"[Info] target_duration not provided, estimated as {target_generation_length:.2f} seconds. If not desired, please provide a target_duration." + ) + else: + target_generation_length = float(target_duration) + + # signature from snippet + if args.n_codebooks == 4: + # signature = "./pretrained/encodec_6f79c6a8.th" + signature = hf_hub_download( + repo_id="pyp1/VoiceCraft", filename="encodec_4cb2048_giga.th" + ) # not sure if this is the right signature + elif args.n_codebooks == 8: + signature = hf_hub_download( + repo_id="pyp1/VoiceCraft", filename="encodec_8cb1024_giga.th" + ) + else: + # fallback, just use the 6-f79c6a8 + raise ValueError(f"Invalid number of codebooks: {args.n_codebooks}") + # not sure where to download 6-f79c6a8 from + # signature = "./pretrained/encodec_6f79c6a8.th" + + if silence_tokens is None: + # default from snippet + silence_tokens = [] + + if multi_trial is None: + # default from snippet + multi_trial = [] + + delay_pattern_increment = args.n_codebooks + 1 # from snippet + + # We can compute prompt_end_frame if we want, from snippet + info = torchaudio.info(reference_speech) + prompt_end_frame = int(cut_off_sec * info.sample_rate) + + # Prepare tokenizers + audio_tokenizer = AudioTokenizer(signature=signature) + text_tokenizer = TextTokenizer(backend="espeak") + + # decode_config from snippet + decode_config = { + "top_k": top_k, + "top_p": top_p, + "min_p": min_p, + "temperature": temperature, + "stop_repetition": stop_repetition, + "kvcache": kvcache, + "codec_audio_sr": codec_audio_sr, + "codec_sr": codec_sr, + "silence_tokens": silence_tokens, + "sample_batch_size": sample_batch_size, + } + + # Run inference + print("[Info] Running TTS inference...") + concated_audio, gen_audio = inference_one_sample( + model, + args, + phn2num, + text_tokenizer, + audio_tokenizer, + reference_speech, + target_text, + device, + decode_config, + prompt_end_frame=prompt_end_frame, + target_generation_length=target_generation_length, + delay_pattern_increment=delay_pattern_increment, + prefix_transcript=prefix_transcript, + multi_trial=multi_trial, + repeat_prompt=repeat_prompt, + ) + + # The model returns a list of waveforms, pick the first + concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu() + + # Save the audio (just the generated portion, as the snippet does) + os.makedirs(output_dir, exist_ok=True) + out_filename = "generated.wav" + out_path = os.path.join(output_dir, out_filename) + torchaudio.save(out_path, gen_audio, codec_audio_sr) + + print(f"[Success] Generated audio saved to {out_path}") + + +def main(): + run_inference() + + +if __name__ == "__main__": + main() diff --git a/models/modules/__init__.py b/voicestar/data/__init__.py similarity index 100% rename from models/modules/__init__.py rename to voicestar/data/__init__.py diff --git a/voicestar/data/encodec.py b/voicestar/data/encodec.py new file mode 100644 index 0000000..fc74ecc --- /dev/null +++ b/voicestar/data/encodec.py @@ -0,0 +1,1852 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Compression models or wrapper around existing models. +Also defines the main interface that a model must follow to be usable as an audio tokenizer. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +import logging +import math +from pathlib import Path +import typing as tp + +import numpy as np +import torch +from torch import nn +from torch import einsum +import torch.nn.functional as F +from torch.nn.utils import spectral_norm, weight_norm + +import logging +import warnings +from einops import rearrange, repeat +import omegaconf + +# import flashy + +CONV_NORMALIZATIONS = frozenset( + ["none", "weight_norm", "spectral_norm", "time_group_norm"] +) + + +def dict_from_config(cfg: omegaconf.DictConfig) -> dict: + """Convenience function to map an omegaconf configuration to a dictionary. + + Args: + cfg (omegaconf.DictConfig): Original configuration to map to dict. + Returns: + dict: Config as dictionary object. + """ + dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) + assert isinstance(dct, dict) + return dct + + +@dataclass +class QuantizedResult: + x: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class BaseQuantizer(nn.Module): + """Base class for quantizers.""" + + def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: + """ + Given input tensor x, returns first the quantized (or approximately quantized) + representation along with quantized codes, bandwidth, and any penalty term for the loss. + Finally, this returns a dict of metrics to update logging etc. + Frame rate must be passed so that the bandwidth is properly computed. + """ + raise NotImplementedError() + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth.""" + raise NotImplementedError() + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + raise NotImplementedError() + + @property + def total_codebooks(self): + """Total number of codebooks.""" + raise NotImplementedError() + + @property + def num_codebooks(self): + """Number of active codebooks.""" + raise NotImplementedError() + + def set_num_codebooks(self, n: int): + """Set the number of active codebooks.""" + raise NotImplementedError() + + +class CompressionModel(ABC, nn.Module): + """Base API for all compression model that aim at being used as audio tokenizers + with a language model. + """ + + @abstractmethod + def forward(self, x: torch.Tensor) -> QuantizedResult: ... + + @abstractmethod + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """See `EncodecModel.encode`.""" + ... + + @abstractmethod + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + """See `EncodecModel.decode`.""" + ... + + @abstractmethod + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + ... + + @property + @abstractmethod + def channels(self) -> int: ... + + @property + @abstractmethod + def frame_rate(self) -> float: ... + + @property + @abstractmethod + def sample_rate(self) -> int: ... + + @property + @abstractmethod + def cardinality(self) -> int: ... + + @property + @abstractmethod + def num_codebooks(self) -> int: ... + + @property + @abstractmethod + def total_codebooks(self) -> int: ... + + @abstractmethod + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer.""" + ... + + +def apply_parametrization_norm(module: nn.Module, norm: str = "none"): + assert norm in CONV_NORMALIZATIONS + if norm == "weight_norm": + return weight_norm(module) + elif norm == "spectral_norm": + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module( + module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs +): + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == "time_group_norm": + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d( + x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0 +): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d( + x: torch.Tensor, + paddings: tp.Tuple[int, int], + mode: str = "constant", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose1d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose2d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class StreamableConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = "reflect", + ): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn( + "StreamableConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + kernel_size = ( + kernel_size - 1 + ) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d( + x, kernel_size, stride, padding_total + ) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) + return self.conv(x) + + +class StreamableConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: tp.Dict[str, tp.Any] = {}, + ): + super().__init__() + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + + +class StreamableLSTM(nn.Module): + """LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + + Args: + dim (int): Dimension of the input/output. + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection. + """ + + def __init__( + self, + dim: int, + kernel_sizes: tp.List[int] = [3, 1], + dilations: tp.List[int] = [1, 1], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): + super().__init__() + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + StreamableConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = StreamableConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order. We use the decoder order as some models may only employ the decoder. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. + For the encoder, it corresponds to the N first blocks. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = True, + compress: int = 2, + lstm: int = 0, + disable_norm_outer_blocks: int = 0, + ): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks + self.disable_norm_outer_blocks = disable_norm_outer_blocks + assert ( + self.disable_norm_outer_blocks >= 0 + and self.disable_norm_outer_blocks <= self.n_blocks + ), ( + "Number of blocks for which to disable norm is invalid." + "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + ) + + act = getattr(nn, activation) + mult = 1 + model: tp.List[nn.Module] = [ + StreamableConv1d( + channels, + mult * n_filters, + kernel_size, + norm="none" if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + norm=block_norm, + norm_params=norm_params, + activation=activation, + activation_params=activation_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + # Add downsampling layers + model += [ + act(**activation_params), + StreamableConv1d( + mult * n_filters, + mult * n_filters * 2, + kernel_size=ratio * 2, + stride=ratio, + norm=block_norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + mult *= 2 + + if lstm: + model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] + + model += [ + act(**activation_params), + StreamableConv1d( + mult * n_filters, + dimension, + last_kernel_size, + norm=( + "none" if self.disable_norm_outer_blocks == self.n_blocks else norm + ), + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple. + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. + For the decoder, it corresponds to the N last blocks. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + final_activation: tp.Optional[str] = None, + final_activation_params: tp.Optional[dict] = None, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = True, + compress: int = 2, + lstm: int = 0, + disable_norm_outer_blocks: int = 0, + trim_right_ratio: float = 1.0, + ): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks + self.disable_norm_outer_blocks = disable_norm_outer_blocks + assert ( + self.disable_norm_outer_blocks >= 0 + and self.disable_norm_outer_blocks <= self.n_blocks + ), ( + "Number of blocks for which to disable norm is invalid." + "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + ) + + act = getattr(nn, activation) + mult = int(2 ** len(self.ratios)) + model: tp.List[nn.Module] = [ + StreamableConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=( + "none" if self.disable_norm_outer_blocks == self.n_blocks else norm + ), + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + + if lstm: + model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + block_norm = ( + "none" + if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) + else norm + ) + # Add upsampling layers + model += [ + act(**activation_params), + StreamableConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=block_norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=block_norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params), + StreamableConv1d( + n_filters, + channels, + last_kernel_size, + norm="none" if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [final_act(**final_activation_params)] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y + + +def exists(val: tp.Optional[tp.Any]) -> bool: + return val is not None + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if exists(val) else d + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +def orthogonal_loss_fn(t): + # eq (2) from https://arxiv.org/abs/2112.00384 + n = t.shape[0] + normed_codes = l2norm(t) + identity = torch.eye(n, device=t.device) + cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes) + return ((cosine_sim - identity) ** 2).sum() / (n**2) + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.8, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + flashy.distrib.broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + flashy.distrib.broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + raise NotImplementedError() + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): + channels_last (bool): Channels are the last dimension in the input tensors. + commitment_weight (float): Weight for commitment loss. + orthogonal_reg_weight (float): Orthogonal regularization weights. + orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. + orthogonal_reg_max_codes (optional int): Maximum number of codes to consider + for orthogonal regularization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.8, + epsilon: float = 1e-5, + kmeans_init: bool = False, + kmeans_iters: int = 10, + threshold_ema_dead_code: int = 2, + channels_last: bool = False, + commitment_weight: float = 1.0, + orthogonal_reg_weight: float = 0.0, + orthogonal_reg_active_codes_only: bool = False, + orthogonal_reg_max_codes: tp.Optional[int] = None, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + self.channels_last = channels_last + + @property + def codebook(self): + return self._codebook.embed + + @property + def inited(self): + return self._codebook.inited + + def _preprocess(self, x): + if not self.channels_last: + x = rearrange(x, "b d n -> b n d") + return x + + def _postprocess(self, quantize): + if not self.channels_last: + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def encode(self, x): + x = self._preprocess(x) + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = self._postprocess(quantize) + return quantize + + def forward(self, x): + device = x.device + x = self._preprocess(x) + + x = self.project_in(x) + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + if self.orthogonal_reg_weight > 0: + codebook = self.codebook + + if self.orthogonal_reg_active_codes_only: + # only calculate orthogonal loss for the activated codes for this batch + unique_code_ids = torch.unique(embed_ind) + codebook = codebook[unique_code_ids] + + num_codes = codebook.shape[0] + if ( + exists(self.orthogonal_reg_max_codes) + and num_codes > self.orthogonal_reg_max_codes + ): + rand_ids = torch.randperm(num_codes, device=device)[ + : self.orthogonal_reg_max_codes + ] + codebook = codebook[rand_ids] + + orthogonal_reg_loss = orthogonal_loss_fn(codebook) + loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight + + quantize = self.project_out(quantize) + quantize = self._postprocess(quantize) + + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + codebook_size = kwargs.pop("codebook_size", None) + if codebook_size is None: + raise ValueError("codebook_size must be provided in kwargs") + if type(codebook_size) != list: + codebook_size = [codebook_size] * num_quantizers + self.layers = nn.ModuleList( + [ + VectorQuantization(codebook_size=cur_codebook_size, **kwargs) + for _, cur_codebook_size in zip(range(num_quantizers), codebook_size) + ] + ) + + # self.layers = nn.ModuleList( + # [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + # ) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for i, layer in enumerate(self.layers[:n_q]): + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + # the original code is below + # since quantize has the gradient of residual, according to line 321 + # quantize = x + (quantize - x).detach() + # the code below will make commitment loss to be 0 for all codebooks except for codebook1 + # https://github.com/facebookresearch/encodec/issues/25 + # therefore we change it + + residual = residual - quantized + # residual = residual - quantized.detach() + # since commitment loss is averaged, the scale of the loss won't get change (not as said in the issue above) + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out + + +class ResidualVectorQuantizer(BaseQuantizer): + """Residual Vector Quantizer. + + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + q_dropout (bool): Random quantizer drop out at train time. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + orthogonal_reg_weight (float): Orthogonal regularization weights. + orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. + orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. + for orthogonal regularization. + """ + + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + q_dropout: bool = False, + bins: tp.Union[int, tp.List[int]] = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 10, + threshold_ema_dead_code: int = 2, + orthogonal_reg_weight: float = 0.0, + orthogonal_reg_active_codes_only: bool = False, + orthogonal_reg_max_codes: tp.Optional[int] = None, + ): + super().__init__() + self.max_n_q = n_q + self.n_q = n_q + self.q_dropout = q_dropout + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + orthogonal_reg_weight=self.orthogonal_reg_weight, + orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, + orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, + channels_last=False, + ) + + def forward(self, x: torch.Tensor, frame_rate: int): + n_q = self.n_q + if self.training and self.q_dropout: + n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) + if type(self.bins) == list: + bins = self.bins + else: + bins = [self.bins] * self.n_q + bw_per_q = [math.log2(bin) * frame_rate / 1000 for bin in bins] + bw = torch.tensor(sum(bw_per_q)).to(x) + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + codes = codes.transpose(0, 1) + # codes is [B, K, T], with T frames, K nb of codebooks. + return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a given input tensor with the specified frame rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + """ + n_q = self.n_q + codes = self.vq.encode(x, n_q=n_q) + codes = codes.transpose(0, 1) + # codes is [B, K, T], with T frames, K nb of codebooks. + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. + codes = codes.transpose(0, 1) + quantized = self.vq.decode(codes) + return quantized + + @property + def total_codebooks(self): + return self.max_n_q + + @property + def num_codebooks(self): + return self.n_q + + def set_num_codebooks(self, n: int): + assert n > 0 and n <= self.max_n_q + self.n_q = n + + +class DummyQuantizer(BaseQuantizer): + """Fake quantizer that actually does not perform any quantization.""" + + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, frame_rate: int): + q = x.unsqueeze(1) + return QuantizedResult( + x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x) + ) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + In the case of the DummyQuantizer, the codes are actually identical + to the input and resulting quantized representation as no quantization is done. + """ + return x.unsqueeze(1) + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation. + In the case of the DummyQuantizer, the codes are actually identical + to the input and resulting quantized representation as no quantization is done. + """ + return codes.squeeze(1) + + @property + def total_codebooks(self): + """Total number of codebooks.""" + return 1 + + @property + def num_codebooks(self): + """Total number of codebooks.""" + return self.total_codebooks + + def set_num_codebooks(self, n: int): + """Set the number of active codebooks.""" + raise AttributeError( + "Cannot override the number of codebooks for the dummy quantizer" + ) + + +class EncodecModel(CompressionModel): + """Encodec model operating on the raw waveform. + + Args: + encoder (nn.Module): Encoder network. + decoder (nn.Module): Decoder network. + quantizer (BaseQuantizer): Quantizer network. + frame_rate (int): Frame rate for the latent representation. + sample_rate (int): Audio sample rate. + channels (int): Number of audio channels. + causal (bool): Whether to use a causal version of the model. + renormalize (bool): Whether to renormalize the audio before running the model. + """ + + # we need assignment to override the property in the abstract class, + # I couldn't find a better way... + frame_rate: float = 0 + sample_rate: int = 0 + channels: int = 0 + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + quantizer: BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.quantizer = quantizer + self.frame_rate = frame_rate + self.sample_rate = sample_rate + self.channels = channels + self.renormalize = renormalize + self.causal = causal + if self.causal: + # we force disabling here to avoid handling linear overlap of segments + # as supported in original EnCodec codebase. + assert not self.renormalize, "Causal model does not support renormalize" + + @property + def total_codebooks(self): + """Total number of quantizer codebooks available.""" + return self.quantizer.total_codebooks + + @property + def num_codebooks(self): + """Active number of codebooks used by the quantizer.""" + return self.quantizer.num_codebooks + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer.""" + self.quantizer.set_num_codebooks(n) + + @property + def cardinality(self): + """Cardinality of each codebook.""" + return self.quantizer.bins + + def preprocess( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + scale: tp.Optional[torch.Tensor] + if self.renormalize: + mono = x.mean(dim=1, keepdim=True) + volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() + scale = 1e-8 + volume + x = x / scale + scale = scale.view(-1, 1) + else: + scale = None + return x, scale + + def postprocess( + self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: + if scale is not None: + assert self.renormalize + x = x * scale.view(-1, 1, 1) + return x + + def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: + if encode: + return self.encode(x) + else: + raise NotImplementedError("model forward and training is not supported.") + assert x.dim() == 3 + length = x.shape[-1] + x, scale = self.preprocess(x) + + emb = self.encoder(x) + q_res = self.quantizer(emb, self.frame_rate) + out = self.decoder(q_res.x) + + # remove extra padding added by the encoder and decoder + assert out.shape[-1] >= length, (out.shape[-1], length) + out = out[..., :length] + + q_res.x = self.postprocess(out, scale) + + return q_res + + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """Encode the given input tensor to quantized representation along with scale parameter. + + Args: + x (torch.Tensor): Float tensor of shape [B, C, T] + + Returns: + codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: + codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. + scale a float tensor containing the scale for audio renormalizealization. + """ + assert x.dim() == 3 + x, scale = self.preprocess(x) + emb = self.encoder(x) + codes = self.quantizer.encode(emb) + return codes, scale + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + """Decode the given codes to a reconstructed representation, using the scale to perform + audio denormalization if needed. + + Args: + codes (torch.Tensor): Int tensor of shape [B, K, T] + scale (torch.Tensor, optional): Float tensor containing the scale value. + + Returns: + out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. + """ + emb = self.decode_latent(codes) + out = self.decoder(emb) + out = self.postprocess(out, scale) + # out contains extra padding added by the encoder and decoder + return out + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + return self.quantizer.decode(codes) + + +class EncodecModel_encode_only(CompressionModel): + """Encodec model operating on the raw waveform. Encode only, so no decoder + + Args: + encoder (nn.Module): Encoder network. + quantizer (BaseQuantizer): Quantizer network. + frame_rate (int): Frame rate for the latent representation. + sample_rate (int): Audio sample rate. + channels (int): Number of audio channels. + causal (bool): Whether to use a causal version of the model. + renormalize (bool): Whether to renormalize the audio before running the model. + """ + + # we need assignment to override the property in the abstract class, + # I couldn't find a better way... + frame_rate: float = 0 + sample_rate: int = 0 + channels: int = 0 + + def __init__( + self, + encoder: nn.Module, + quantizer: BaseQuantizer, + frame_rate: int, + sample_rate: int, + channels: int, + causal: bool = False, + renormalize: bool = False, + ): + super().__init__() + self.encoder = encoder + self.quantizer = quantizer + self.frame_rate = frame_rate + self.sample_rate = sample_rate + self.channels = channels + self.renormalize = renormalize + self.causal = causal + if self.causal: + # we force disabling here to avoid handling linear overlap of segments + # as supported in original EnCodec codebase. + assert not self.renormalize, "Causal model does not support renormalize" + + @property + def total_codebooks(self): + """Total number of quantizer codebooks available.""" + return self.quantizer.total_codebooks + + @property + def num_codebooks(self): + """Active number of codebooks used by the quantizer.""" + return self.quantizer.num_codebooks + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer.""" + self.quantizer.set_num_codebooks(n) + + @property + def cardinality(self): + """Cardinality of each codebook.""" + return self.quantizer.bins + + def preprocess( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + scale: tp.Optional[torch.Tensor] + if self.renormalize: + mono = x.mean(dim=1, keepdim=True) + volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() + scale = 1e-8 + volume + x = x / scale + scale = scale.view(-1, 1) + else: + scale = None + return x, scale + + def postprocess( + self, x: torch.Tensor, scale: tp.Optional[torch.Tensor] = None + ) -> torch.Tensor: + if scale is not None: + assert self.renormalize + x = x * scale.view(-1, 1, 1) + return x + + def forward(self, x: torch.Tensor, encode=False) -> QuantizedResult: + if encode: + return self.encode(x) + else: + raise NotImplementedError("model forward and training is not supported.") + assert x.dim() == 3 + length = x.shape[-1] + x, scale = self.preprocess(x) + + emb = self.encoder(x) + q_res = self.quantizer(emb, self.frame_rate) + out = self.decoder(q_res.x) + + # remove extra padding added by the encoder and decoder + assert out.shape[-1] >= length, (out.shape[-1], length) + out = out[..., :length] + + q_res.x = self.postprocess(out, scale) + + return q_res + + def encode( + self, x: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """Encode the given input tensor to quantized representation along with scale parameter. + + Args: + x (torch.Tensor): Float tensor of shape [B, C, T] + + Returns: + codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of: + codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. + scale a float tensor containing the scale for audio renormalizealization. + """ + assert x.dim() == 3 + x, scale = self.preprocess(x) + emb = self.encoder(x) + codes = self.quantizer.encode(emb) + return codes, scale + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + """Decode the given codes to a reconstructed representation, using the scale to perform + audio denormalization if needed. + + Args: + codes (torch.Tensor): Int tensor of shape [B, K, T] + scale (torch.Tensor, optional): Float tensor containing the scale value. + + Returns: + out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. + """ + raise NotImplementedError("Decode is not supported for encode only model") + emb = self.decode_latent(codes) + out = self.decoder(emb) + out = self.postprocess(out, scale) + # out contains extra padding added by the encoder and decoder + return out + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + raise NotImplementedError("Decode is not supported for encode only model") + return self.quantizer.decode(codes) + + +def get_quantizer( + quantizer: str, cfg: omegaconf.DictConfig, dimension: int +) -> BaseQuantizer: + klass = {"no_quant": DummyQuantizer, "rvq": ResidualVectorQuantizer}[quantizer] + kwargs = dict_from_config(getattr(cfg, quantizer)) + if quantizer != "no_quant": + kwargs["dimension"] = dimension + return klass(**kwargs) + + +def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): + if encoder_name == "seanet": + kwargs = dict_from_config(getattr(cfg, "seanet")) + encoder_override_kwargs = kwargs.pop("encoder") + decoder_override_kwargs = kwargs.pop("decoder") + encoder_kwargs = {**kwargs, **encoder_override_kwargs} + decoder_kwargs = {**kwargs, **decoder_override_kwargs} + encoder = SEANetEncoder(**encoder_kwargs) + decoder = SEANetDecoder(**decoder_kwargs) + return encoder, decoder + else: + raise KeyError(f"Unexpected compression model {cfg.compression_model}") + + +def get_compression_model(ckpt_fn, encode_only=False, device="cpu") -> CompressionModel: + """Instantiate a compression model.""" + if device == None: + device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + state = torch.load( + ckpt_fn, map_location="cpu", weights_only=False + ) # TODO: Convert to SafeTensors + cfg = state["xp.cfg"] + cfg.device = str(device) + weights = state["best_state"]["model"] + assert ( + cfg.compression_model == "encodec" + ), "Only Encodec model is supported for now." + if encode_only: + all_keys = list(weights.keys()) + for key in all_keys: + if key.startswith("decoder"): + del weights[key] + kwargs = dict_from_config(getattr(cfg, "encodec")) + encoder_name = kwargs.pop("autoencoder") + quantizer_name = kwargs.pop("quantizer") + encoder, _ = get_encodec_autoencoder(encoder_name, cfg) + quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) + frame_rate = kwargs["sample_rate"] // encoder.hop_length + renormalize = kwargs.pop("renormalize", False) + # deprecated params + kwargs.pop("renorm", None) + compression_model = EncodecModel_encode_only( + encoder, quantizer, frame_rate=frame_rate, renormalize=renormalize, **kwargs + ).to(cfg.device) + assert ( + compression_model.sample_rate == cfg.sample_rate + ), "Compression model sample rate should match" + compression_model.load_state_dict(weights) + compression_model.eval() + return compression_model + + else: + kwargs = dict_from_config(getattr(cfg, "encodec")) + encoder_name = kwargs.pop("autoencoder") + quantizer_name = kwargs.pop("quantizer") + encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) + quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) + frame_rate = kwargs["sample_rate"] // encoder.hop_length + renormalize = kwargs.pop("renormalize", False) + # deprecated params + kwargs.pop("renorm", None) + compression_model = EncodecModel( + encoder, + decoder, + quantizer, + frame_rate=frame_rate, + renormalize=renormalize, + **kwargs, + ).to(cfg.device) + assert ( + compression_model.sample_rate == cfg.sample_rate + ), "Compression model sample rate should match" + compression_model.load_state_dict(weights) + compression_model.eval() + return compression_model + + +if __name__ == "__main__": + import torchaudio + + ckpt_fn = "/home/pyp/BoostedVoiceEditor/pretrained/encodec_6f79c6a8.th" + audio_in_fns = [ + "/home/pyp/BoostedVoiceEditor/demo/pam.wav", + "/home/pyp/BoostedVoiceEditor/demo/ray.wav", + "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000.wav", + "/home/pyp/BoostedVoiceEditor/demo/caribbean.wav", + "/home/pyp/BoostedVoiceEditor/demo/bible.wav", + "/home/pyp/BoostedVoiceEditor/demo/miley.wav", + ] + audio_out_fns = [ + "/home/pyp/BoostedVoiceEditor/demo/pam_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/ray_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/84_121550_000074_000000_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/caribbean_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/bible_encodecTest.wav", + "/home/pyp/BoostedVoiceEditor/demo/miley_encodecTest.wav", + ] + device = torch.device( + "cuda" + if torch.cuda.is_available() + else "mps" if torch.backends.mps.is_available() else "cpu" + ) + model = get_compression_model(ckpt_fn, device=device) + + for audio_in_fn, audio_out_fn in zip(audio_in_fns, audio_out_fns): + audio_in, sr = torchaudio.load(audio_in_fn) + if sr != model.sample_rate: + audio_in = torchaudio.transforms.Resample(sr, model.sample_rate)(audio_in) + if audio_in.shape[0] == 2: + audio_in = audio_in.mean(dim=0, keepdim=True) + audio_in = audio_in.unsqueeze(0) + audio_in = audio_in.to(torch.float32).to(device) + codes = model.encode(audio_in)[0] + audio_out = model.decode(codes)[0].cpu() + torchaudio.save(audio_out_fn, audio_out, model.sample_rate) diff --git a/pretrained/.gitkeep b/voicestar/modules/__init__.py similarity index 100% rename from pretrained/.gitkeep rename to voicestar/modules/__init__.py diff --git a/models/modules/activation.py b/voicestar/modules/activation.py similarity index 83% rename from models/modules/activation.py rename to voicestar/modules/activation.py index 46d6074..de2fcfe 100644 --- a/models/modules/activation.py +++ b/voicestar/modules/activation.py @@ -1,4 +1,4 @@ -# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py +# From https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py (Apache 2.0 license) from typing import Optional, Tuple import torch @@ -11,19 +11,21 @@ import logging from typing import Callable, List, Optional, Tuple, Union from typing import TYPE_CHECKING + if TYPE_CHECKING: from torch.types import _dtype as DType else: # The JIT doesn't understand Union, nor torch.dtype here DType = int + def _canonical_mask( - mask: Optional[Tensor], - mask_name: str, - other_type: Optional[DType], - other_name: str, - target_type: DType, - check_other: bool = True, + mask: Optional[Tensor], + mask_name: str, + other_type: Optional[DType], + other_name: str, + target_type: DType, + check_other: bool = True, ) -> Optional[Tensor]: if mask is not None: @@ -31,7 +33,8 @@ def _canonical_mask( _mask_is_float = torch.is_floating_point(mask) if _mask_dtype != torch.bool and not _mask_is_float: raise AssertionError( - f"only bool and floating types of {mask_name} are supported") + f"only bool and floating types of {mask_name} are supported" + ) if check_other and other_type is not None: if _mask_dtype != other_type: warnings.warn( @@ -39,12 +42,12 @@ def _canonical_mask( "is deprecated. Use same type for both instead." ) if not _mask_is_float: - mask = ( - torch.zeros_like(mask, dtype=target_type) - .masked_fill_(mask, float("-inf")) + mask = torch.zeros_like(mask, dtype=target_type).masked_fill_( + mask, float("-inf") ) return mask + def _in_projection_packed( q: Tensor, k: Tensor, @@ -85,7 +88,13 @@ def _in_projection_packed( # self-attention proj = F.linear(q, w, b) # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() - proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + proj = ( + proj.unflatten(-1, (3, E)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + .contiguous() + ) return proj[0], proj[1], proj[2] else: # encoder-decoder attention @@ -97,7 +106,13 @@ def _in_projection_packed( q_proj = F.linear(q, w_q, b_q) kv_proj = F.linear(k, w_kv, b_kv) # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk() - kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() + kv_proj = ( + kv_proj.unflatten(-1, (2, E)) + .unsqueeze(0) + .transpose(0, -2) + .squeeze(-2) + .contiguous() + ) return (q_proj, kv_proj[0], kv_proj[1]) else: w_q, w_k, w_v = w.chunk(3) @@ -107,6 +122,7 @@ def _in_projection_packed( b_q, b_k, b_v = b.chunk(3) return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) + def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: if input is None: return None @@ -114,19 +130,35 @@ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]: return input.dtype raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor") + def rotate_half(x): - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat([-x2, x1], dim=-1) -def apply_rotary_pos_emb(q, k, q_sinu=None, k_sinu=None, sinu=None, unsqueeze_dim=1, args=None, q_offset=0): + +def apply_rotary_pos_emb( + q, k, q_sinu=None, k_sinu=None, sinu=None, unsqueeze_dim=1, args=None, q_offset=0 +): if sinu is not None: - q_emb = q * sinu['cos'][:, q_offset:q_offset+q.shape[2]].unsqueeze(unsqueeze_dim) + rotate_half(q) * sinu['sin'][:, q_offset:q_offset+q.shape[2]].unsqueeze(unsqueeze_dim) - k_emb = k * sinu['cos'][:, :k.shape[2]].unsqueeze(unsqueeze_dim) + rotate_half(k) * sinu['sin'][:, :k.shape[2]].unsqueeze(unsqueeze_dim) + q_emb = q * sinu["cos"][:, q_offset : q_offset + q.shape[2]].unsqueeze( + unsqueeze_dim + ) + rotate_half(q) * sinu["sin"][:, q_offset : q_offset + q.shape[2]].unsqueeze( + unsqueeze_dim + ) + k_emb = k * sinu["cos"][:, : k.shape[2]].unsqueeze(unsqueeze_dim) + rotate_half( + k + ) * sinu["sin"][:, : k.shape[2]].unsqueeze(unsqueeze_dim) if q_sinu is not None: assert sinu is None, "sinu must be None" - q_emb = q * q_sinu['cos'][:, :, q_offset:q_offset+q.shape[2]] + rotate_half(q) * q_sinu['sin'][:, :, q_offset:q_offset+q.shape[2]] - k_emb = k * k_sinu['cos'][:, :, :k.shape[2]] + rotate_half(k) * k_sinu['sin'][:, :, :k.shape[2]] + q_emb = ( + q * q_sinu["cos"][:, :, q_offset : q_offset + q.shape[2]] + + rotate_half(q) * q_sinu["sin"][:, :, q_offset : q_offset + q.shape[2]] + ) + k_emb = ( + k * k_sinu["cos"][:, :, : k.shape[2]] + + rotate_half(k) * k_sinu["sin"][:, :, : k.shape[2]] + ) # else: # assert freqs is not None, "freqs must be provided" # assert key_lens is not None, "key_lens must be provided" @@ -243,6 +275,7 @@ class MultiheadAttention(Module): >>> attn_output, attn_output_weights = multihead_attn(query, key, value) """ + __constants__ = ["batch_first"] bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] @@ -268,9 +301,7 @@ def __init__( self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = ( - self.kdim == embed_dim and self.vdim == embed_dim - ) + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout @@ -281,12 +312,8 @@ def __init__( ), "embed_dim must be divisible by num_heads" if add_bias_kv: - self.bias_k = Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) - self.bias_v = Parameter( - torch.empty((1, 1, embed_dim), **factory_kwargs) - ) + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None @@ -311,7 +338,7 @@ def __init__( self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) - if bias: # True by default + if bias: # True by default self.in_proj_bias = Parameter( torch.empty(3 * embed_dim, **factory_kwargs) ) @@ -385,11 +412,11 @@ def forward( attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True, past: Optional[Tensor] = None, - q_sinu = None, - k_sinu = None, - sinu = None, - args = None, - q_offset = 0, + q_sinu=None, + k_sinu=None, + sinu=None, + args=None, + q_offset=0, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: @@ -453,20 +480,18 @@ def forward( ) why_not_fast_path = "" if not is_batched: - why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) elif query is not key or key is not value: # When lifting this restriction, don't forget to either # enforce that the dtypes all match or test cases where # they don't! why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" - elif ( - self.in_proj_bias is not None - and query.dtype != self.in_proj_bias.dtype - ): + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" elif ( - self.in_proj_weight is not None - and query.dtype != self.in_proj_weight.dtype + self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype ): # this case will fail anyway, but at least they'll get a useful error message. why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" @@ -515,9 +540,7 @@ def forward( for x in tensor_args ] ): - why_not_fast_path = ( - "some Tensor argument is neither CUDA nor CPU" - ) + why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" elif torch.is_grad_enabled() and any( [x is not None and x.requires_grad for x in tensor_args] ): @@ -536,16 +559,14 @@ def forward( self.in_proj_bias, self.out_proj.weight, self.out_proj.bias, - key_padding_mask - if key_padding_mask is not None - else attn_mask, + key_padding_mask if key_padding_mask is not None else attn_mask, need_weights, average_attn_weights, - 1 - if key_padding_mask is not None - else 0 - if attn_mask is not None - else None, + ( + 1 + if key_padding_mask is not None + else 0 if attn_mask is not None else None + ), ) any_nested = query.is_nested or key.is_nested or value.is_nested @@ -563,9 +584,7 @@ def forward( query, key = [x.transpose(1, 0) for x in (query, key)] value = key else: - query, key, value = [ - x.transpose(1, 0) for x in (query, key, value) - ] + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: attn_output, attn_output_weights = F.multi_head_attention_forward( @@ -624,62 +643,78 @@ def forward( mask_name="key_padding_mask", other_type=_none_or_dtype(attn_mask), other_name="attn_mask", - target_type=query.dtype + target_type=query.dtype, ) attn_mask = _canonical_mask( - mask=attn_mask, - mask_name="attn_mask", - other_type=None, - other_name="", - target_type=query.dtype, - check_other=False, - ) + mask=attn_mask, + mask_name="attn_mask", + other_type=None, + other_name="", + target_type=query.dtype, + check_other=False, + ) head_dim = self.embed_dim // self.num_heads - assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}" - assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" - q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias) + assert ( + head_dim * self.num_heads == self.embed_dim + ), f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}" + assert ( + key.shape == value.shape + ), f"key shape {key.shape} does not match value shape {value.shape}" + q, k, v = _in_projection_packed( + query, key, value, self.in_proj_weight, self.in_proj_bias + ) # k_present, v_present = k, v - + # # reshape q, k, v for multihead attention and make em batch first # - - q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) - v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim) + v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose( + 0, 1 + ) # (bsz * num_heads, src_len, head_dim) src_len = k.size(1) if past is not None and past.ndim > 2: expected_src_len = src_len + past[0].shape[-2] else: expected_src_len = src_len - # ensure attn_mask's dim is 3 if attn_mask is not None: if attn_mask.dim() == 2: correct_2d_size = (tgt_len, expected_src_len) if attn_mask.shape != correct_2d_size: - raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") + raise RuntimeError( + f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}." + ) attn_mask = attn_mask.unsqueeze(0) elif attn_mask.dim() == 3: correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len) if attn_mask.shape != correct_3d_size: - raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") + raise RuntimeError( + f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}." + ) else: - raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") - + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) + if key_padding_mask is not None: - assert key_padding_mask.shape == (bsz, expected_src_len), \ - f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}" - key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \ - expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len) + assert key_padding_mask.shape == ( + bsz, + expected_src_len, + ), f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}" + key_padding_mask = ( + key_padding_mask.view(bsz, 1, 1, expected_src_len) + .expand(-1, num_heads, -1, -1) + .reshape(bsz * num_heads, 1, expected_src_len) + ) if attn_mask is None: attn_mask = key_padding_mask else: attn_mask = attn_mask + key_padding_mask - + if not self.training: dropout_p = 0.0 else: @@ -731,8 +766,12 @@ def forward( v = v.view(bsz, num_heads, src_len, head_dim) # logging.info(f"shape of past: {past.shape}") if past is not None: - present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim) - if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache + present = torch.stack( + [k, v], dim=0 + ) # (2, bsz, num_heads, src_len, head_dim) + if ( + past.ndim > 2 + ): # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache pk, pv = past k = torch.cat([pk, k], dim=-2) v = torch.cat([pv, v], dim=-2) @@ -742,23 +781,40 @@ def forward( # here we assume that this kvcache is only used in self-attention, and therefore k and q always have the same seq_len # rope positional encoding if sinu is not None: - # direct rotary + # direct rotary # logging.info("perform rotary positional encoding") - q, k = apply_rotary_pos_emb(q, k, sinu=sinu, args = args, q_offset=q_offset) + q, k = apply_rotary_pos_emb( + q, k, sinu=sinu, args=args, q_offset=q_offset + ) if q_sinu is not None: assert sinu is None, "sinu and q_sinu cannot be used together" assert k_sinu is not None, "k_sinu must be provided" - q, k = apply_rotary_pos_emb(q, k, q_sinu=q_sinu, k_sinu=k_sinu, args = args, q_offset=q_offset) - + q, k = apply_rotary_pos_emb( + q, k, q_sinu=q_sinu, k_sinu=k_sinu, args=args, q_offset=q_offset + ) + # if self.training and it's cross attention, will get attention_weights - if args != None and self.training and getattr(args, "attention_alignment_loss", 0) and not (query is key): + if ( + args != None + and self.training + and getattr(args, "attention_alignment_loss", 0) + and not (query is key) + ): attention_weights = q @ k.transpose(-1, -2) else: attention_weights = None - attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False) - attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask, dropout_p, is_causal=False + ) + attn_output = ( + attn_output.permute(2, 0, 1, 3) + .contiguous() + .view(bsz * tgt_len, embed_dim) + ) - attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias) + attn_output = F.linear( + attn_output, self.out_proj.weight, self.out_proj.bias + ) attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1)) if not is_batched: # squeeze the output if input was unbatched @@ -769,13 +825,18 @@ def forward( # return (attn_output, present), None # harded coded, the code do not support returning attn weigths yet - attn_output_weights=None + attn_output_weights = None if self.batch_first and is_batched: if attention_weights != None: - return {"attn_output": attn_output.transpose(1, 0), "attention_weights": attention_weights}, present + return { + "attn_output": attn_output.transpose(1, 0), + "attention_weights": attention_weights, + }, present return attn_output.transpose(1, 0), present else: if attention_weights != None: - return {"attn_output": attn_output, "attention_weights": attention_weights}, present + return { + "attn_output": attn_output, + "attention_weights": attention_weights, + }, present return attn_output, present - diff --git a/models/modules/embedding.py b/voicestar/modules/embedding.py similarity index 79% rename from models/modules/embedding.py rename to voicestar/modules/embedding.py index 92e36da..25f2507 100644 --- a/models/modules/embedding.py +++ b/voicestar/modules/embedding.py @@ -1,5 +1,5 @@ -# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py -# Copyright 2023 (authors: Feiteng Li) +# From https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py (Apache 2.0 license) +# Copyright 2023 (authors: Feiteng Li) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -79,9 +79,7 @@ def extend_pe(self, x): x.size(1) - 1, -1, -1.0, dtype=torch.float32 ).unsqueeze(1) else: - position = torch.arange( - 0, x.size(1), dtype=torch.float32 - ).unsqueeze(1) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.dim_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.dim_model) @@ -105,7 +103,7 @@ def __init__( dropout: float = 0.0, scale: bool = False, alpha: bool = False, - args = None + args=None, ): super().__init__() self.args = args @@ -115,10 +113,14 @@ def __init__( self.dropout = torch.nn.Dropout(p=dropout) self.reverse = False - self.div_term = torch.exp( - torch.arange(0, self.dim_model, 2, dtype=torch.float32) - * -(math.log(args.sinusoidal_base) / self.dim_model) - ).unsqueeze(0).unsqueeze(0) # [1, 1, dim_model//2] + self.div_term = ( + torch.exp( + torch.arange(0, self.dim_model, 2, dtype=torch.float32) + * -(math.log(args.sinusoidal_base) / self.dim_model) + ) + .unsqueeze(0) + .unsqueeze(0) + ) # [1, 1, dim_model//2] self.position = None self.extend_position(torch.tensor(0.0).expand(1, 10000)) self.progress_scale = getattr(args, "progress_scale", 1.0) @@ -133,26 +135,34 @@ def extend_position(self, x): self.position = self.position.to(dtype=x.dtype, device=x.device) return if self.reverse: - self.position = torch.arange( - x.size(1) - 1, -1, -1.0, dtype=torch.float32 - ).unsqueeze(0).unsqueeze(2).to(x) + self.position = ( + torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(2) + .to(x) + ) else: - self.position = torch.arange( - 0, x.size(1), dtype=torch.float32 - ).unsqueeze(0).unsqueeze(2).to(x) # [1, seq_len, 1] + self.position = ( + torch.arange(0, x.size(1), dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(2) + .to(x) + ) # [1, seq_len, 1] def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: assert x.ndim == 3, x.shape self.extend_position(x) - x_lens = x_lens.unsqueeze(1).unsqueeze(2) # [B, 1, 1] + x_lens = x_lens.unsqueeze(1).unsqueeze(2) # [B, 1, 1] multiple = x_lens / (x_lens - 1) - progress = self.position[:, :x.shape[1]] * multiple / x_lens * self.progress_scale + progress = ( + self.position[:, : x.shape[1]] * multiple / x_lens * self.progress_scale + ) # torch.set_printoptions(edgeitems=100) # for i in range(x_lens.shape[0]): # logging.info(f"{progress[i, :x_lens[i,0,0], 0]}") - invfreq = self.div_term * progress # might want to use a scale term here + invfreq = self.div_term * progress # might want to use a scale term here pe = torch.zeros_like(x) pe[..., 0::2] = torch.sin(invfreq) pe[..., 1::2] = torch.cos(invfreq) output = x * self.x_scale + self.alpha * pe - return self.dropout(output) \ No newline at end of file + return self.dropout(output) diff --git a/models/modules/sampling.py b/voicestar/modules/sampling.py similarity index 88% rename from models/modules/sampling.py rename to voicestar/modules/sampling.py index 7acdcd4..db6c165 100644 --- a/models/modules/sampling.py +++ b/voicestar/modules/sampling.py @@ -1,6 +1,16 @@ +""" +VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate + +GitHub: https://github.com/jasonppy/VoiceStar +License: MIT + +Copyright (c) 2025 Puyuan Peng +""" + import torch import torch.nn.functional as F + def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 ): @@ -14,18 +24,14 @@ def top_k_top_p_filtering( From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ if top_k > 0: - top_k = min( - max(top_k, min_tokens_to_keep), logits.size(-1) - ) # Safety check + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum( - F.softmax(sorted_logits, dim=-1), dim=-1 - ) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p @@ -33,9 +39,7 @@ def top_k_top_p_filtering( # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ - ..., :-1 - ].clone() + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing @@ -44,7 +48,8 @@ def top_k_top_p_filtering( ) logits[indices_to_remove] = filter_value return logits - + + def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): # temperature: (`optional`) float # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. @@ -60,4 +65,4 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) # Sample token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) - return token \ No newline at end of file + return token diff --git a/models/modules/scaling.py b/voicestar/modules/scaling.py similarity index 94% rename from models/modules/scaling.py rename to voicestar/modules/scaling.py index cd245ea..d5d0c9a 100644 --- a/models/modules/scaling.py +++ b/voicestar/modules/scaling.py @@ -1,5 +1,5 @@ -# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) +# From https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py (Apache 2.0 license) +# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -32,12 +32,14 @@ # from valle.utils import Transpose + class Transpose(nn.Identity): """(N, T, D) -> (N, D, T)""" def forward(self, input: torch.Tensor) -> torch.Tensor: return input.transpose(1, 2) - + + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( @@ -97,9 +99,9 @@ def _compute_scale_factor( else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ( - (min_abs - x_abs_mean) * (gain_factor / min_abs) - ).clamp(min=0, max=max_factor) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( + min=0, max=max_factor + ) above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( min=0, max=max_factor @@ -135,8 +137,7 @@ def _compute_sign_factor( # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. factor2 = ( - (proportion_positive - max_positive) - * (gain_factor / (1.0 - max_positive)) + (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) ).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: @@ -204,9 +205,7 @@ def forward( return ans @staticmethod - def backward( - ctx, ans_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None]: + def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: (is_same,) = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect @@ -255,9 +254,7 @@ def forward(ctx, x: Tensor, min_abs: float) -> Tensor: def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: return ( - random_cast_to_half( - ans_grad.to(torch.float32), min_abs=ctx.min_abs - ), + random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), None, ) else: @@ -275,11 +272,7 @@ def __init__(self, min_abs: float = 5.0e-06): self.min_abs = min_abs def forward(self, x: Tensor): - if ( - torch.jit.is_scripting() - or not self.training - or torch.jit.is_tracing() - ): + if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): return x else: return RandomGradFunction.apply(x, self.min_abs) @@ -346,9 +339,9 @@ def backward(ctx, x_grad, *args): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) @@ -424,7 +417,7 @@ def forward(self, x: Tensor) -> Tensor: # gradients to allow the parameter to get back into the allowed region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales @@ -448,9 +441,7 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_( - ans.bias, -0.1 * initial_scale, 0.1 * initial_scale - ) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans @@ -479,9 +470,7 @@ def ScaledConv1d( with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_( - ans.bias, -0.1 * initial_scale, 0.1 * initial_scale - ) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans @@ -713,11 +702,7 @@ def __init__( self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or not x.requires_grad - or torch.jit.is_tracing() - ): + if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): return _no_op(x) count = self.cpu_count @@ -835,11 +820,9 @@ def _whitening_metric(x: Tensor, num_groups: int): # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / ( - num_groups * channels_per_group - ) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric @@ -877,8 +860,7 @@ def backward(ctx, x_grad: Tensor): (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad scale = ctx.grad_scale * ( - x_grad.to(torch.float32).norm() - / (penalty_grad.norm() + 1.0e-20) + x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None @@ -943,11 +925,7 @@ def forward(self, x: Tensor) -> Tensor: you use the returned value, or the graph will be freed and nothing will happen in backprop. """ - if ( - not x.requires_grad - or random.random() > self.prob - or self.grad_scale == 0 - ): + if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: if hasattr(self, "min_prob") and random.random() < 0.25: @@ -1069,27 +1047,21 @@ def forward(self, x: Tensor) -> Tensor: orig_x = x x = x.to(torch.float32) with torch.no_grad(): - x = x.transpose(self.channel_dim, -1).reshape( - -1, self.num_channels - ) + x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) new_direction, coeffs = self._find_direction_coeffs( x, self.max_eig_direction ) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. - variance_proportion = (x_var - x_residual_var) / ( - x_var + 1.0e-20 - ) + variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) # ensure new direction is nonzero even if x == 0, by including `direction`. - self._set_direction( - 0.1 * self.max_eig_direction + new_direction - ) + self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": logging.info( @@ -1101,9 +1073,7 @@ def forward(self, x: Tensor) -> Tensor: # reach here, only near the beginning of training if we are # starting to diverge, should this constraint be active. cur_prob = self.cur_prob - self.cur_prob = ( - 1.0 # next time, do the update with probability 1.0. - ) + self.cur_prob = 1.0 # next time, do the update with probability 1.0. return MaxEigLimiterFunction.apply( orig_x, coeffs, new_direction, self.channel_dim, self.scale ) @@ -1152,9 +1122,7 @@ def _find_direction_coeffs( # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ( - (coeffs ** 2).sum() + 1.0e-20 - ) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) return cur_direction, coeffs @@ -1194,9 +1162,9 @@ def forward(ctx, x: Tensor) -> Tensor: # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = (deriv - floor) * ( - 255.0 / (ceil - floor) - ) + torch.rand_like(deriv) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1299,9 +1267,7 @@ def _test_whiten(): def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 - x = 1.0 * ( - (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0 - ) + x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1325,9 +1291,7 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1359,8 +1323,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x ** 2).mean().sqrt() - y_rms = (y ** 2).mean().sqrt() + x_rms = (x**2).mean().sqrt() + y_rms = (y**2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1403,4 +1367,4 @@ def _test_softmax(): _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() - _test_double_swish_deriv() \ No newline at end of file + _test_double_swish_deriv() diff --git a/models/modules/transformer.py b/voicestar/modules/transformer.py similarity index 75% rename from models/modules/transformer.py rename to voicestar/modules/transformer.py index 3cbc70b..f152b48 100644 --- a/models/modules/transformer.py +++ b/voicestar/modules/transformer.py @@ -1,3 +1,12 @@ +""" +VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate + +GitHub: https://github.com/jasonppy/VoiceStar +License: MIT + +Copyright (c) 2025 Puyuan Peng +""" + import copy, logging import numbers from functools import partial @@ -242,13 +251,9 @@ def __init__( norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if layer_norm_cls == IdentityNorm: - norm2 = BalancedBasicNorm( - d_model, eps=layer_norm_eps, **factory_kwargs - ) + norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) else: - norm2 = layer_norm_cls( - d_model, eps=layer_norm_eps, **factory_kwargs - ) + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if adaptive_layer_norm: self.norm1 = AdaptiveLayerNorm(d_model, norm1) @@ -292,7 +297,7 @@ def forward( if isinstance(src, tuple): x, stage_embedding = src is_src_tuple = True - + if src_key_padding_mask is not None: _skpm_dtype = src_key_padding_mask.dtype if _skpm_dtype != torch.bool and not torch.is_floating_point( @@ -308,41 +313,52 @@ def forward( self.norm1(x, stage_embedding), src_mask, src_key_padding_mask, - past, sinu = sinu + past, + sinu=sinu, ) - out, present = out # present is the kvcache of the present timestep + out, present = out # present is the kvcache of the present timestep x = x + out x = x + self._ff_block(self.norm2(x, stage_embedding)) else: - out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past, sinu = sinu) - out, present = out # present is the kvcache of the present timestep + out, attn = self._sa_block_attn( + x, src_mask, src_key_padding_mask, past, sinu=sinu + ) + out, present = out # present is the kvcache of the present timestep x = self.norm1( x + out, stage_embedding, ) x = self.norm2(x + self._ff_block(x), stage_embedding) assert not is_src_tuple - # return (x, stage_embedding) + # return (x, stage_embedding) return (x, attn) else: if self.norm_first: out = self._sa_block( self.norm1(x, stage_embedding), src_mask, - src_key_padding_mask, past, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'] + src_key_padding_mask, + past, + sinu=sinu, + q_sinu=pm_sinu["q"], + k_sinu=pm_sinu["q"], ) - out, present = out # present is the kvcache of the present timestep + out, present = out # present is the kvcache of the present timestep x = x + out x = x + self._ff_block(self.norm2(x, stage_embedding)) else: - out = self._sa_block(x, src_mask, src_key_padding_mask, sinu = sinu, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q']) - out, present = out # present is the kvcache of the present timestep - x = self.norm1( - x + out, - stage_embedding, past + out = self._sa_block( + x, + src_mask, + src_key_padding_mask, + sinu=sinu, + q_sinu=pm_sinu["q"], + k_sinu=pm_sinu["q"], ) + out, present = out # present is the kvcache of the present timestep + x = self.norm1(x + out, stage_embedding, past) x = self.norm2(x + self._ff_block(x), stage_embedding) - + if is_src_tuple: x = (x, stage_embedding) if present != None: @@ -356,9 +372,9 @@ def _sa_block( attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], past: Optional[Tensor] = None, - sinu = None, - q_sinu = None, - k_sinu = None + sinu=None, + q_sinu=None, + k_sinu=None, ) -> Tensor: x = self.self_attn( x, @@ -368,9 +384,9 @@ def _sa_block( key_padding_mask=key_padding_mask, need_weights=False, past=past, - sinu = sinu, - q_sinu = q_sinu, - k_sinu = k_sinu + sinu=sinu, + q_sinu=q_sinu, + k_sinu=k_sinu, ) x, present = x return self.dropout1(x), present @@ -390,7 +406,7 @@ def _sa_block_attn( attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=True, - past=past + past=past, ) x, present = x return (self.dropout1(x), present), attn @@ -400,22 +416,37 @@ def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) -def pre_compute_sinusoidal(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) - position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1] - inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2] - freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2] - freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d] + +def pre_compute_sinusoidal( + dim, base, max_len=10000 +): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim) + ) + position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze( + 1 + ) # [x_len_max, 1] + inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2] + freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2] + freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d] return {"sin": freqs.sin(), "cos": freqs.cos()} -def pre_compute_freqs(dim, base, max_len = 10000): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) - position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [x_len_max, 1] - inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2] - freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2] - freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d] + +def pre_compute_freqs( + dim, base, max_len=10000 +): # 4000 max length equivalent of mimi code is 320s, as mimi is 12.5hz + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim) + ) + position_ids_expanded = torch.arange(0, max_len, dtype=torch.float32).unsqueeze( + 1 + ) # [x_len_max, 1] + inv_freq_expanded = inv_freq.unsqueeze(0).float() # [1, d//2] + freqs = position_ids_expanded @ inv_freq_expanded # [x_len_max, d//2] + freqs = torch.cat((freqs, freqs), dim=-1).unsqueeze(0) # [1, x_len_max, d] return freqs + class TransformerEncoder(nn.Module): r"""TransformerEncoder is a stack of N encoder layers. Users can build the BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. @@ -434,9 +465,19 @@ class TransformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = transformer_encoder(src) """ + __constants__ = ["norm"] - def __init__(self, encoder_layer, num_layers, norm=None, rope_base=None, d_model=None, nhead=None, args=None): + def __init__( + self, + encoder_layer, + num_layers, + norm=None, + rope_base=None, + d_model=None, + nhead=None, + args=None, + ): super(TransformerEncoder, self).__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers @@ -450,10 +491,10 @@ def __init__(self, encoder_layer, num_layers, norm=None, rope_base=None, d_model if rope_base is not None: if self.progress_no_multiple: - self.pm_freqs = pre_compute_freqs(d_model//nhead, rope_base) + self.pm_freqs = pre_compute_freqs(d_model // nhead, rope_base) self.sinu = None else: - self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base) + self.sinu = pre_compute_sinusoidal(d_model / nhead, rope_base) self.pm_freqs = None # logging.info(f"get precomputed sinusoidal for {rope_base=}: {self.sinu=}") else: @@ -466,7 +507,7 @@ def forward( mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, return_layer_states: bool = False, - need_weights:Optional[bool] = False, + need_weights: Optional[bool] = False, past: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -490,7 +531,7 @@ def forward( output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, - past=past + past=past, ) layer_states.append(output[0]) @@ -509,7 +550,7 @@ def forward( src_mask=mask, src_key_padding_mask=src_key_padding_mask, need_weights=True, - past=past + past=past, ) layer_attn.append(output[1]) @@ -517,7 +558,7 @@ def forward( output = self.norm(output) return layer_attn, output - + output = src all_present = [] if self.sinu is not None: @@ -531,13 +572,15 @@ def forward( if src_key_padding_mask != None: query_lens = (~src_key_padding_mask).int().sum(-1).to(output.device) else: - query_lens = torch.tensor([output.shape[1]]*output.shape[0]).to(output.device) - assert query_lens.ndim==1, query_lens - q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] + query_lens = torch.tensor([output.shape[1]] * output.shape[0]).to( + output.device + ) + assert query_lens.ndim == 1, query_lens + q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1) - q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d] + q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d] q_emb = q_emb / q_lens_expanded * self.progress_scale - q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead + q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead q_sin = q_emb.sin().unsqueeze(1) self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}} else: @@ -546,7 +589,10 @@ def forward( output = {"input": output, "sinu": self.sinu, "pm_sinu": self.pm_sinu} for n_layer, mod in enumerate(self.layers): output = mod( - output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer] + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + past=None if past is None else past[n_layer], ) if isinstance(output, list): output, present = output @@ -558,7 +604,9 @@ def forward( if self.norm is not None: output = self.norm(output) if all_present != []: - all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim) + all_present = torch.stack( + all_present, dim=0 + ) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim) output = [output, all_present] return output @@ -630,26 +678,16 @@ def __init__( self.activation = activation if adaptive_layer_norm: - norm1 = layer_norm_cls( - d_model, eps=layer_norm_eps, **factory_kwargs - ) - norm2 = layer_norm_cls( - d_model, eps=layer_norm_eps, **factory_kwargs - ) - norm3 = layer_norm_cls( - d_model, eps=layer_norm_eps, **factory_kwargs - ) + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + norm3 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm1 = AdaptiveLayerNorm(d_model, norm1) self.norm2 = AdaptiveLayerNorm(d_model, norm2) self.norm3 = AdaptiveLayerNorm(d_model, norm3) else: - self.norm1 = layer_norm_cls( - d_model, eps=layer_norm_eps, **factory_kwargs - ) - self.norm2 = layer_norm_cls( - d_model, eps=layer_norm_eps, **factory_kwargs - ) + self.norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) if layer_norm_cls == IdentityNorm: self.norm3 = BalancedBasicNorm( d_model, eps=layer_norm_eps, **factory_kwargs @@ -667,8 +705,12 @@ def forward( memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, - tgt_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used - memory_is_causal: Optional[bool] = False, # for compatibility with the nn.TransformerDecoder, not used + tgt_is_causal: Optional[ + bool + ] = False, # for compatibility with the nn.TransformerDecoder, not used + memory_is_causal: Optional[ + bool + ] = False, # for compatibility with the nn.TransformerDecoder, not used past: Optional[Tensor] = None, ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer. @@ -706,14 +748,23 @@ def forward( # past stores the kvcache for self-attention, and it can also be used to infer q_offset if past is not None and past.ndim > 2: - q_offset = past[0].shape[-2] # past is (2, batch_size, num_heads, seq_len, head_dim), 2 contains [k, v], these are for self-attn, therefore also reflect the length of q + q_offset = past[0].shape[ + -2 + ] # past is (2, batch_size, num_heads, seq_len, head_dim), 2 contains [k, v], these are for self-attn, therefore also reflect the length of q else: q_offset = 0 - if self.norm_first: temp = self._sa_block( - self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset + self.norm1(x, stage_embedding), + tgt_mask, + tgt_key_padding_mask, + q_sinu=pm_sinu["q"], + k_sinu=pm_sinu["q"], + sinu=sinu, + args=args, + past=past, + q_offset=q_offset, ) present = temp[1] x = x + temp[0] @@ -721,7 +772,12 @@ def forward( self.norm2(x, stage_embedding), memory, memory_mask, - memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args = args, q_offset=q_offset + memory_key_padding_mask, + q_sinu=pm_sinu["q"], + k_sinu=pm_sinu["k"], + sinu=sinu, + args=args, + q_offset=q_offset, ) if isinstance(cross_out, dict): attention_weights = cross_out["attention_weights"] @@ -731,27 +787,44 @@ def forward( x = x + cross_out x = x + self._ff_block(self.norm3(x, stage_embedding)) else: - temp = self._sa_block(x, tgt_mask, tgt_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['q'], sinu=sinu, args = args, past=past, q_offset=q_offset) + temp = self._sa_block( + x, + tgt_mask, + tgt_key_padding_mask, + q_sinu=pm_sinu["q"], + k_sinu=pm_sinu["q"], + sinu=sinu, + args=args, + past=past, + q_offset=q_offset, + ) present = temp[1] x = self.norm1( x + temp[0], stage_embedding, ) cross_out = self._mha_block( - x, memory, memory_mask, memory_key_padding_mask, q_sinu=pm_sinu['q'], k_sinu=pm_sinu['k'], sinu=sinu, args=args, q_offset=q_offset - ) + x, + memory, + memory_mask, + memory_key_padding_mask, + q_sinu=pm_sinu["q"], + k_sinu=pm_sinu["k"], + sinu=sinu, + args=args, + q_offset=q_offset, + ) if isinstance(cross_out, dict): attention_weights = cross_out["attention_weights"] cross_out = cross_out["x"] else: - attention_weights = None + attention_weights = None x = self.norm2( - x - + cross_out, + x + cross_out, stage_embedding, ) x = self.norm3(x + self._ff_block(x), stage_embedding) - + if attention_weights is not None: x = {"x": x, "attention_weights": attention_weights} if tgt_is_tuple: @@ -768,10 +841,10 @@ def _sa_block( key_padding_mask: Optional[Tensor], q_sinu=None, k_sinu=None, - sinu = None, - args = None, - past = None, - q_offset = 0 + sinu=None, + args=None, + past=None, + q_offset=0, ) -> Tensor: # if past is not None and past.ndim > 2: # print(f"self-attn, k len: {past[0].shape[-2] + x.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}") @@ -784,11 +857,11 @@ def _sa_block( attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, - q_sinu = q_sinu, - k_sinu = k_sinu, - sinu = sinu, - past = past, - q_offset = q_offset + q_sinu=q_sinu, + k_sinu=k_sinu, + sinu=sinu, + past=past, + q_offset=q_offset, ) x, present = x return self.dropout1(x), present @@ -800,11 +873,11 @@ def _mha_block( mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], - q_sinu = None, - k_sinu = None, - sinu = None, - args = None, - q_offset = 0 + q_sinu=None, + k_sinu=None, + sinu=None, + args=None, + q_offset=0, ) -> Tensor: # print(f"cross-attn, k len: {mem.shape[-2]}, q len: {x.shape[-2]} q_offset: {q_offset}") x = self.multihead_attn( @@ -814,16 +887,16 @@ def _mha_block( attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, - q_sinu = q_sinu, - k_sinu = k_sinu, - sinu = sinu, - args = args, - q_offset = q_offset + q_sinu=q_sinu, + k_sinu=k_sinu, + sinu=sinu, + args=args, + q_offset=q_offset, ) if len(x) == 2 and isinstance(x[0], dict) and "attention_weights" in x[0]: x, present = x - attention_weights = x['attention_weights'] - x = x['attn_output'] + attention_weights = x["attention_weights"] + x = x["attn_output"] return {"x": self.dropout2(x), "attention_weights": attention_weights} elif len(x) == 2: x = x[0] @@ -845,30 +918,29 @@ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: elif activation == "gelu": return F.gelu - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + def _generate_square_subsequent_mask( - sz: int, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + sz: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ) -> Tensor: r"""Generate a square causal mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ if device is None: - device = torch.device('cpu') + device = torch.device("cpu") if dtype is None: dtype = torch.float32 return torch.triu( - torch.full((sz, sz), float('-inf'), dtype=dtype, device=device), + torch.full((sz, sz), float("-inf"), dtype=dtype, device=device), diagonal=1, ) -def _get_seq_len( - src: Tensor, - batch_first: bool -) -> Optional[int]: + + +def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: if src.is_nested: return None @@ -882,10 +954,11 @@ def _get_seq_len( seq_len_pos = 1 if batch_first else 0 return src_size[seq_len_pos] + def _detect_is_causal_mask( - mask: Optional[Tensor], - is_causal: Optional[bool] = None, - size: Optional[int] = None, + mask: Optional[Tensor], + is_causal: Optional[bool] = None, + size: Optional[int] = None, ) -> bool: """Return whether the given attention mask is causal. @@ -907,12 +980,13 @@ def _detect_is_causal_mask( Otherwise, checks for any causal mask. """ # Prevent type refinement - make_causal = (is_causal is True) + make_causal = is_causal is True if is_causal is None and mask is not None: sz = size if size is not None else mask.size(-2) causal_comparison = _generate_square_subsequent_mask( - sz, device=mask.device, dtype=mask.dtype) + sz, device=mask.device, dtype=mask.dtype + ) # Do not use `torch.equal` so we handle batched masks by # broadcasting the comparison. @@ -923,6 +997,7 @@ def _detect_is_causal_mask( return make_causal + class TransformerDecoder(nn.Module): r"""TransformerDecoder is a stack of N decoder layers. @@ -939,17 +1014,17 @@ class TransformerDecoder(nn.Module): >>> out = transformer_decoder(tgt, memory) """ - __constants__ = ['norm'] + __constants__ = ["norm"] def __init__( self, decoder_layer: "TransformerDecoderLayer", num_layers: int, norm: Optional[nn.Module] = None, - rope_base=None, - d_model=None, + rope_base=None, + d_model=None, nhead=None, - args=None + args=None, ) -> None: super().__init__() torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") @@ -957,22 +1032,32 @@ def __init__( self.num_layers = num_layers self.norm = norm self.args = args - if getattr(self.args, 'decoder_regular_rope', False): - self.sinu = pre_compute_sinusoidal(d_model/nhead, rope_base) + if getattr(self.args, "decoder_regular_rope", False): + self.sinu = pre_compute_sinusoidal(d_model / nhead, rope_base) self.pm_freqs = None else: self.sinu = None if rope_base is not None: - self.pm_freqs = pre_compute_freqs(d_model/nhead, rope_base) + self.pm_freqs = pre_compute_freqs(d_model / nhead, rope_base) # logging.info(f"get precomputed freqs for {rope_base=}: {self.freqs=}") else: self.pm_freqs = None self.progress_scale = getattr(self.args, "progress_scale", 1.0) - def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: Optional[bool] = None, - memory_is_causal: bool = False, query_lens: Optional[Tensor] = None, key_lens: Optional[Tensor] = None, past: Optional[Tensor] = None) -> Tensor: + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + tgt_is_causal: Optional[bool] = None, + memory_is_causal: bool = False, + query_lens: Optional[Tensor] = None, + key_lens: Optional[Tensor] = None, + past: Optional[Tensor] = None, + ) -> Tensor: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: @@ -1011,21 +1096,42 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None self.sinu[key] = self.sinu[key].to(output.device) if self.pm_freqs is not None: assert self.sinu is None - if not self.training and hasattr(self, "pm_sinu") and past is not None and past[0].ndim > 2: # inference mode, will use cached sinu for the same example - assert self.pm_sinu['q'] is not None and self.pm_sinu['k'] is not None + if ( + not self.training + and hasattr(self, "pm_sinu") + and past is not None + and past[0].ndim > 2 + ): # inference mode, will use cached sinu for the same example + assert self.pm_sinu["q"] is not None and self.pm_sinu["k"] is not None # check batch size, need to modify the batch size if we use multi_trial during inference - if self.pm_sinu['q']['cos'].shape[0] != tgt.shape[0]: - if self.pm_sinu['q']['cos'].shape[0] > tgt.shape[0]: - self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'][:tgt.shape[0]] - self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'][:tgt.shape[0]] - self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'][:tgt.shape[0]] - self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'][:tgt.shape[0]] + if self.pm_sinu["q"]["cos"].shape[0] != tgt.shape[0]: + if self.pm_sinu["q"]["cos"].shape[0] > tgt.shape[0]: + self.pm_sinu["q"]["cos"] = self.pm_sinu["q"]["cos"][ + : tgt.shape[0] + ] + self.pm_sinu["q"]["sin"] = self.pm_sinu["q"]["sin"][ + : tgt.shape[0] + ] + self.pm_sinu["k"]["cos"] = self.pm_sinu["k"]["cos"][ + : tgt.shape[0] + ] + self.pm_sinu["k"]["sin"] = self.pm_sinu["k"]["sin"][ + : tgt.shape[0] + ] else: - assert self.pm_sinu['q']['cos'].shape[0] == 1 - self.pm_sinu['q']['cos'] = self.pm_sinu['q']['cos'].repeat(tgt.shape[0], 1, 1, 1) - self.pm_sinu['q']['sin'] = self.pm_sinu['q']['sin'].repeat(tgt.shape[0], 1, 1, 1) - self.pm_sinu['k']['cos'] = self.pm_sinu['k']['cos'].repeat(tgt.shape[0], 1, 1, 1) - self.pm_sinu['k']['sin'] = self.pm_sinu['k']['sin'].repeat(tgt.shape[0], 1, 1, 1) + assert self.pm_sinu["q"]["cos"].shape[0] == 1 + self.pm_sinu["q"]["cos"] = self.pm_sinu["q"]["cos"].repeat( + tgt.shape[0], 1, 1, 1 + ) + self.pm_sinu["q"]["sin"] = self.pm_sinu["q"]["sin"].repeat( + tgt.shape[0], 1, 1, 1 + ) + self.pm_sinu["k"]["cos"] = self.pm_sinu["k"]["cos"].repeat( + tgt.shape[0], 1, 1, 1 + ) + self.pm_sinu["k"]["sin"] = self.pm_sinu["k"]["sin"].repeat( + tgt.shape[0], 1, 1, 1 + ) pass else: self.pm_freqs = self.pm_freqs.to(output.device) @@ -1033,38 +1139,51 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None query_lens = (~tgt_key_padding_mask).int().sum(-1).to(tgt.device) if key_lens is None: key_lens = (~memory_key_padding_mask).int().sum(-1).to(tgt.device) - assert key_lens.ndim==1, key_lens - assert query_lens.ndim==1, query_lens - q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] - k_lens_expanded = key_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] + assert key_lens.ndim == 1, key_lens + assert query_lens.ndim == 1, query_lens + q_lens_expanded = query_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] + k_lens_expanded = key_lens.unsqueeze(-1).unsqueeze(-1) # [B, 1, 1] query_ids_multiple = q_lens_expanded / (q_lens_expanded - 1) key_ids_multiple = k_lens_expanded / (k_lens_expanded - 1) - q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d] - k_emb = self.pm_freqs * key_ids_multiple # [B, k_len_max, d] + q_emb = self.pm_freqs * query_ids_multiple # [B, q_len_max, d] + k_emb = self.pm_freqs * key_ids_multiple # [B, k_len_max, d] q_emb = q_emb / q_lens_expanded * self.progress_scale k_emb = k_emb / k_lens_expanded * self.progress_scale - q_cos = q_emb.cos().unsqueeze(1) # [B, 1, q_len_max, d] # 1 is for nhead + q_cos = q_emb.cos().unsqueeze( + 1 + ) # [B, 1, q_len_max, d] # 1 is for nhead q_sin = q_emb.sin().unsqueeze(1) k_cos = k_emb.cos().unsqueeze(1) k_sin = k_emb.sin().unsqueeze(1) - self.pm_sinu = {"q": {"cos": q_cos, "sin": q_sin}, "k": {"cos": k_cos, "sin": k_sin}} + self.pm_sinu = { + "q": {"cos": q_cos, "sin": q_sin}, + "k": {"cos": k_cos, "sin": k_sin}, + } else: self.pm_sinu = {"q": None, "k": None} - output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args} + output = { + "input": output, + "pm_sinu": self.pm_sinu, + "sinu": self.sinu, + "args": self.args, + } if past != None: all_present = [] if self.training and getattr(self.args, "attention_alignment_loss", 0): all_attn_weights = [] for i, mod in enumerate(self.layers): - output = mod(output, memory, tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - past=past[i] if past != None else None - # tgt_is_causal=tgt_is_causal, - # memory_is_causal=memory_is_causal - ) + output = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + past=past[i] if past != None else None, + # tgt_is_causal=tgt_is_causal, + # memory_is_causal=memory_is_causal + ) if past != None: output, cur_present = output all_present.append(cur_present) @@ -1073,17 +1192,24 @@ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None all_attn_weights.append(current_attn_weights) output = output["x"] if self.sinu is not None or self.pm_sinu is not None: - output = {"input": output, "pm_sinu": self.pm_sinu, "sinu": self.sinu, "args": self.args} + output = { + "input": output, + "pm_sinu": self.pm_sinu, + "sinu": self.sinu, + "args": self.args, + } if self.pm_sinu is not None or self.sinu is not None: output = output["input"] if self.norm is not None: output = self.norm(output) if self.training and getattr(self.args, "attention_alignment_loss", 0): - assert len(all_attn_weights) == self.num_layers, f"{len(all_attn_weights)=}, {self.num_layers=}" + assert ( + len(all_attn_weights) == self.num_layers + ), f"{len(all_attn_weights)=}, {self.num_layers=}" output = {"output": output, "attention_weights": all_attn_weights} if past != None: all_present = torch.stack(all_present, dim=0) output = [output, all_present] else: output = [output, None] - return output \ No newline at end of file + return output diff --git a/models/modules/utils.py b/voicestar/modules/utils.py similarity index 79% rename from models/modules/utils.py rename to voicestar/modules/utils.py index c8e8788..f3dab55 100644 --- a/models/modules/utils.py +++ b/voicestar/modules/utils.py @@ -1,3 +1,12 @@ +""" +VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate + +GitHub: https://github.com/jasonppy/VoiceStar +License: MIT + +Copyright (c) 2025 Puyuan Peng +""" + import torch @@ -28,9 +37,12 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: return expaned_lengths >= lengths.unsqueeze(-1) + def generate_partial_autoregressive_mask(sz, start, end): mask = torch.zeros(sz, sz).bool() - mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1) + mask[start:end, start:end] = torch.triu( + torch.ones(end - start, end - start, dtype=torch.bool), diagonal=1 + ) mask[:start, start:end] = True mask[end:, start:end] = True return mask diff --git a/models/modules/visualizer.py b/voicestar/modules/visualizer.py similarity index 88% rename from models/modules/visualizer.py rename to voicestar/modules/visualizer.py index f8ac5bc..cab289c 100644 --- a/models/modules/visualizer.py +++ b/voicestar/modules/visualizer.py @@ -1,6 +1,5 @@ -# cp from https://github.com/lifeiteng/vall-e/blob/main/valle/models/visualizer.py -#!/usr/bin/env python3 -# Copyright 2023 (authors: Feiteng Li) +# From https://github.com/lifeiteng/vall-e/blob/main/valle/models/visualizer.py (Apache 2.0 license) +# Copyright 2023 (authors: Feiteng Li) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -33,9 +32,7 @@ def visualize( text_tokens = batch["text_tokens"].to("cpu").detach().numpy() text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() audio_features = batch["audio_features"].to("cpu").detach().numpy() - audio_features_lens = ( - batch["audio_features_lens"].to("cpu").detach().numpy() - ) + audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy() assert text_tokens.ndim == 2 utt_ids, texts = batch["utt_id"], batch["text"] @@ -44,9 +41,7 @@ def visualize( decoder_outputs = predicts[1] if isinstance(decoder_outputs, list): decoder_outputs = decoder_outputs[-1] - decoder_outputs = ( - decoder_outputs.to("cpu").type(torch.float32).detach().numpy() - ) + decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() vmin, vmax = 0, 1024 # Encodec if decoder_outputs.dtype == np.float32: @@ -104,4 +99,4 @@ def visualize( plt.colorbar() plt.savefig(f"{output_dir}/{utt_id}.png") - plt.close() \ No newline at end of file + plt.close() diff --git a/voicestar/scripts/__init__.py b/voicestar/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/voicestar/scripts/copy_codebase.py b/voicestar/scripts/copy_codebase.py new file mode 100644 index 0000000..ea2fcc5 --- /dev/null +++ b/voicestar/scripts/copy_codebase.py @@ -0,0 +1,57 @@ +import os +import shutil +import fnmatch + + +def parse_gitignore(gitignore_path): + """Parse a .gitignore file and return a list of patterns.""" + patterns = [] + with open(gitignore_path, "r") as f: + for line in f: + # Ignore comments and blank lines + line = line.strip() + if not line or line.startswith("#"): + continue + # Handle wildcards and directory separators + patterns.append(line) + return patterns + + +def file_matches_patterns(file_path, patterns): + """Check if a file matches any of the patterns in .gitignore.""" + for pattern in patterns: + if fnmatch.fnmatch(file_path, pattern): + return True + return False + + +def copy_codebase(src, dst, max_size_mb=5, gitignore_path=None): + """Copy files from src to dst, skipping files larger than max_size_mb and matching .gitignore patterns.""" + if gitignore_path and os.path.exists(gitignore_path): + patterns = parse_gitignore(gitignore_path) + else: + patterns = [] + print("patterns to ignore: ", patterns) + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + for file in files: + file_path = os.path.join(root, file) + relative_path = os.path.relpath(file_path, src) + dst_path = os.path.join(dst, relative_path) + # ignore .git because of permission issues + if "/.git/" in file_path: + continue + + # Check .gitignore patterns + if file_matches_patterns(file_path, patterns): + # print(f"Skipping {file_path} because it matches a pattern in .gitignore") + continue + + # Check file size + if os.path.getsize(file_path) > max_size_mb * 1024 * 1024: + print(f"Skipping {file_path} because it's larger than {max_size_mb}MB") + continue + + # Make sure the destination directory exists + os.makedirs(os.path.dirname(dst_path), exist_ok=True) + shutil.copy(file_path, dst_path) diff --git a/voicestar/utils.py b/voicestar/utils.py new file mode 100644 index 0000000..bb3f8bd --- /dev/null +++ b/voicestar/utils.py @@ -0,0 +1,35 @@ +""" +VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate + +GitHub: https://github.com/jasonppy/VoiceStar +License: MIT + +Copyright (c) 2025 Puyuan Peng +""" + +import os +import random +import numpy as np +import torch +import torchaudio + + +def seed_everything(seed=1): + os.environ["PYTHONHASHSEED"] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def estimate_duration(ref_audio_path, text): + """ + Estimate duration based on seconds per character from the reference audio. + """ + info = torchaudio.info(ref_audio_path) + audio_duration = info.num_frames / info.sample_rate + length_text = max(len(text), 1) + spc = audio_duration / length_text # seconds per character + return len(text) * spc diff --git a/voicestar/voicestar.py b/voicestar/voicestar.py new file mode 100644 index 0000000..207d4ff --- /dev/null +++ b/voicestar/voicestar.py @@ -0,0 +1,1043 @@ +""" +VoiceStar: Robust, Duration-Controllable TTS that can Extrapolate + +GitHub: https://github.com/jasonppy/VoiceStar +License: MIT + +Copyright (c) 2025 Puyuan Peng +""" + +import random, os, copy +from typing import Dict, Iterator, List, Tuple, Union +import logging +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchmetrics.classification import MulticlassAccuracy +import torch.distributed as dist + +from voicestar.modules.utils import make_pad_mask, generate_partial_autoregressive_mask + +from voicestar.modules.embedding import ( + SinePositionalEmbedding, + TokenEmbedding, + SinePositionalEmbedding_progress, +) +from voicestar.modules.transformer import ( + AdaptiveLayerNorm, + LayerNorm, + TransformerDecoderLayer, + TransformerDecoder, + TransformerEncoder, + TransformerEncoderLayer, +) + + +def top_k_top_p_filtering( + logits, + top_k=0, + top_p=1.0, + min_p=1.0, + filter_value=-float("Inf"), + min_tokens_to_keep=1, +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if min_p < 1.0: + probs = F.softmax(logits, dim=-1) + indices_to_remove = probs < min_p + if not torch.any(indices_to_remove.sum(-1) == logits.size(-1)): + logits[indices_to_remove] = filter_value + top_k = 0 + top_p = 1.0 + # else will use other types of sampling, or no filtering + + # If top_k is a single integer + if isinstance(top_k, int) and top_k > 0: + # Safety check to ensure we don't ask for more than available + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) + + # Remove all tokens with a probability less than the last token of the top-k + threshold = torch.topk(logits, top_k, dim=-1)[0][..., -1, None] + indices_to_remove = logits < threshold + logits[indices_to_remove] = filter_value + + # If top_k is a list, assume it has the same length as M + elif isinstance(top_k, list): + # Ensure the length matches the first dimension + assert len(top_k) == logits.size( + 0 + ), f"top_k list length ({len(top_k)}) must match logits.size(0) ({logits.size(0)})" + + for i in range(logits.size(0)): + k_i = top_k[i] + if k_i > 0: + # Safety check + k_i = min(max(k_i, min_tokens_to_keep), logits.size(-1)) + row_threshold = torch.topk(logits[i], k_i, dim=-1)[0][-1] + indices_to_remove_i = logits[i] < row_threshold + logits[i, indices_to_remove_i] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + return logits + + +def topk_sampling(logits, top_k=10, top_p=1.0, min_p=1.0, temperature=1.0): + # temperature: (`optional`) float + # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + # top_k: (`optional`) int + # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + # top_p: (`optional`) float + # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_p=min_p) + # Sample + token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return token + + +class VoiceStarModel(nn.Module): + def __init__(self, args): + super().__init__() + self.args = args + assert ( + self.args.enc_dec ^ self.args.dec + ), f"self.args.enc_dec: {self.args.enc_dec}, self.args.dec: {self.args.dec}" + if not getattr(self.args, "special_first", False): + self.args.special_first = 0 + if not getattr(self.args, "n_special", False): + self.args.n_special = 3 + self.args.eos = getattr(self.args, "eos", -1) + self.eog = nn.Parameter( + torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long), + requires_grad=False, + ) # [K 1] + if self.args.eos > 0: + assert ( + self.args.eos != self.args.audio_pad_token + and self.args.eos != self.args.empty_token + ), self.args.eos + self.eos = nn.Parameter( + torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), + requires_grad=False, + ) # [K 1] + if type(self.args.audio_vocab_size) == str: + self.args.audio_vocab_size = eval(self.args.audio_vocab_size) + if type(self.args.audio_vocab_size) == list: # otherwise they are all lists + assert self.args.special_first + + self.n_text_tokens = self.args.text_vocab_size + 1 + assert ( + self.args.text_pad_token == self.args.text_vocab_size + ), f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}" + + if self.args.special_first and type(self.args.audio_vocab_size) == list: + self.n_audio_tokens = [ + tok + self.args.n_special for tok in self.args.audio_vocab_size + ] # special tokens: empty token, EOG token, audio pad token + assert self.args.empty_token == 0, self.args.empty_token + assert self.args.eog == 1, self.args.eog + assert self.args.audio_pad_token == 2, self.args.audio_pad_token + else: + self.n_audio_tokens = [ + self.args.audio_vocab_size + self.args.n_special + ] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token + assert ( + self.args.audio_vocab_size == self.args.empty_token + ), self.args.empty_token + assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog + assert ( + self.args.audio_pad_token == self.args.audio_vocab_size + 2 + ), self.args.audio_pad_token + + self.text_embedding = TokenEmbedding( + dim_model=self.args.d_model, + vocab_size=self.n_text_tokens, + dropout=self.args.text_embedding_dropout, + ) + + self.audio_embedding = nn.ModuleList( + [ + TokenEmbedding( + dim_model=self.args.audio_embedding_dim, + vocab_size=self.n_audio_tokens[k], + dropout=self.args.audio_embedding_dropout, + ) + for k in range(self.args.n_codebooks) + ] + ) + + rope_base = getattr(self.args, "rope_base", None) + use_sinusoidal = getattr(self.args, "use_sinusoidal", False) + use_sinusoidal_progress = getattr(self.args, "use_sinusoidal_progress", False) + logging.info(f"rope_base: {rope_base}, use_sinusoidal: {use_sinusoidal}") + if use_sinusoidal: + self.text_positional_embedding = SinePositionalEmbedding( + self.args.d_model, + dropout=self.args.text_positional_embedding_dropout, + scale=False, + alpha=True, # learnable scaler, scale the volume of positional embedding + ) + self.audio_positional_embedding = SinePositionalEmbedding( + self.args.d_model, + dropout=self.args.audio_positional_embedding_dropout, + scale=False, + alpha=True, # learnable scaler, scale the volume of positional embedding + ) + elif use_sinusoidal_progress: + self.text_positional_embedding = SinePositionalEmbedding_progress( + self.args.d_model, + dropout=self.args.text_positional_embedding_dropout, + scale=False, + alpha=True, # learnable scaler, scale the volume of positional embedding + args=self.args, + ) + self.audio_positional_embedding = SinePositionalEmbedding_progress( + self.args.d_model, + dropout=self.args.audio_positional_embedding_dropout, + scale=False, + alpha=True, # learnable scaler, scale the volume of positional embedding + args=self.args, + ) + + else: + + class NoOp: + def __init__(self): + pass + + def __call__(self, *args, **kwargs): + return args[0] + + self.text_positional_embedding = NoOp() + self.audio_positional_embedding = NoOp() + + if self.args.enc_dec: + enc_layer = TransformerEncoderLayer( + d_model=self.args.d_model, + nhead=self.args.nhead, + dim_feedforward=self.args.d_model * 4, + dropout=self.args.trm_dropout, + batch_first=True, + norm_first=True, + layer_norm_cls=LayerNorm, + ) # use the pre-norm arch + + self.encoder = TransformerEncoder( + encoder_layer=enc_layer, + num_layers=self.args.num_encoder_layers, + norm=LayerNorm(self.args.d_model), + rope_base=self.args.rope_base, + d_model=self.args.d_model, + nhead=self.args.nhead, + args=self.args, + ) # use the pre-norm arch + + dec_layer = TransformerDecoderLayer( + d_model=self.args.d_model, + nhead=self.args.nhead, + dim_feedforward=self.args.d_model * 4, + dropout=self.args.trm_dropout, + batch_first=True, + norm_first=True, + layer_norm_cls=LayerNorm, + ) + + self.decoder = TransformerDecoder( + decoder_layer=dec_layer, + num_layers=self.args.num_decoder_layers, + norm=LayerNorm(self.args.d_model), + rope_base=self.args.rope_base, + d_model=self.args.d_model, + nhead=self.args.nhead, + args=self.args, + ) # NOTE: this one I use torch.nn native implementation, as it's not implemented in .modules + + else: + dec_layer = TransformerEncoderLayer( + self.args.d_model, + self.args.nhead, + dim_feedforward=self.args.d_model * 4, + dropout=self.args.trm_dropout, + batch_first=True, + norm_first=True, + layer_norm_cls=LayerNorm, + ) + self.decoder = TransformerEncoder( + dec_layer, + num_layers=self.args.num_decoder_layers, + norm=LayerNorm(self.args.d_model), + ) + + if type(self.args.audio_vocab_size) == int: + self.predict_layer = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(self.args.d_model, self.args.audio_vocab_size // 2), + nn.GELU(), + nn.Linear( + self.args.audio_vocab_size // 2, self.n_audio_tokens[k] + ), + ) + for k in range(self.args.n_codebooks) + ] + ) + else: + self.predict_layer = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(self.args.d_model, self.args.d_model // 2), + nn.GELU(), + nn.Linear(self.args.d_model // 2, self.n_audio_tokens[k]), + ) + for k in range(self.args.n_codebooks) + ] + ) + + self.accuracy_metrics = nn.ModuleList( + [ + MulticlassAccuracy( + self.n_audio_tokens[k], + top_k=10, + average="micro", + multidim_average="global", + ignore_index=None, + ) + for k in range(self.args.n_codebooks) + ] + ) + + if self.args.eog_weight != 1: + raise NotImplementedError( + "now have different vocab_size for different codebooks, therefore currently don't support eog_weight" + ) + self.class_weight = nn.Parameter( + torch.ones(self.n_audio_tokens), requires_grad=False + ) + self.class_weight.data[self.args.eog] = self.args.eog_weight + + def dec_forward( + self, + x_input, + x_lens, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + need_weights=False, + past=None, + last_3_tokens=False, + ): + x_attn_mask = F.pad( + x_attention_mask, + (0, new_y_lens.max()), + value=True, + ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper + y_attn_mask = F.pad( + y_attention_mask, + (x_lens.max(), 0), # y is padded at the front + value=False, + ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + + # merge key padding and attention masks + bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max() + xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1) + _xy_padding_mask = ( + xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.args.nhead, -1, -1) + .reshape(bsz * self.args.nhead, 1, src_len) + ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + + new_attn_mask = torch.zeros_like(xy_attn_mask) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + + xy_input = torch.cat([x_input, y_input], dim=1) + if need_weights: + raise NotImplementedError("not implemented yet") + out, layer_attn_weights = self.decoder( + (xy_input, None), mask=xy_attn_mask, need_weights=True + ) + return layer_attn_weights + + if past == None: # do not use kvcache + out, _ = self.decoder((xy_input, None), mask=xy_attn_mask) + return out[:, x_lens.max() :], None + else: # use kvcache + if ( + past.ndim > 3 + ): # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet + if last_3_tokens: + xy_input = xy_input[:, -3:] + xy_attn_mask = xy_attn_mask[:, -3:] + else: + xy_input = xy_input[:, -1:] + xy_attn_mask = xy_attn_mask[:, -1:] + + out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past) + if isinstance(out, tuple): # get rid of stage_embedding + out = out[0] + + if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet + return out[:, x_lens.max() :], present + else: # used kvcache + return out, present + + def enc_dec_forward( + self, + xa, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + tgt_y_lens=None, + need_weights=False, + past=None, + last_3_tokens=False, + ): + assert not need_weights + if past != None and past.ndim > 3: + y_input = y_input[:, -1:] + y_attention_mask = y_attention_mask[-1:] + yhat, present = self.decoder( + tgt=y_input, + memory=xa, + tgt_mask=y_attention_mask, + tgt_key_padding_mask=y_padding_mask, + memory_key_padding_mask=x_padding_mask, + query_lens=tgt_y_lens, + past=past, + ) + return yhat, present + + def forward(self, batch, calc_loss=False): + """ + Args: + x: + A 2-D tensor of shape (N, S). + x_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (N, K, T). + where K is the number of codebooks + y_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + """ + x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"] + if len(x) == 0: + return None + x = x[ + :, : x_lens.max() + ] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x + y = y[..., : y_lens.max()] + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape + assert y_lens.ndim == 1, y_lens.shape + x_padding_mask = make_pad_mask(x_lens).to(x.device) + x_attention_mask = ( + torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1) + .bool() + .to(x_padding_mask.device) + ) + x_input = self.text_embedding(x) + x_input = self.text_positional_embedding(x_input, x_lens) + y_with_eos = [ + torch.cat([item[:, : y_lens[i]], self.eos], dim=-1) + for i, item in enumerate(y) + ] + targets = y_with_eos + # apply delayed stacking on y + shifted_y = [] + patterns = [] + new_y_lens = [] + if getattr(self, "empty_tokens", None) == None: + self.empty_tokens = torch.full( + (self.args.n_codebooks, self.args.n_codebooks), + self.args.empty_token, + dtype=torch.long, + ).to( + y.device + ) # [K, K] + for i in range(len(y)): + tmp = torch.cat( + [y_with_eos[i], self.empty_tokens], dim=-1 + ) # [K, T+n_codebooks] + for ii in range(self.args.n_codebooks): + tmp[ii] = torch.roll(tmp[ii], shifts=ii + 1, dims=0) + shifted_y.append( + tmp.transpose(1, 0) + ) # [K, T+n_codebooks] -> [T+n_codebooks, K] + new_y_lens.append(y_with_eos[i].shape[1] + self.empty_tokens.shape[1]) + + new_y_lens = torch.LongTensor(new_y_lens).to(y.device) + + cated_y = torch.nn.utils.rnn.pad_sequence( + shifted_y, batch_first=False, padding_value=self.args.audio_pad_token + ) + assert cated_y.shape == torch.Size( + [max(new_y_lens), len(y), self.args.n_codebooks] + ), cated_y.shape + cated_y = cated_y.permute(2, 0, 1) # [T,B,K]->[K,T,B] + stacked_embedded_y = torch.stack( + [self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], + dim=0, + ) # [K, T, B, D] + assert ( + stacked_embedded_y.shape[0] == self.args.n_codebooks + and stacked_embedded_y.shape[2] == len(y) + and stacked_embedded_y.shape[-1] == self.args.d_model + ), stacked_embedded_y.shape + embedded_y = stacked_embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D] + embedded_y = embedded_y.transpose(1, 0) # [T,B,D]->[B,T,D] + assert embedded_y.shape[1:] == torch.Size( + [max(new_y_lens), self.args.d_model] + ), embedded_y.shape + y_input = self.audio_positional_embedding(embedded_y, new_y_lens) + y_padding_mask = make_pad_mask(new_y_lens).to(y.device) + y_attention_mask = ( + torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1) + .bool() + .to(y_padding_mask.device) + ) + if self.args.dec: + y_out = self.dec_forward( + x_input, + x_lens, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + ) + else: + xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask) + y_out = self.enc_dec_forward( + xa, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + ) + y_out = y_out[0] # no kv-caching during training + assert ( + y_out.shape == y_input.shape + ), f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D] + logits = torch.stack( + [self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1 + ) # [B K S card] + assert ( + logits.shape[1] == self.args.n_codebooks + and logits.shape[3] == self.n_audio_tokens[0] + ), logits.shape + logits_use = [ + logit[:, : new_y_lens[i]] for i, logit in enumerate(logits) + ] # each of shape [K, T, card] + logits_final = [] + for i, logit in enumerate(logits_use): + logit_copy = logit.clone() + for ii in range(self.args.n_codebooks): + logit_copy[ii] = torch.roll(logit_copy[ii], shifts=-ii, dims=0) + logit = logit_copy[ + :, : -self.args.n_codebooks + ] # [K, T, card] -> [K, T-n_codebooks, card] + logits_final.append(logit) + if self.args.no_loss_on_prefix: + assert ( + "y_sep_token_position" in batch + ), f"y_sep_token_position should be in batch, but it's not" + logit_temp = [] + target_temp = [] + for jj, (logit, target) in enumerate(zip(logits_final, targets)): + # TODO already taken into consideration in depth transformer + logit_temp.append(logit[:, batch["y_sep_token_position"][jj] :]) + target_temp.append(target[:, batch["y_sep_token_position"][jj] :]) + logits_final = logit_temp + targets = target_temp + logits = torch.cat(logits_final, dim=1) # [K, T1+T2+T3+..., card] + targets = torch.cat(targets, dim=1) # [K, T1+T2+T3+...] + + assert targets.shape[:2] == logits.shape[:2], f"{targets.shape}, {logits.shape}" + loss = [] + ntokens = [] + top10acc = [] + for k, (logit, target) in enumerate( + zip(logits, targets) + ): # even though the loss and top10acc is calculated in a loop (loop through n_codebooks), validation is still taking a lot of mem, need to optimize this a little more + loss.append( + F.cross_entropy( + logit, + target, + reduction="mean", + weight=( + self.class_weight.data if self.args.eog_weight != 1 else None + ), + ignore_index=( + self.args.y_sep_token if self.args.y_sep_token != None else -100 + ), + ) + ) # ignore audio sep token as it's unpredictable (like the random early stop bug happened in 2023) + # NOTE have to ignore the sep token in the loss calculation + top10acc.append(self.accuracy_metrics[k](logit.detach(), target)) + ntokens.append(len(logit)) + + all_ntokens = sum(ntokens) + if self.args.codebook_weight != None: + codebook_weight = ( + eval(self.args.codebook_weight) + if isinstance(self.args.codebook_weight, str) + else self.args.codebook_weight + ) + else: + codebook_weight = [1.0] * self.args.n_codebooks + perplexity_by_codebook = [torch.exp(l) for l in loss] + loss = sum([l * nt * cw for l, nt, cw in zip(loss, ntokens, codebook_weight)]) + + top10acc_by_codebook = [t10a * nt for t10a, nt in zip(top10acc, ntokens)] + top10acc = sum(top10acc_by_codebook) + + ntokens = torch.tensor(all_ntokens).to(logits.device) + + ret = { + "loss": loss, + "perplexity_by_codebook": perplexity_by_codebook, + "top10acc": top10acc, + "top10acc_by_codebook": top10acc_by_codebook, + "effective_ntoken": ntokens, + } + + return ret + + def inference_tts( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + tgt_y_lens: torch.Tensor, # + top_k: Union[int, list[int]] = -100, + top_p: float = 1.0, + min_p: float = 1.0, + temperature: float = 1.0, + stop_repetition: int = 3, + kvcache: int = 1, + silence_tokens: list[int] = [], + multi_trial: list[int] = [], + *kargs, + ) -> torch.Tensor: + """ + This implementation uses kvcache, which should have significant speed up + Args: + x: + A 2-D tensor of shape (1, L). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, K). + tgt_y_lens: + *new arg* this specify the target length of y + top_k: (`optional`) int + The number of highest probability tokens to keep for top-k-filtering. Default to -100. + top_p: (`optional`) float + For Neucleus sampling + min_p: (`optional`) float + For min_p filtered sampling + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + multi_trial: (`optional`) list[int] + If not empty, it will be [n_trials, beam_size, trial_interval] + from the start and begining trial_interval, we duplicate the current sample by beam_size, + at the end of every trial_interval, we choose the sample with the highest log likelihood to keep and throw away the rest + """ + eog_inference = self.args.eos if self.args.eos > 0 else self.args.eog + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + if self.args.special_first: + y = y + int(self.args.n_special) + y = y.transpose(2, 1) # [1,T,K] -> [1,K,T] + assert ( + y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks + ), y.shape # there is no padding + + # make x attention mask and x_input + x_attention_mask = ( + torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1) + .bool() + .to(x.device) + ) + # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device) + x_input = self.text_embedding(x) + x_input = self.text_positional_embedding(x_input, x_lens) + + y_len = y.shape[2] + y_lens = torch.LongTensor([y_len]).to(y.device) + + # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario + rearranged_y = [[y[0]]] + assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][ + 0 + ].shape + + # # shift y to create the delayed pattern + if getattr(self, "empty_tokens", None) == None: + self.empty_tokens = torch.full( + (self.args.n_codebooks, self.args.n_codebooks), + self.args.empty_token, + dtype=torch.long, + ).to( + y.device + ) # [K, K] + temp = rearranged_y[0][0] + assert temp.ndim == 2 and temp.shape[0] == self.args.n_codebooks, temp.shape + temp = torch.cat([temp, self.empty_tokens], dim=-1) # [K, T+n_codebooks] + for ii in range(self.args.n_codebooks): + temp[ii] = torch.roll(temp[ii], shifts=ii + 1, dims=0) + shifted_y = [[temp]] + + # below is different from forward or inference + # where we cut this shifted part + shifted_y[0][0] = shifted_y[0][0][:, : -(self.args.n_codebooks - 1)] + assert ( + not ( + shifted_y[0][0][self.args.n_codebooks :] == self.args.empty_token + ).any() + and not (shifted_y[0][0][self.args.n_codebooks :] == self.args.eog).any() + ), shifted_y[0][0] + # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that + # next section is concate tensors of each sample to one tensor, which we also don't need + cated_y = shifted_y[0][0].unsqueeze(-1) # [K,S]->[K,S,B] + new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device) + assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1)) + assert not (cated_y == self.args.audio_pad_token).any(), cated_y + + # replace tokens in y with the embeddings, add sum codebooks up + embedded_y = torch.stack( + [self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], + dim=0, + ) # [K, S, B, D] + assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape + assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape + embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D] + embedded_y = embedded_y.transpose(1, 0) # [S,B,D]->[B,S,D] + + # positional embedding + y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens) + + # make attention mask and padding mask + y_attention_mask = ( + torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1) + .bool() + .to(y.device) + ) + + x_padding_mask = torch.full((1, x_lens[0]), False).to(x.device) + y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device) + + # entering the generation stage + # starting from line 708 + codebook_eog = [False] * self.args.n_codebooks + generated = [] # doesn't contain any empty token, contain eog + cur_generated = [] + # say 0 is empty, 4 is eog + # tensor([[ 1, 2, 3, 4, 0, 0], + # [ 0, 1, 2, 3, 4, 0], + # [ 0, 0, 1, 2, 3, 4]]) + num_gen = [] + cur_num_gen = 0 + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + # silence_tokens = [1388,1898,131] # [1388, 2045, 2041, 1996] + # silence_tokens = [] + consec_silence_count = 0 + prev_token = None + ##################### silence repetition handling ##################### + ##################### silence repetition handling ##################### + + def sample_helper( + n_eog, + logits, + codebook_eog, + top_k, + top_p, + min_p, + temperature, + prev_token, + consec_silence_count, + stop_repetition, + silence_tokens, + cur_num_gen, + ): + if n_eog == 0: + logits_adjust = logits + for jj in range(1, self.args.n_codebooks): + logits_adjust[jj][eog_inference] = -10000 + logits_adjust[jj][self.args.empty_token] = -10000 + if ( + cur_num_gen <= self.args.encodec_sr // 5 + ): # this shouldn't happen, but just in case the model stopped too early + logits_adjust[0][eog_inference] = -10000 + ##################### silence repetition handling ##################### + if ( + stop_repetition > 0 + and prev_token in silence_tokens + and consec_silence_count > stop_repetition + ): + if logits_adjust[0, prev_token] < 0: + logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * ( + consec_silence_count - (stop_repetition - 1) + ) + else: + logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / ( + consec_silence_count - (stop_repetition - 1) + ) + ##################### silence repetition handling ##################### + samples = topk_sampling( + logits_adjust, + top_k=top_k, + top_p=top_p, + min_p=min_p, + temperature=temperature, + ) # [K, 1] + assert samples.shape == torch.Size( + (self.args.n_codebooks, 1) + ), f"samples.shape: {samples.shape}" + if cur_num_gen < self.args.n_codebooks - 1: + for jj in range(1, self.args.n_codebooks - cur_num_gen): + samples[-jj, 0] = self.args.empty_token + + if ( + ( + samples[0, 0] == eog_inference + or torch.argmax(logits[0], dim=-1) == eog_inference + or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr // 4) + ) + or self.args.rope_base is not None + and not self.args.decoder_regular_rope + and self.args.progress_no_multiple + and cur_num_gen + > ( + tgt_y_lens[0] + + self.args.encodec_sr * getattr(self.args, "extra_cutoff", 5) + ) + ): + # last one condition in the first bracket means y is already too long, shouldn't happen, but put it here + # the second bracket means we are using progress-monitoring RoPE, but the model is generating excessively long sequence (5 seconds more than specified), in which case we terminate the generation + samples[0, 0] = eog_inference + codebook_eog[0] = True + ##################### silence repetition handling ##################### + if samples[0, 0] in silence_tokens and samples[0, 0] == prev_token: + consec_silence_count += 1 + else: + consec_silence_count = 0 + prev_token = samples[0, 0] + ##################### silence repetition handling ##################### + return samples, codebook_eog, prev_token, consec_silence_count + else: + assert ( + sum(codebook_eog[i] for i in range(n_eog)) == n_eog + ), f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}" + logits_adjust = logits + for jj in range(n_eog + 1, self.args.n_codebooks): + logits_adjust[jj][eog_inference] = -10000 + logits_adjust[jj][self.args.empty_token] = -10000 + samples = topk_sampling( + logits_adjust, + top_k=top_k, + top_p=top_p, + min_p=min_p, + temperature=temperature, + ) # [K, 1] + for jj in range(n_eog): + samples[jj, 0] = self.args.empty_token + samples[n_eog, 0] = eog_inference + codebook_eog[n_eog] = True + return samples, codebook_eog, prev_token, consec_silence_count + + # prepare the cache placeholder + # n_layers, 2, bsz, num_heads, src_len, head_dim, 2 means [key, value] + past = ( + torch.ones( + [self.args.num_decoder_layers, 2, x.shape[0]], + device=x.device, + dtype=torch.float32, + ) + if kvcache + else None + ) + if self.args.enc_dec: + xa = self.encoder(src=x_input, src_key_padding_mask=x_padding_mask) + while True: + if self.args.dec: + y_out, present = self.dec_forward( + x_input, + x_lens, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + past=past, + ) + else: + y_out, present = self.enc_dec_forward( + xa, + x_attention_mask, + x_padding_mask, + y_input, + new_y_lens, + y_attention_mask, + y_padding_mask, + tgt_y_lens=tgt_y_lens, + past=past, + ) + if past != None: + past = ( + torch.cat([past, present.to(past.dtype)], dim=-2) + if past.ndim > 3 + else present.to(past.dtype) + ) + + y_out = y_out[:, -1:] # only take the last token + + logits = torch.stack( + [self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], + dim=1, + ) # [B K S card], B==S==1, so [1 K 1 card] + logits = logits.squeeze(0).squeeze(1) # [K card] + assert logits.shape == torch.Size( + (self.args.n_codebooks, self.n_audio_tokens[0]) + ), f"{logits.shape}" + + n_eog = sum(codebook_eog) + assert n_eog < self.args.n_codebooks + if ( + self.args.eos > 0 + ): # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans + for jj in range(self.args.n_codebooks): + logits[jj][self.args.eog] = -10000.0 + + samples, codebook_eog, prev_token, consec_silence_count = sample_helper( + n_eog, + logits, + codebook_eog, + top_k, + top_p, + min_p, + temperature, + prev_token, + consec_silence_count, + stop_repetition, + silence_tokens, + cur_num_gen, + ) + # samples.shape is [K,1] + # ge samples_emb + samples_emb = torch.stack( + [ + self.audio_embedding[k](samples[k]) + for k in range(self.args.n_codebooks) + ], + dim=0, + ) # [K,1,D] + samples_emb = samples_emb.sum(dim=0, keepdim=True) # [1,1,D] + + cur_num_gen += 1 + cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K] + + if ( + sum(codebook_eog) == self.args.n_codebooks + ): # generation for the current span is done + codebook_eog = [False] * self.args.n_codebooks + num_gen.append(cur_num_gen) + cur_num_gen = 0 + generated.append(cur_generated) + cur_generated = [] + break + else: + assert samples_emb.shape == torch.Size( + (1, 1, self.args.d_model) + ), f"samples_emb.shape: {samples_emb.shape}" + + embedded_y = torch.cat([embedded_y, samples_emb], dim=1) + new_y_lens = torch.LongTensor([embedded_y.shape[1]]).to(y.device) + y_input = self.audio_positional_embedding(embedded_y, tgt_y_lens) # [B T D] + # make attention mask and padding mask + y_attention_mask = ( + torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1) + .bool() + .to(y.device) + ) + y_padding_mask = torch.full((1, new_y_lens[0]), False).to(y.device) + + assert len(generated) == 1, f"len(generated): {len(generated)}" + + # revert the pattern + flatten_gen = [] + for l, orig_span in enumerate(generated): + span = torch.stack(orig_span, dim=0) # [T, K] + span = span.transpose(1, 0) # [K, T] + assert span.shape[0] == self.args.n_codebooks, span.shape + unshifted_span = [] + for j, s in enumerate(span): + start_from = j + end_at = -(self.args.n_codebooks - start_from) + unshifted_span.append(s[start_from:end_at]) + unshifted_span = torch.stack(unshifted_span, dim=0) + + assert ( + unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks + ), f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}" + + flatten_gen.append(unshifted_span) + assert len(flatten_gen) == 1, len(flatten_gen) + + # combine + res = [y[0], flatten_gen[0]] + res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T] + expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen]) + assert res.shape == torch.Size( + (1, self.args.n_codebooks, expected_y_len) + ), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}" + + if self.args.special_first: + res = res - int(self.args.n_special) + flatten_gen = flatten_gen - int(self.args.n_special) + return res, flatten_gen[0].unsqueeze(0)