- uvをインストールしてください
- 仮想環境を作成し、ライブラリをインストール
uv sync- wandbにアカウント作成
- CLIでwandbにログイン
uv run wandb loginuv run scripts/write_shar.py- JSUTコーパスをダウンロードし、lhotse shar形式で保存
config/data/jsut.yamlのdataset.shar_dirを適宜変更してください
解説
hydraを用いてパラメータ管理しています。
scripts/write_shar.py
@hydra.main(config_path="../config", config_name="default", version_base=None) # config/default.yaml が読み込まれる
def main(cfg: DictConfig) -> None:
corpus = hydra.utils.instantiate(cfg.data.dataset) # cfg.data.dataset の内容をもとにインスタンス化
corpus.write_shar()実際の処理内容は以下の2ファイルを見ればわかります。
src/matcha/data/corpus/base.pysrc/matcha/data/corpus/jsut.py- (
tests/data/corpus/test_jsut.py)
音声をlhotse.cutオブジェクトとして保存します。実際のオブジェクトはこんな感じで、音声とメタデータがセットになっています。
MonoCut(id='BASIC5000_0001', start=0, duration=3.19, channel=0, supervisions=[SupervisionSegment(id='transcript_BASIC5000_0001', recording_id='recording_BASIC5000_0001', start=0, duration=3.19, channel=0, text='水をマレーシアから買わなくてはならないのです。', language='ja', speaker='JSUT', gender='female', custom=None, alignment=None)], features=None, recording=Recording(id='recording_BASIC5000_0001', sources=[AudioSource(type='memory', channels=[0], source='<binary-data>')], sampling_rate=48000, num_samples=153120, duration=3.19, channel_ids=[0], transforms=None), custom=None)以下のようなスクリプトで確認できます。
from pathlib import Path
from lhotse import CutSet
if __name__ == "__main__":
shar_dir = Path("shar/jsut")
cut_paths = sorted(map(str, shar_dir.glob("cuts.*.jsonl.gz")))
recording_paths = sorted(map(str, shar_dir.glob("recording.*.tar")))
cuts = CutSet.from_shar({"cuts": cut_paths, "recording": recording_paths})
for cut in cuts.data:
print(cut)
audio = cut.load_audio()
print(audio)
breakuv run scripts/preprocess.py- JSUTコーパスからメルスペクトログラムなどの特徴量を抽出し、webdataset形式で保存します
config/data/jsut.yamlのpreprocess.webdataset_dirを適宜変更してください- GPUを使用したほうが良いです
解説
src/matcha/data/preprocessor.py を見ればすべての処理が書いてあります。
shar形式て保存したコーパスを読み込み、それを前処理してwebdataset形式で保存しています。
def process_cut(self, cut: Cut) -> dict[str, Any]:
assert cut.supervisions[0].text is not None
text, cleaned_text = self.get_text(
cut.supervisions[0].text, add_blank=self.cfg.add_blank
)
mel = self.get_mel(cut)
return {
"__key__": uuid.uuid1().hex,
"x.pth": wds.torch_dumps(text), # トークン列にしたテキスト
"y.pth": wds.torch_dumps(mel), # メルスペクトログラム
"spk.txt": cut.supervisions[0].speaker, # 話者
"x_text.txt": "".join(cleaned_text), # テキスト
}tests/data/test_preprocessor.py を見ると処理が掴みやすいかもしれません。
uv run scripts/train.py- もし複数GPUを用いて学習する場合は、
config/data/jsut.yamlのdatamodule.use_ddpをTrueにしてください
解説
PyTorch Lightningを用いた実装になっています。
src/matcha/data/datamodule.pyに実装されています。
Preprocessで作ったwebdatasetを読みこみ、collate_fnによってバッチ化しています。
src/matcha/model/lightning_module.pyに実装されています。
training_stepとvalidation_stepが実行されます。
uv run scripts/infer.py --ckpt_path <CKPT_PATH> --text "<INPUT TEXT>"