Skip to content

sarulab-speech/matcha-tts

Repository files navigation

Matcha-TTS For Japanese

Setup

  1. uvをインストールしてください
  2. 仮想環境を作成し、ライブラリをインストール
uv sync
  1. wandbにアカウント作成
  2. CLIでwandbにログイン
uv run wandb login

Usage

Download Dataset

uv run scripts/write_shar.py
  • JSUTコーパスをダウンロードし、lhotse shar形式で保存
  • config/data/jsut.yamldataset.shar_dirを適宜変更してください
解説

hydraを用いてパラメータ管理しています。

scripts/write_shar.py

@hydra.main(config_path="../config", config_name="default", version_base=None) # config/default.yaml が読み込まれる
def main(cfg: DictConfig) -> None:
    corpus = hydra.utils.instantiate(cfg.data.dataset) # cfg.data.dataset の内容をもとにインスタンス化
    corpus.write_shar()

実際の処理内容は以下の2ファイルを見ればわかります。

  • src/matcha/data/corpus/base.py
  • src/matcha/data/corpus/jsut.py
  • tests/data/corpus/test_jsut.py

音声をlhotse.cutオブジェクトとして保存します。実際のオブジェクトはこんな感じで、音声とメタデータがセットになっています。

MonoCut(id='BASIC5000_0001', start=0, duration=3.19, channel=0, supervisions=[SupervisionSegment(id='transcript_BASIC5000_0001', recording_id='recording_BASIC5000_0001', start=0, duration=3.19, channel=0, text='水をマレーシアから買わなくてはならないのです。', language='ja', speaker='JSUT', gender='female', custom=None, alignment=None)], features=None, recording=Recording(id='recording_BASIC5000_0001', sources=[AudioSource(type='memory', channels=[0], source='<binary-data>')], sampling_rate=48000, num_samples=153120, duration=3.19, channel_ids=[0], transforms=None), custom=None)

以下のようなスクリプトで確認できます。

from pathlib import Path

from lhotse import CutSet

if __name__ == "__main__":
    shar_dir = Path("shar/jsut")
    cut_paths = sorted(map(str, shar_dir.glob("cuts.*.jsonl.gz")))
    recording_paths = sorted(map(str, shar_dir.glob("recording.*.tar")))
    cuts = CutSet.from_shar({"cuts": cut_paths, "recording": recording_paths})

    for cut in cuts.data:
        print(cut)
        audio = cut.load_audio()
        print(audio)
        break

Preprocess

uv run scripts/preprocess.py
  • JSUTコーパスからメルスペクトログラムなどの特徴量を抽出し、webdataset形式で保存します
  • config/data/jsut.yamlpreprocess.webdataset_dirを適宜変更してください
  • GPUを使用したほうが良いです
解説

src/matcha/data/preprocessor.py を見ればすべての処理が書いてあります。

shar形式て保存したコーパスを読み込み、それを前処理してwebdataset形式で保存しています。

def process_cut(self, cut: Cut) -> dict[str, Any]:
    assert cut.supervisions[0].text is not None
    text, cleaned_text = self.get_text(
        cut.supervisions[0].text, add_blank=self.cfg.add_blank
    )
    mel = self.get_mel(cut)

    return {
        "__key__": uuid.uuid1().hex,
        "x.pth": wds.torch_dumps(text), # トークン列にしたテキスト
        "y.pth": wds.torch_dumps(mel), # メルスペクトログラム
        "spk.txt": cut.supervisions[0].speaker, # 話者
        "x_text.txt": "".join(cleaned_text), # テキスト
    }

tests/data/test_preprocessor.py を見ると処理が掴みやすいかもしれません。

Training

uv run scripts/train.py
  • もし複数GPUを用いて学習する場合は、config/data/jsut.yamldatamodule.use_ddpTrueにしてください
解説

PyTorch Lightningを用いた実装になっています。

LightningDataModule

src/matcha/data/datamodule.pyに実装されています。

Preprocessで作ったwebdatasetを読みこみ、collate_fnによってバッチ化しています。

LightningModule

src/matcha/model/lightning_module.pyに実装されています。

training_stepvalidation_stepが実行されます。

Infer

uv run scripts/infer.py --ckpt_path <CKPT_PATH> --text "<INPUT TEXT>"

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages