Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions styletts2/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,28 @@ def train(

tb_logger = TensorBoardLogger(save_dir=log_dir, name="tensorboard", version="")

checkpoint_cb = ModelCheckpoint(
ckpt_filename = f"epoch_{mode.value[0]}_" + "{epoch:05d}"
# Always keep the last checkpoint regardless of performance.
last_ckpt_callback = ModelCheckpoint(
dirpath=log_dir,
filename=f"epoch_{mode.value[0]}_" + "{epoch:05d}",
every_n_epochs=config.get("save_freq", 2),
save_top_k=-1,
filename=ckpt_filename,
save_top_k=1,
save_last=True,
every_n_train_steps=tr.ckpt_steps,
every_n_epochs=tr.ckpt_epochs,
enable_version_counter=True,
save_on_train_epoch_end=True,
)
# Keep only the top-k checkpoints ranked by val/mel (lower is better).
monitored_ckpt_callback = ModelCheckpoint(
dirpath=log_dir,
filename=ckpt_filename,
monitor="val/mel",
mode="min",
save_top_k=tr.save_top_k_ckpts,
every_n_train_steps=tr.ckpt_steps,
every_n_epochs=tr.ckpt_epochs,
enable_version_counter=False,
)
lr_monitor = LearningRateMonitor(logging_interval="step")

Expand Down Expand Up @@ -119,7 +133,7 @@ def train(
strategy=resolved_strategy,
precision=precision,
logger=tb_logger,
callbacks=[checkpoint_cb, lr_monitor],
callbacks=[monitored_ckpt_callback, last_ckpt_callback, lr_monitor],
log_every_n_steps=config.get("log_interval", 10),
enable_progress_bar=True,
)
Expand Down
92 changes: 86 additions & 6 deletions styletts2/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .modules.diffusion.sampler import ADPM2Sampler, DiffusionSampler, KarrasSchedule
from .modules.slmadv import SLMAdversarialLoss
from .pretrained.plbert.util import load_plbert
from .text_utils import symbols as _text_symbols
from .utils import (
get_data_path_list,
get_image,
Expand Down Expand Up @@ -393,8 +394,13 @@ def on_train_epoch_end(self):
)
self._running_std.clear()

_MAX_VAL_AUDIO = 7

def on_validation_epoch_start(self):
self._val_batch = None

def on_validation_epoch_end(self):
if not self.trainer.is_global_zero or not hasattr(self, "_val_batch"):
if not self.trainer.is_global_zero or self._val_batch is None:
return

tb = self.logger.experiment
Expand All @@ -406,7 +412,7 @@ def on_validation_epoch_end(self):
tb.add_figure(
"eval/attn", get_image(b["s2s_attn"][0].numpy().squeeze()), epoch
)
for bib in range(min(len(b["asr"]), 7)):
for bib in range(min(len(b["asr"]), self._MAX_VAL_AUDIO)):
mel_length = int(b["mel_input_length"][bib].item())
gt = b["mels"][bib, :, :mel_length].unsqueeze(0).to(self.device)
en = (
Expand All @@ -425,6 +431,11 @@ def on_validation_epoch_end(self):
sample_rate=self.sr,
)
if epoch == 0:
text_label = "".join(
_text_symbols[i]
for i in b["texts"][bib, : b["input_lengths"][bib]].tolist()
)
tb.add_text(f"text/y{bib}", text_label, epoch)
tb.add_audio(
f"gt/y{bib}",
b["waves"][bib].squeeze(),
Expand All @@ -434,7 +445,7 @@ def on_validation_epoch_end(self):

else:
if epoch < self.joint_epoch and self.mode != "finetune":
for bib in range(min(len(b["asr"]), 6)):
for bib in range(min(len(b["asr"]), self._MAX_VAL_AUDIO)):
mel_length = int(b["mel_input_length"][bib].item())
gt = b["mels"][bib, :, :mel_length].unsqueeze(0).to(self.device)
en = (
Expand All @@ -447,6 +458,7 @@ def on_validation_epoch_end(self):
s = self.style_encoder(gt.unsqueeze(1))
real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
y_rec = self.decoder(en, F0_real, real_norm, s)

tb.add_audio(
f"eval/y{bib}",
y_rec.cpu().numpy().squeeze(),
Expand All @@ -470,6 +482,13 @@ def on_validation_epoch_end(self):
)

if epoch == 0:
text_label = "".join(
_text_symbols[i]
for i in b["texts"][
bib, : b["input_lengths"][bib]
].tolist()
)
tb.add_text(f"text/y{bib}", text_label, epoch)
tb.add_audio(
f"gt/y{bib}",
b["waves"][bib].squeeze(),
Expand All @@ -493,7 +512,7 @@ def on_validation_epoch_end(self):

t_en = self.text_encoder(texts, input_lengths, text_mask)

for bib in range(min(d_en.size(0), 6)):
for bib in range(min(d_en.size(0), self._MAX_VAL_AUDIO)):
noise = torch.randn((1, 256), device=self.device).unsqueeze(1)
sampler_kwargs = dict(
noise=noise,
Expand Down Expand Up @@ -546,6 +565,20 @@ def on_validation_epoch_end(self):
epoch,
sample_rate=self.sr,
)
if epoch == 0:
text_label = "".join(
_text_symbols[i]
for i in b["texts"][
bib, : b["input_lengths"][bib]
].tolist()
)
tb.add_text(f"text/y{bib}", text_label, epoch)
tb.add_audio(
f"gt/y{bib}",
b["waves"][bib].squeeze(),
epoch,
sample_rate=self.sr,
)

# ------------------------------------------------------------------
# Optimizer helpers
Expand Down Expand Up @@ -1252,13 +1285,34 @@ def _validate_first(self, batch, batch_idx):
)

if self.trainer.is_global_zero:
self._val_batch = dict(
new = dict(
s2s_attn=s2s_attn.detach().cpu(),
asr=asr.detach().cpu(),
mels=mels.detach().cpu(),
mel_input_length=mel_input_length.detach().cpu(),
waves=waves,
texts=texts.detach().cpu(),
input_lengths=input_lengths.detach().cpu(),
)
if self._val_batch is None:
self._val_batch = new
elif len(self._val_batch["asr"]) < self._MAX_VAL_AUDIO:
n_have = len(self._val_batch["asr"])
n_take = min(len(new["asr"]), self._MAX_VAL_AUDIO - n_have)
for key in (
"s2s_attn",
"asr",
"mels",
"mel_input_length",
"texts",
"input_lengths",
):
self._val_batch[key] = torch.cat(
[self._val_batch[key], new[key][:n_take]], dim=0
)
self._val_batch["waves"] = (
self._val_batch["waves"] + new["waves"][:n_take]
)

def _validate_second(self, batch, batch_idx):
device = self.device
Expand Down Expand Up @@ -1331,7 +1385,7 @@ def _validate_second(self, batch, batch_idx):
)

if self.trainer.is_global_zero:
self._val_batch = dict(
new = dict(
asr=asr.detach().cpu(),
mels=mels.detach().cpu(),
mel_input_length=mel_input_length.detach().cpu(),
Expand All @@ -1344,3 +1398,29 @@ def _validate_second(self, batch, batch_idx):
text_mask=text_mask.detach().cpu(),
ref_mels=ref_mels.detach().cpu() if self.multispeaker else None,
)
if self._val_batch is None:
self._val_batch = new
elif len(self._val_batch["asr"]) < self._MAX_VAL_AUDIO:
n_have = len(self._val_batch["asr"])
n_take = min(len(new["asr"]), self._MAX_VAL_AUDIO - n_have)
for key in (
"asr",
"mels",
"mel_input_length",
"p",
"d_en",
"bert_dur",
"texts",
"input_lengths",
"text_mask",
):
self._val_batch[key] = torch.cat(
[self._val_batch[key], new[key][:n_take]], dim=0
)
self._val_batch["waves"] = (
self._val_batch["waves"] + new["waves"][:n_take]
)
if self.multispeaker and self._val_batch["ref_mels"] is not None:
self._val_batch["ref_mels"] = torch.cat(
[self._val_batch["ref_mels"], new["ref_mels"][:n_take]], dim=0
)