diff --git a/everyvoice/base_cli/checkpoint.py b/everyvoice/base_cli/checkpoint.py index 73410239..33231a4f 100644 --- a/everyvoice/base_cli/checkpoint.py +++ b/everyvoice/base_cli/checkpoint.py @@ -122,6 +122,17 @@ def summarize_hfgl_generator_model(model_path: Path, checkpoint: dict) -> None: print(summary(vocoder_model, None, verbose=0)) +def summarize_styletts2_model(model_path: Path, checkpoint: dict) -> None: + from torchinfo import summary + + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.lightning import ( + StyleTTS2Module, + ) + + model = StyleTTS2Module.load_from_checkpoint(model_path) + print(summary(model, None, verbose=0)) + + def summarize_unknown_model(model_path: Path, checkpoint: dict) -> None: from tabulate import tabulate @@ -194,6 +205,7 @@ def inspect( if show_architecture: checkpoint = load_checkpoint(model_path, minimal=False) + if "model_info" in checkpoint: print( "Inspecting checkpoint according to its model info:", @@ -203,6 +215,7 @@ def inspect( "FastSpeech2": summarize_fs2_model, "HiFiGAN": summarize_hfgl_model, "HiFiGANGenerator": summarize_hfgl_generator_model, + "StyleTTS2Module": summarize_styletts2_model, } summarizer = model_summarizers.get(checkpoint["model_info"]["name"], None) if summarizer: diff --git a/everyvoice/base_cli/prediction_writing_callback.py b/everyvoice/base_cli/prediction_writing_callback.py new file mode 100644 index 00000000..2c79e4f1 --- /dev/null +++ b/everyvoice/base_cli/prediction_writing_callback.py @@ -0,0 +1,41 @@ +"""Generic base for synthesis-output-writing Lightning callbacks. + +Shared by the FS2 and StyleTTS2 prediction-writing callback hierarchies. +Subclasses override ``on_predict_batch_end`` with format-specific logic. +""" + +from __future__ import annotations + +from pathlib import Path + +from pytorch_lightning.callbacks import Callback + + +class BasePredictionWritingCallback(Callback): + """Handles output-directory creation and output-path construction. + + Concrete subclasses must implement ``on_predict_batch_end``. + """ + + def __init__( + self, + save_dir: Path, + file_extension: str, + global_step: int, + include_global_step_in_filename: bool = False, + ) -> None: + super().__init__() + self.file_extension = file_extension + self.global_step = f"ckpt={global_step}" + self.save_dir = save_dir + self.sep = "--" + self.include_global_step_in_filename = include_global_step_in_filename + self.save_dir.mkdir(parents=True, exist_ok=True) + + def get_filename(self, basename: str, speaker: str, language: str) -> str: + name_parts = [basename, speaker, language, self.file_extension] + if self.include_global_step_in_filename: + name_parts.insert(-1, self.global_step) + path = self.save_dir / self.sep.join(name_parts) + path.parent.mkdir(parents=True, exist_ok=True) + return str(path) diff --git a/everyvoice/cli.py b/everyvoice/cli.py index 6ea46e52..5cbbc332 100644 --- a/everyvoice/cli.py +++ b/everyvoice/cli.py @@ -35,6 +35,9 @@ from everyvoice.model.aligner.wav2vec2aligner.aligner.cli import ( extract_segments_from_textgrid, ) +from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.synthesize import ( + synthesize as synthesize_styletts2, +) from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.train import ( train as train_styletts2, ) @@ -571,9 +574,12 @@ def new_project( help=""" # Synthesize Help - - **from-text** --- This is the most common input for performing normal speech synthesis. It will take text or a filelist with text and produce either waveform audio or spectrogram. + - **from-text** --- This is the most common input for performing normal speech synthesis. It will take text or a filelist with text and produce either waveform audio or spectrogram. This option uses FastSpeech2 & HiFiGAN. If you want to do end-to-end synthesis with StyleTTS2, run `everyvoice synthesize text-to-wav` instead. + + - **text-to-wav** --- Synthesize audio directly from text using a trained end-to-end (StyleTTS2) model. Only supports the wav output format. - **from-spec** --- This is the model that turns your spectral features into audio. This type of synthesis is also known as copy synthesis and unless you know what you are doing, you probably don't want to do this. + """, ) @@ -585,6 +591,11 @@ def new_project( name="from-spec", )(synthesize_hfg) +synthesize_group.command( + name="text-to-wav", + short_help="Synthesize audio from text using a trained StyleTTS2 model", +)(synthesize_styletts2) + app.add_typer( synthesize_group, name="synthesize", @@ -669,9 +680,29 @@ def test(suite: TestSuites = typer.Argument("dev")): # pragma: no cover ) -@app.command() +# Add the demo commands +demo_group = typer.Typer( + pretty_exceptions_show_locals=False, + no_args_is_help=True, + context_settings={"help_option_names": ["-h", "--help"]}, + rich_markup_mode="markdown", + cls=TyperGroupOrderAsDeclared, + help=""" + # Demo Help + + - **text-to-spec** --- Launch an interactive Gradio demo for a two-stage (FastSpeech2 + HiFiGAN) model. + + - **text-to-wav** --- Launch an interactive Gradio demo for an end-to-end (StyleTTS2) model. + """, +) + + +@demo_group.command( + name="text-to-spec", + short_help="Launch a Gradio demo for a text-to-spec (FastSpeech2 + HiFiGAN) model", +) @merge_args(inference_base_command_interface) -def demo( +def demo_text_to_spec( text_to_spec_model: Annotated[ Path, typer_file_argument( @@ -699,19 +730,19 @@ def demo( ["all"], "--language", "-l", - help="Specify languages to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --language eng --language fin", + help="Specify languages to be included in the demo. Must be supported by your model. Example: everyvoice demo text-to-spec TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --language eng --language fin", ), speakers: list[str] = typer.Option( ["all"], "--speaker", "-s", - help="Specify speakers to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --speaker speaker_1 --speaker Sue", + help="Specify speakers to be included in the demo. Must be supported by your model. Example: everyvoice demo text-to-spec TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --speaker speaker_1 --speaker Sue", ), outputs: list[AllowedDemoOutputFormats] = typer.Option( ["all"], "--output-format", "-O", - help="Specify output formats to be included in the demo. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --output-format wav --output-format readalong-html", + help="Specify output formats to be included in the demo. Example: everyvoice demo text-to-spec TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --output-format wav --output-format readalong-html", ), output_dir: Path = typer_directory_option( "synthesis_output", @@ -764,6 +795,7 @@ def demo( ] = None, **kwargs, ): + """Launch an interactive Gradio demo for a two-stage (FastSpeech2 + HiFiGAN) model.""" if allowlist and denylist: raise typer.BadParameter( "You provided a value for both the allowlist and the denylist but you can only provide one." @@ -835,6 +867,177 @@ def demo( ) +@demo_group.command( + name="text-to-wav", + short_help="Launch a Gradio demo for an end-to-end (StyleTTS2) model", +) +def demo_text_to_wav( + model_path: Annotated[ + Path, + typer_file_argument(help="The path to a trained StyleTTS2 checkpoint (.ckpt)."), + ], + reference: Optional[Path] = typer.Option( + None, + "--reference", + "-r", + help="Path to a reference audio file that sets the default speaker style in the UI. " + "Use --speaker for a named multi-speaker dropdown.", + exists=True, + ), + speaker: list[str] = typer.Option( + [], + "--speaker", + "-s", + help="Named speaker defined as 'Display Name=path/to/audio.wav'. " + "Repeat the flag to add multiple speakers. " + "Their style encodings are pre-computed at startup and shown in a dropdown. " + "Example: everyvoice demo text-to-wav CONFIG MODEL --speaker 'Alice=alice.wav' --speaker 'Bob=bob.wav'", + ), + allowlist: Annotated[ + Optional[Path], + typer_file_option( + "--allowlist", + help="A plain text file containing a list of words or utterances to allow synthesizing. " + "Words/utterances should be separated by a new line in a plain text file. " + "All other words are disallowed.", + ), + ] = None, + denylist: Annotated[ + Optional[Path], + typer_file_option( + "--denylist", + help="A plain text file containing a list of words or utterances to disallow synthesizing. " + "Words/utterances should be separated by a new line in a plain text file. " + "All other words are allowed. " + "IMPORTANT: there are many ways to 'hack' the denylist that we do not protect against. " + "We suggest using the 'allowlist' instead for maximum security if you know the full list " + "of utterances you want to allow synthesis for.", + ), + ] = None, + output_dir: Path = typer_directory_option( + "synthesis_output", + "--output-dir", + "-o", + exists=False, + help="The directory where your synthesized audio should be written.", + ), + accelerator: str = typer.Option( + "auto", + "--accelerator", + "-a", + help="Specify the Pytorch Lightning accelerator to use.", + ), + port: int = typer.Option(7860, "--port", "-p", help="The port to run the demo on."), + share: bool = typer.Option( + False, + "--share", + help="Share the demo using Gradio's share feature. This will make the demo accessible from the internet.", + ), + server_name: str = typer.Option( + "0.0.0.0", + "--server-name", + "-n", + help="The server name to run the demo on. This is useful if you want to run the demo on a specific IP address.", + ), +): + """Launch an interactive Gradio demo for an end-to-end (StyleTTS2) model.""" + if not speaker and reference is None: + raise typer.BadParameter( + "Provide at least one --speaker 'Name=path/to/audio.wav' or a --reference path.", + param_hint="--speaker / --reference", + ) + if allowlist and denylist: + raise typer.BadParameter( + "You provided a value for both the allowlist and the denylist but you can only provide one." + ) + + # Parse --speaker "Display Name=path/to/audio.wav" entries + speakers_dict: dict[str, Path] = {} + for s in speaker: + if "=" not in s: + raise typer.BadParameter( + f"Speaker '{s}' must be in the format 'Display Name=path/to/audio.wav'.", + param_hint="--speaker", + ) + display_name, path_str = s.split("=", 1) + audio_path = Path(path_str.strip()).expanduser() + if not audio_path.exists(): + raise typer.BadParameter( + f"Speaker audio file not found: {audio_path}", + param_hint="--speaker", + ) + speakers_dict[display_name.strip()] = audio_path + + # --reference with no --speaker → reference-upload mode (no speaker dropdown) + default_reference = reference if not speakers_dict else None + + allowlist_data: list[str] = [] + denylist_data: list[str] = [] + if allowlist: + with open(allowlist) as f: + allowlist_data = [line.strip() for line in f if line.strip()] + if denylist: + with open(denylist) as f: + denylist_data = [line.strip() for line in f if line.strip()] + + import json + + import torch + + print("INFO - Starting the StyleTTS2 demo with the following parameters:") + print(f" - Model Path: {model_path}") + try: + _state = torch.load(model_path, map_location="cpu", weights_only=False) + _hp = _state.get("hyper_parameters", {}) + print(f" - Mode: {_hp.get('mode', 'unknown')} (from checkpoint)") + if "config" in _hp: + print(" - Checkpoint config:") + print(json.dumps(_hp["config"], indent=4, default=str)) + del _state + except Exception as e: + print(f" - (Could not read checkpoint config: {e})") + if speakers_dict: + for name, path in speakers_dict.items(): + print(f" - Speaker: {name} = {path}") + else: + print(f" - Reference: {reference}") + print(f" - Allowlist: {allowlist if allowlist else 'None'}") + print(f" - Denylist: {denylist if denylist else 'None'}") + print(f" - Output Dir: {output_dir}") + print(f" - Accelerator: {accelerator}") + print(f" - Port: {port}") + print(f" - Share: {share}") + print(f" - Server Name: {server_name}") + + with spinner("Loading software"): + from everyvoice.demo.app import create_demo_app_styletts2 + + with spinner("Loading model"): + demo = create_demo_app_styletts2( + model_path=model_path, + output_dir=output_dir, + speakers=speakers_dict, + default_reference=default_reference, + accelerator=accelerator, + allowlist=allowlist_data, + denylist=denylist_data, + ) + + demo.launch( + share=share, + server_port=port, + server_name=server_name, + allowed_paths=[str(output_dir), tempfile.gettempdir()], + ) + + +app.add_typer( + demo_group, + name="demo", + short_help="Launch an interactive Gradio demo for your EveryVoice models", +) + + @app.command(hidden=True) def update_schemas( out_dir: Annotated[ diff --git a/everyvoice/demo/app.py b/everyvoice/demo/app.py index 01abba4c..13ff7554 100644 --- a/everyvoice/demo/app.py +++ b/everyvoice/demo/app.py @@ -474,6 +474,244 @@ def load_model_from_checkpoint( return model, vocoder_model, vocoder_config, device +def synthesize_audio_styletts2( + text: str, + speaker: "str | None", # selected display name from dropdown, or None in reference-upload mode + user_reference, # filepath from Gradio Audio component, or None if not uploaded / cleared + diffusion_steps: int, + embedding_scale: float, + acoustic_blend: float, + prosody_blend: float, + *, + module, + mel_transform, + device, + output_dir: Path, + speaker_ref_s: "dict[str, torch.Tensor]", # pre-computed at startup; empty in reference-only mode + default_ref_s: "torch.Tensor | None", # pre-computed from --reference; None in speaker mode + allowlist: list[str], + denylist: list[str], +) -> str: + """Synthesize one utterance with StyleTTS2 and return the path to the saved WAV.""" + if not text or not text.strip(): + raise gr.Error("Please provide text to synthesize.") + + norm_text = normalize_text(text) + if allowlist and norm_text not in allowlist: + raise gr.Error( + f"The text '{text}' is not allowed to be synthesized by this model. " + "Please contact the model owner." + ) + if denylist: + for word in norm_text.split(): + if word in denylist: + raise gr.Error( + f"The text '{text}' contains a word that is not allowed. " + "Please contact the model owner." + ) + + # Determine ref_s — prefer user-uploaded audio, then pre-loaded speaker, then default + if user_reference is not None: + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.synthesize import ( + load_reference_style, + ) + + try: + ref_s = load_reference_style( + module, mel_transform, Path(user_reference), device + ) + except Exception as e: + raise gr.Error(f"Could not load reference audio: {e}") + elif speaker is not None and speaker in speaker_ref_s: + ref_s = speaker_ref_s[speaker] + elif default_ref_s is not None: + ref_s = default_ref_s + else: + raise gr.Error( + "No reference audio available. Please upload a reference audio file." + ) + + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.text_utils import ( + TextCleaner, + ) + + text_cleaner = TextCleaner() + tokens = torch.LongTensor(text_cleaner(text)).unsqueeze(0).to(device) + if tokens.numel() == 0: + raise gr.Error(f"Text produced no tokens: {text!r}") + input_lengths = torch.LongTensor([tokens.size(1)]).to(device) + + try: + audio = module._synthesize_text( + tokens, + input_lengths, + ref_s=ref_s, + diffusion_steps=diffusion_steps, + embedding_scale=embedding_scale, + acoustic_blend=acoustic_blend, + prosody_blend=prosody_blend, + ) + except Exception as e: + raise gr.Error(str(e)) + + import soundfile as sf + + out_path = output_dir / (slugify(text[:50]) + ".wav") + sf.write(str(out_path), audio, module.sr) + return str(out_path) + + +def make_gradio_display_styletts2( + synthesize_fn, + speaker_list: "GradioChoices", + default_reference: "Path | None" = None, +) -> "gr.Blocks": + """Build the Gradio Blocks for the StyleTTS2 demo. + + When ``speaker_list`` is non-empty a speaker dropdown is shown and the + reference audio widget becomes an optional style override. When it is + empty the reference audio widget is the primary input (reference-upload + mode) and the ``speaker`` argument is pre-bound as ``None``. + """ + has_speakers = bool(speaker_list) + interactive_speaker = len(speaker_list) > 1 + + with gr.Blocks() as demo: + gr.Markdown("