From efa894c470f55530b56f653218be6aeb768d0d97 Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Fri, 15 May 2026 16:44:49 -0700 Subject: [PATCH 1/2] fix: save checkpoints in the same was as fs2 save the last checkpoint, and also the top-k with the lowest Mel validation loss --- styletts2/cli/train.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/styletts2/cli/train.py b/styletts2/cli/train.py index a5dc90a7..d48c4933 100644 --- a/styletts2/cli/train.py +++ b/styletts2/cli/train.py @@ -84,14 +84,28 @@ def train( tb_logger = TensorBoardLogger(save_dir=log_dir, name="tensorboard", version="") - checkpoint_cb = ModelCheckpoint( + ckpt_filename = f"epoch_{mode.value[0]}_" + "{epoch:05d}" + # Always keep the last checkpoint regardless of performance. + last_ckpt_callback = ModelCheckpoint( dirpath=log_dir, - filename=f"epoch_{mode.value[0]}_" + "{epoch:05d}", - every_n_epochs=config.get("save_freq", 2), - save_top_k=-1, + filename=ckpt_filename, + save_top_k=1, save_last=True, + every_n_train_steps=tr.ckpt_steps, + every_n_epochs=tr.ckpt_epochs, + enable_version_counter=True, + save_on_train_epoch_end=True, + ) + # Keep only the top-k checkpoints ranked by val/mel (lower is better). + monitored_ckpt_callback = ModelCheckpoint( + dirpath=log_dir, + filename=ckpt_filename, monitor="val/mel", mode="min", + save_top_k=tr.save_top_k_ckpts, + every_n_train_steps=tr.ckpt_steps, + every_n_epochs=tr.ckpt_epochs, + enable_version_counter=False, ) lr_monitor = LearningRateMonitor(logging_interval="step") @@ -119,7 +133,7 @@ def train( strategy=resolved_strategy, precision=precision, logger=tb_logger, - callbacks=[checkpoint_cb, lr_monitor], + callbacks=[monitored_ckpt_callback, last_ckpt_callback, lr_monitor], log_every_n_steps=config.get("log_interval", 10), enable_progress_bar=True, ) From b6bf3d602173289b723516af1b9bdaea7a67fc18 Mon Sep 17 00:00:00 2001 From: Aidan Pine Date: Fri, 15 May 2026 17:10:51 -0700 Subject: [PATCH 2/2] feat: add text logging to tensorboard also log up to 7 audio files to tensorboard --- styletts2/lightning.py | 92 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 86 insertions(+), 6 deletions(-) diff --git a/styletts2/lightning.py b/styletts2/lightning.py index 600a4c17..9fab698d 100644 --- a/styletts2/lightning.py +++ b/styletts2/lightning.py @@ -17,6 +17,7 @@ from .modules.diffusion.sampler import ADPM2Sampler, DiffusionSampler, KarrasSchedule from .modules.slmadv import SLMAdversarialLoss from .pretrained.plbert.util import load_plbert +from .text_utils import symbols as _text_symbols from .utils import ( get_data_path_list, get_image, @@ -393,8 +394,13 @@ def on_train_epoch_end(self): ) self._running_std.clear() + _MAX_VAL_AUDIO = 7 + + def on_validation_epoch_start(self): + self._val_batch = None + def on_validation_epoch_end(self): - if not self.trainer.is_global_zero or not hasattr(self, "_val_batch"): + if not self.trainer.is_global_zero or self._val_batch is None: return tb = self.logger.experiment @@ -406,7 +412,7 @@ def on_validation_epoch_end(self): tb.add_figure( "eval/attn", get_image(b["s2s_attn"][0].numpy().squeeze()), epoch ) - for bib in range(min(len(b["asr"]), 7)): + for bib in range(min(len(b["asr"]), self._MAX_VAL_AUDIO)): mel_length = int(b["mel_input_length"][bib].item()) gt = b["mels"][bib, :, :mel_length].unsqueeze(0).to(self.device) en = ( @@ -425,6 +431,11 @@ def on_validation_epoch_end(self): sample_rate=self.sr, ) if epoch == 0: + text_label = "".join( + _text_symbols[i] + for i in b["texts"][bib, : b["input_lengths"][bib]].tolist() + ) + tb.add_text(f"text/y{bib}", text_label, epoch) tb.add_audio( f"gt/y{bib}", b["waves"][bib].squeeze(), @@ -434,7 +445,7 @@ def on_validation_epoch_end(self): else: if epoch < self.joint_epoch and self.mode != "finetune": - for bib in range(min(len(b["asr"]), 6)): + for bib in range(min(len(b["asr"]), self._MAX_VAL_AUDIO)): mel_length = int(b["mel_input_length"][bib].item()) gt = b["mels"][bib, :, :mel_length].unsqueeze(0).to(self.device) en = ( @@ -447,6 +458,7 @@ def on_validation_epoch_end(self): s = self.style_encoder(gt.unsqueeze(1)) real_norm = log_norm(gt.unsqueeze(1)).squeeze(1) y_rec = self.decoder(en, F0_real, real_norm, s) + tb.add_audio( f"eval/y{bib}", y_rec.cpu().numpy().squeeze(), @@ -470,6 +482,13 @@ def on_validation_epoch_end(self): ) if epoch == 0: + text_label = "".join( + _text_symbols[i] + for i in b["texts"][ + bib, : b["input_lengths"][bib] + ].tolist() + ) + tb.add_text(f"text/y{bib}", text_label, epoch) tb.add_audio( f"gt/y{bib}", b["waves"][bib].squeeze(), @@ -493,7 +512,7 @@ def on_validation_epoch_end(self): t_en = self.text_encoder(texts, input_lengths, text_mask) - for bib in range(min(d_en.size(0), 6)): + for bib in range(min(d_en.size(0), self._MAX_VAL_AUDIO)): noise = torch.randn((1, 256), device=self.device).unsqueeze(1) sampler_kwargs = dict( noise=noise, @@ -546,6 +565,20 @@ def on_validation_epoch_end(self): epoch, sample_rate=self.sr, ) + if epoch == 0: + text_label = "".join( + _text_symbols[i] + for i in b["texts"][ + bib, : b["input_lengths"][bib] + ].tolist() + ) + tb.add_text(f"text/y{bib}", text_label, epoch) + tb.add_audio( + f"gt/y{bib}", + b["waves"][bib].squeeze(), + epoch, + sample_rate=self.sr, + ) # ------------------------------------------------------------------ # Optimizer helpers @@ -1252,13 +1285,34 @@ def _validate_first(self, batch, batch_idx): ) if self.trainer.is_global_zero: - self._val_batch = dict( + new = dict( s2s_attn=s2s_attn.detach().cpu(), asr=asr.detach().cpu(), mels=mels.detach().cpu(), mel_input_length=mel_input_length.detach().cpu(), waves=waves, + texts=texts.detach().cpu(), + input_lengths=input_lengths.detach().cpu(), ) + if self._val_batch is None: + self._val_batch = new + elif len(self._val_batch["asr"]) < self._MAX_VAL_AUDIO: + n_have = len(self._val_batch["asr"]) + n_take = min(len(new["asr"]), self._MAX_VAL_AUDIO - n_have) + for key in ( + "s2s_attn", + "asr", + "mels", + "mel_input_length", + "texts", + "input_lengths", + ): + self._val_batch[key] = torch.cat( + [self._val_batch[key], new[key][:n_take]], dim=0 + ) + self._val_batch["waves"] = ( + self._val_batch["waves"] + new["waves"][:n_take] + ) def _validate_second(self, batch, batch_idx): device = self.device @@ -1331,7 +1385,7 @@ def _validate_second(self, batch, batch_idx): ) if self.trainer.is_global_zero: - self._val_batch = dict( + new = dict( asr=asr.detach().cpu(), mels=mels.detach().cpu(), mel_input_length=mel_input_length.detach().cpu(), @@ -1344,3 +1398,29 @@ def _validate_second(self, batch, batch_idx): text_mask=text_mask.detach().cpu(), ref_mels=ref_mels.detach().cpu() if self.multispeaker else None, ) + if self._val_batch is None: + self._val_batch = new + elif len(self._val_batch["asr"]) < self._MAX_VAL_AUDIO: + n_have = len(self._val_batch["asr"]) + n_take = min(len(new["asr"]), self._MAX_VAL_AUDIO - n_have) + for key in ( + "asr", + "mels", + "mel_input_length", + "p", + "d_en", + "bert_dur", + "texts", + "input_lengths", + "text_mask", + ): + self._val_batch[key] = torch.cat( + [self._val_batch[key], new[key][:n_take]], dim=0 + ) + self._val_batch["waves"] = ( + self._val_batch["waves"] + new["waves"][:n_take] + ) + if self.multispeaker and self._val_batch["ref_mels"] is not None: + self._val_batch["ref_mels"] = torch.cat( + [self._val_batch["ref_mels"], new["ref_mels"][:n_take]], dim=0 + )