forked from yl4579/StyleTTS2
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
95 lines (79 loc) · 3.1 KB
/
train.py
File metadata and controls
95 lines (79 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import shutil
from enum import Enum
from pathlib import Path
from typing import Annotated, Optional
import lightning as L
import typer
import yaml
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.strategies import DDPStrategy
from styletts2.lightning import StyleTTS2DataModule, StyleTTS2Module
class Mode(str, Enum):
first = "first"
second = "second"
finetune = "finetune"
def main(
config_path: Annotated[
Path, typer.Option("-p", "--config", help="Path to YAML config file.")
] = Path("configs/base.yml"),
mode: Annotated[
Mode, typer.Option("-m", "--mode", help="Training mode.")
] = Mode.first,
resume: Annotated[
Optional[Path],
typer.Option(help="Path to a Lightning checkpoint to resume from."),
] = None,
devices: Annotated[int, typer.Option(help="Number of GPUs to use.")] = 1,
strategy: Annotated[
str, typer.Option(help="Lightning strategy (auto, ddp, deepspeed, etc.).")
] = "auto",
precision: Annotated[
str, typer.Option(help="Floating-point precision (32, 16-mixed, bf16-mixed).")
] = "32",
):
config = yaml.safe_load(open(config_path))
log_dir = config["log_dir"]
os.makedirs(log_dir, exist_ok=True)
shutil.copy(config_path, os.path.join(log_dir, os.path.basename(config_path)))
epochs_key = (
"epochs_1st"
if mode == Mode.first
else ("epochs_2nd" if mode == Mode.second else "epochs")
)
max_epochs = config.get(epochs_key, 200)
datamodule = StyleTTS2DataModule(config, load_for_everyvoice=False)
model = StyleTTS2Module(config, mode=mode.value)
tb_logger = TensorBoardLogger(save_dir=log_dir, name="tensorboard", version="")
checkpoint_cb = ModelCheckpoint(
dirpath=log_dir,
filename=f'epoch_{mode.value[0]}_{"{epoch:05d}"}',
every_n_epochs=config.get("save_freq", 2),
save_top_k=-1,
save_last=True,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
# GAN training uses separate discriminator and generator backward passes per
# step, so different subsets of trainable parameters participate in each
# backward. find_unused_parameters=True is required for DDP correctness here.
# The freezing done in setup() already removes the truly-never-trained networks
# from DDP's parameter list, keeping the graph traversal cheap.
strategy_str = strategy
if strategy_str == "ddp" or (devices > 1 and strategy_str == "auto"):
resolved_strategy = DDPStrategy(find_unused_parameters=True)
else:
resolved_strategy = strategy_str
trainer = L.Trainer(
max_epochs=max_epochs,
devices=devices,
strategy=resolved_strategy,
precision=precision,
logger=tb_logger,
callbacks=[checkpoint_cb, lr_monitor],
log_every_n_steps=config.get("log_interval", 10),
enable_progress_bar=True,
)
trainer.fit(model, datamodule=datamodule, ckpt_path=resume)
if __name__ == "__main__":
typer.run(main)