-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmain.py
More file actions
114 lines (101 loc) · 3.46 KB
/
main.py
File metadata and controls
114 lines (101 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import os
import warnings
from importlib import import_module
from pathlib import Path
from shutil import copy
import pytorch_lightning as pl
from omegaconf import OmegaConf
from utils import *
warnings.filterwarnings("ignore", category=FutureWarning)
def main(args):
# load configurations and set seed
config = OmegaConf.load(args.config)
output_dir = Path(args.output_dir)
pl.seed_everything(config.seed, workers=True)
# torch.set_grad_enabled(True)
torch.autograd.set_detect_anomaly(True)
# generate speaker-utterance meta information
if not (
os.path.exists(config.dirs.spk_meta + "spk_meta_trn.pk")
and os.path.exists(config.dirs.spk_meta + "spk_meta_dev.pk")
and os.path.exists(config.dirs.spk_meta + "spk_meta_eval.pk")
):
generate_spk_meta(config)
# configure paths
model_tag = os.path.splitext(os.path.basename(args.config))[0]
model_tag = output_dir / model_tag
model_save_path = model_tag / "weights"
model_save_path.mkdir(parents=True, exist_ok=True)
copy(args.config, model_tag / "config.conf")
_system = import_module("systems.{}".format(config.pl_system))
_system = getattr(_system, "System")
system = _system(config)
# Configure logging and callbacks
logger = [
pl.loggers.TensorBoardLogger(save_dir=model_tag, version=1, name="tsbd_logs"),
pl.loggers.csv_logs.CSVLogger(
save_dir=model_tag,
version=1,
name="csv_logs",
flush_logs_every_n_steps=config.progbar_refresh * 100,
),
]
callbacks = [
pl.callbacks.ModelSummary(max_depth=3),
pl.callbacks.LearningRateMonitor(logging_interval="step"),
pl.callbacks.ModelCheckpoint(
dirpath=model_save_path,
filename="{epoch}-{sasv_eer_dev:.5f}",
monitor="sasv_eer_dev",
mode="min",
every_n_epochs=config.val_interval_epoch,
save_top_k=config.save_top_k,
),
]
# Train / Evaluate
gpus = find_gpus(config.ngpus)
if gpus == -1:
raise ValueError("Required GPUs are not available")
os.environ["CUDA_VISIBLE_DEVICES"] = gpus
trainer = pl.Trainer(
accelerator="gpu",
callbacks=callbacks,
check_val_every_n_epoch=1,
devices=config.ngpus,
fast_dev_run=config.fast_dev_run,
gradient_clip_val=config.gradient_clip
if config.gradient_clip is not None
else 0,
limit_train_batches=1.0,
limit_val_batches=1.0,
logger=logger,
max_epochs=config.epoch,
num_sanity_val_steps=0,
progress_bar_refresh_rate=config.progbar_refresh, # 0 to disable
reload_dataloaders_every_n_epochs=config.epoch,
strategy="ddp",
sync_batchnorm=True,
val_check_interval=1.0, # 0.25 validates 4 times every epoch
# replace_sampler_ddp=False,
)
trainer.fit(system)
trainer.test(ckpt_path="best")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SASVC2022 Baseline framework.")
parser.add_argument(
"--config",
dest="config",
type=str,
help="configuration file",
required=True,
default="configs/Baseline2.conf",
)
parser.add_argument(
"--output_dir",
dest="output_dir",
type=str,
help="output directory for results",
default="./exp_result",
)
main(parser.parse_args())