Skip to content
13 changes: 13 additions & 0 deletions everyvoice/base_cli/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ def summarize_hfgl_generator_model(model_path: Path, checkpoint: dict) -> None:
print(summary(vocoder_model, None, verbose=0))


def summarize_styletts2_model(model_path: Path, checkpoint: dict) -> None:
from torchinfo import summary

from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.lightning import (
StyleTTS2Module,
)

model = StyleTTS2Module.load_from_checkpoint(model_path)
print(summary(model, None, verbose=0))


def summarize_unknown_model(model_path: Path, checkpoint: dict) -> None:
from tabulate import tabulate

Expand Down Expand Up @@ -194,6 +205,7 @@ def inspect(

if show_architecture:
checkpoint = load_checkpoint(model_path, minimal=False)

if "model_info" in checkpoint:
print(
"Inspecting checkpoint according to its model info:",
Expand All @@ -203,6 +215,7 @@ def inspect(
"FastSpeech2": summarize_fs2_model,
"HiFiGAN": summarize_hfgl_model,
"HiFiGANGenerator": summarize_hfgl_generator_model,
"StyleTTS2Module": summarize_styletts2_model,
}
summarizer = model_summarizers.get(checkpoint["model_info"]["name"], None)
if summarizer:
Expand Down
41 changes: 41 additions & 0 deletions everyvoice/base_cli/prediction_writing_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Generic base for synthesis-output-writing Lightning callbacks.

Shared by the FS2 and StyleTTS2 prediction-writing callback hierarchies.
Subclasses override ``on_predict_batch_end`` with format-specific logic.
"""

from __future__ import annotations

from pathlib import Path

from pytorch_lightning.callbacks import Callback


class BasePredictionWritingCallback(Callback):
"""Handles output-directory creation and output-path construction.

Concrete subclasses must implement ``on_predict_batch_end``.
"""

def __init__(
self,
save_dir: Path,
file_extension: str,
global_step: int,
include_global_step_in_filename: bool = False,
) -> None:
super().__init__()
self.file_extension = file_extension
self.global_step = f"ckpt={global_step}"
self.save_dir = save_dir
self.sep = "--"
self.include_global_step_in_filename = include_global_step_in_filename
self.save_dir.mkdir(parents=True, exist_ok=True)

def get_filename(self, basename: str, speaker: str, language: str) -> str:
name_parts = [basename, speaker, language, self.file_extension]
if self.include_global_step_in_filename:
name_parts.insert(-1, self.global_step)
path = self.save_dir / self.sep.join(name_parts)
path.parent.mkdir(parents=True, exist_ok=True)
return str(path)
Loading
Loading