diff --git a/everyvoice/base_cli/checkpoint.py b/everyvoice/base_cli/checkpoint.py index 73410239..33231a4f 100644 --- a/everyvoice/base_cli/checkpoint.py +++ b/everyvoice/base_cli/checkpoint.py @@ -122,6 +122,17 @@ def summarize_hfgl_generator_model(model_path: Path, checkpoint: dict) -> None: print(summary(vocoder_model, None, verbose=0)) +def summarize_styletts2_model(model_path: Path, checkpoint: dict) -> None: + from torchinfo import summary + + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.lightning import ( + StyleTTS2Module, + ) + + model = StyleTTS2Module.load_from_checkpoint(model_path) + print(summary(model, None, verbose=0)) + + def summarize_unknown_model(model_path: Path, checkpoint: dict) -> None: from tabulate import tabulate @@ -194,6 +205,7 @@ def inspect( if show_architecture: checkpoint = load_checkpoint(model_path, minimal=False) + if "model_info" in checkpoint: print( "Inspecting checkpoint according to its model info:", @@ -203,6 +215,7 @@ def inspect( "FastSpeech2": summarize_fs2_model, "HiFiGAN": summarize_hfgl_model, "HiFiGANGenerator": summarize_hfgl_generator_model, + "StyleTTS2Module": summarize_styletts2_model, } summarizer = model_summarizers.get(checkpoint["model_info"]["name"], None) if summarizer: diff --git a/everyvoice/base_cli/prediction_writing_callback.py b/everyvoice/base_cli/prediction_writing_callback.py new file mode 100644 index 00000000..2c79e4f1 --- /dev/null +++ b/everyvoice/base_cli/prediction_writing_callback.py @@ -0,0 +1,41 @@ +"""Generic base for synthesis-output-writing Lightning callbacks. + +Shared by the FS2 and StyleTTS2 prediction-writing callback hierarchies. +Subclasses override ``on_predict_batch_end`` with format-specific logic. +""" + +from __future__ import annotations + +from pathlib import Path + +from pytorch_lightning.callbacks import Callback + + +class BasePredictionWritingCallback(Callback): + """Handles output-directory creation and output-path construction. + + Concrete subclasses must implement ``on_predict_batch_end``. + """ + + def __init__( + self, + save_dir: Path, + file_extension: str, + global_step: int, + include_global_step_in_filename: bool = False, + ) -> None: + super().__init__() + self.file_extension = file_extension + self.global_step = f"ckpt={global_step}" + self.save_dir = save_dir + self.sep = "--" + self.include_global_step_in_filename = include_global_step_in_filename + self.save_dir.mkdir(parents=True, exist_ok=True) + + def get_filename(self, basename: str, speaker: str, language: str) -> str: + name_parts = [basename, speaker, language, self.file_extension] + if self.include_global_step_in_filename: + name_parts.insert(-1, self.global_step) + path = self.save_dir / self.sep.join(name_parts) + path.parent.mkdir(parents=True, exist_ok=True) + return str(path) diff --git a/everyvoice/cli.py b/everyvoice/cli.py index f71c1fe3..ab009fa8 100644 --- a/everyvoice/cli.py +++ b/everyvoice/cli.py @@ -35,6 +35,9 @@ from everyvoice.model.aligner.wav2vec2aligner.aligner.cli import ( extract_segments_from_textgrid, ) +from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.synthesize import ( + synthesize as synthesize_styletts2, +) from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.train import ( train as train_styletts2, ) @@ -64,7 +67,6 @@ synthesize as synthesize_hfg, ) from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.cli import train as train_hfg -from everyvoice.utils import spinner from everyvoice.wizard import ( PREPROCESSING_CONFIG_FILENAME_PREFIX, SPEC_TO_WAV_CONFIG_FILENAME_PREFIX, @@ -573,9 +575,12 @@ def new_project( help=""" # Synthesize Help - - **from-text** --- This is the most common input for performing normal speech synthesis. It will take text or a filelist with text and produce either waveform audio or spectrogram. + - **from-text** --- This is the most common input for performing normal speech synthesis. It will take text or a filelist with text and produce either waveform audio or spectrogram. This option uses FastSpeech2 & HiFiGAN. If you want to do end-to-end synthesis with StyleTTS2, run `everyvoice synthesize text-to-wav` instead. + + - **text-to-wav** --- Synthesize audio directly from text using a trained end-to-end (StyleTTS2) model. Only supports the wav output format. - **from-spec** --- This is the model that turns your spectral features into audio. This type of synthesis is also known as copy synthesis and unless you know what you are doing, you probably don't want to do this. + """, ) @@ -587,6 +592,11 @@ def new_project( name="from-spec", )(synthesize_hfg) +synthesize_group.command( + name="text-to-wav", + short_help="Synthesize audio from text using a trained StyleTTS2 model", +)(synthesize_styletts2) + app.add_typer( synthesize_group, name="synthesize", @@ -648,172 +658,440 @@ def inspect_checkpoint(model_path: Path): [("all", "all")] + [(i.name, i.value) for i in SynthesizeOutputFormats], ) +_VOCODER_CLASS_NAMES = {"HiFiGAN", "HiFiGANGenerator"} +_FS2_CLASS_NAMES = {"FastSpeech2"} +_STYLETTS2_CLASS_NAMES = {"StyleTTS2Module"} -@app.command() + +def _peek_model_class(checkpoint_path: Path) -> str: + """Load a checkpoint header and return the stored model class name. + + Returns an empty string for legacy checkpoints that predate the model_info field. + Raises typer.BadParameter if the file cannot be read as a PyTorch checkpoint. + """ + import torch + + try: + ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + except Exception as e: + raise typer.BadParameter( + f"Could not read checkpoint '{checkpoint_path}': {e}", + param_hint="CHECKPOINT", + ) + return ckpt.get("model_info", {}).get("name", "") + + +def _load_list_file(path: Optional[Path]) -> list[str]: + """Read a plain-text word/utterance list from *path*, one entry per line.""" + if path is None: + return [] + with open(path) as f: + return [line.strip() for line in f if line.strip()] + + +def _parse_ref_speakers(ref_speaker: list[str]) -> dict[str, Path]: + """Parse --ref-speaker 'Name=path' values into a display-name → Path mapping.""" + speakers_dict: dict[str, Path] = {} + for s in ref_speaker: + if "=" not in s: + raise typer.BadParameter( + f"--ref-speaker '{s}' must be in the format 'Display Name=path/to/audio.wav'.", + param_hint="--ref-speaker", + ) + display_name, path_str = s.split("=", 1) + audio_path = Path(path_str.strip()).expanduser() + if not audio_path.exists(): + raise typer.BadParameter( + f"Reference audio file not found: {audio_path}", + param_hint="--ref-speaker", + ) + speakers_dict[display_name.strip()] = audio_path + return speakers_dict + + +def _load_fs2_ui_config(ui_config_file: Optional[Path]) -> "dict | None": + """Read and JSON-parse the FS2 UI config file, or return None if not given.""" + if ui_config_file is None: + print(" - UI Config file: None") + return None + print(f" - UI Config file: {ui_config_file}") + with open(ui_config_file) as f: + try: + ui_config_json = json.load(f) + print("\t config loaded") + return ui_config_json + except Exception as e: + raise typer.BadParameter( + f"Your config file {ui_config_file} has errors\n {e}" + ) + + +def _run_styletts2_demo( + checkpoint: Path, + ref_speaker: list[str], + reference: Optional[Path], + speakers: list[str], + vocoder: Optional[Path], + allowlist: Optional[Path], + denylist: Optional[Path], + allowlist_data: list[str], + denylist_data: list[str], + output_dir: Path, + accelerator: str, + port: int, + share: bool, + server_name: str, +) -> None: + """Validate StyleTTS2-specific options and launch the Gradio demo.""" + if vocoder is not None: + raise typer.BadParameter( + "StyleTTS2 does not use a separate vocoder. Remove --vocoder.", + param_hint="--vocoder", + ) + if speakers != ["all"]: + raise typer.BadParameter( + "StyleTTS2 does not use --speaker as a filter. " + "To define named speakers with reference audio use --ref-speaker 'Name=path/to/audio.wav'.", + param_hint="--speaker", + ) + if not ref_speaker and reference is None: + raise typer.BadParameter( + "Provide at least one --ref-speaker 'Name=path/to/audio.wav' or a --reference path.", + param_hint="--ref-speaker / --reference", + ) + + speakers_dict = _parse_ref_speakers(ref_speaker) + default_reference = reference if not speakers_dict else None + + import torch + + print("INFO - Starting the StyleTTS2 demo with the following parameters:") + print(f" - Checkpoint: {checkpoint}") + try: + _state = torch.load(checkpoint, map_location="cpu", weights_only=False) + _hp = _state.get("hyper_parameters", {}) + if "config" in _hp: + print(" - Checkpoint config:") + print(json.dumps(_hp["config"], indent=4, default=str)) + del _state + except Exception as e: + print(f" - (Could not read checkpoint config: {e})") + if speakers_dict: + for name, path in speakers_dict.items(): + print(f" - Ref speaker: {name} = {path}") + else: + print(f" - Reference: {reference}") + print(f" - Allowlist: {allowlist if allowlist else 'None'}") + print(f" - Denylist: {denylist if denylist else 'None'}") + print(f" - Output Dir: {output_dir}") + print(f" - Accelerator: {accelerator}") + print(f" - Port: {port}") + print(f" - Share: {share}") + print(f" - Server Name: {server_name}") + + from everyvoice.utils import spinner + + with spinner("Loading software"): + from everyvoice.demo.app import create_demo_app_styletts2 + + with spinner("Loading model"): + demo_app = create_demo_app_styletts2( + model_path=checkpoint, + output_dir=output_dir, + speakers=speakers_dict, + default_reference=default_reference, + accelerator=accelerator, + allowlist=allowlist_data, + denylist=denylist_data, + ) + + demo_app.launch( + share=share, + server_port=port, + server_name=server_name, + allowed_paths=[str(output_dir), tempfile.gettempdir()], + ) + + +def _run_fs2_demo( + checkpoint: Path, + vocoder: Optional[Path], + speakers: list[str], + languages: list[str], + outputs: list, + ui_config_file: Optional[Path], + ref_speaker: list[str], + reference: Optional[Path], + allowlist: Optional[Path], + denylist: Optional[Path], + allowlist_data: list[str], + denylist_data: list[str], + output_dir: Path, + accelerator: str, + port: int, + share: bool, + server_name: str, + **kwargs, +) -> None: + """Validate FastSpeech2-specific options and launch the Gradio demo.""" + if vocoder is None: + raise typer.BadParameter( + "FastSpeech2 requires a vocoder checkpoint. " + "Pass it with --vocoder path/to/hifigan.ckpt.", + param_hint="--vocoder", + ) + if ref_speaker: + raise typer.BadParameter( + "--ref-speaker is only used with StyleTTS2 models. " + "To filter FastSpeech2 speakers use --speaker.", + param_hint="--ref-speaker", + ) + if reference is not None: + raise typer.BadParameter( + "--reference is only used with StyleTTS2 models.", + param_hint="--reference", + ) + + print("INFO - Starting the demo with the following parameters:") + print(f" - Checkpoint: {checkpoint}") + print(f" - Vocoder: {vocoder}") + print(f" - Languages: {languages}") + print(f" - Speakers: {speakers}") + print(f" - Outputs: {outputs}") + print(f" - Output Dir: {output_dir}") + print(f" - Accelerator: {accelerator}") + print(f" - Allowlist: {allowlist_data if allowlist else 'None'}") + print(f" - Denylist: {denylist_data if denylist else 'None'}") + print(f" - Port: {port}") + print(f" - Share: {share}") + print(f" - Server Name: {server_name}") + ui_config_json = _load_fs2_ui_config(ui_config_file) + + from everyvoice.utils import spinner + + with spinner("Loading software"): + from everyvoice.demo.app import create_demo_app + + with spinner("Loading models"): + demo_app = create_demo_app( + text_to_spec_model_path=checkpoint, + spec_to_wav_model_path=vocoder, + languages=languages, + speakers=speakers, + outputs=outputs, + output_dir=output_dir, + accelerator=accelerator, + allowlist=allowlist_data, + denylist=denylist_data, + app_ui_config=ui_config_json, + **kwargs, + ) + + demo_app.launch( + share=share, + server_port=port, + server_name=server_name, + allowed_paths=[str(output_dir), tempfile.gettempdir()], + ) + + +@app.command( + name="demo", + short_help="Launch an interactive Gradio demo for any EveryVoice model", +) @merge_args(inference_base_command_interface) def demo( - text_to_spec_model: Annotated[ + checkpoint: Annotated[ Path, typer_file_argument( - help="The path to a trained text-to-spec (i.e., feature prediction) EveryVoice model." + help="Path to a trained EveryVoice checkpoint (.ckpt). " + "The model type is detected automatically from the checkpoint. " + "For FastSpeech2 models also pass --vocoder." ), ], - spec_to_wav_model: Annotated[ - Path, typer_file_argument(help="The path to a trained vocoder.") - ], - allowlist: Annotated[ + # ---- FastSpeech2 options ------------------------------------------------ + vocoder: Annotated[ Optional[Path], typer_file_option( - "--allowlist", - help="A plain text file containing a list of words or utterances to allow synthesizing. Words/utterances should be separated by a new line in a plain text file. All other words are disallowed.", + "--vocoder", + "-V", + help="[FastSpeech2] Path to a trained HiFiGAN vocoder checkpoint. " + "Required when the primary checkpoint is a FastSpeech2 model. " + "Not used with StyleTTS2 checkpoints.", + rich_help_panel="FastSpeech2 (text-to-spec) Options", ), ] = None, - denylist: Annotated[ - Optional[Path], - typer_file_option( - "--denylist", - help="A plain text file containing a list of words or utterances to disallow synthesizing. Words/utterances should be separated by a new line in a plain text file. All other words are allowed. IMPORTANT: there are many ways to 'hack' the denylist that we do not protect against. We suggest using the 'allowlist' instead for maximum security if you know the full list of utterances you want to allow synthesis for.", - ), - ] = None, - languages: list[str] = typer.Option( - ["all"], - "--language", - "-l", - help="Specify languages to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --language eng --language fin", - ), speakers: list[str] = typer.Option( ["all"], "--speaker", "-s", - help="Specify speakers to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --speaker speaker_1 --speaker Sue", + help="[FastSpeech2] Speaker names to expose in the demo UI. " + "Repeat the flag to include multiple speakers. " + "Defaults to all speakers in the model. " + "Not applicable to StyleTTS2 — use --ref-speaker instead.", + rich_help_panel="FastSpeech2 (text-to-spec) Options", + ), + languages: list[str] = typer.Option( + ["all"], + "--language", + "-l", + help="[FastSpeech2] Languages to expose in the demo UI. " + "Repeat the flag to include multiple languages. " + "Defaults to all languages in the model.", + rich_help_panel="FastSpeech2 (text-to-spec) Options", ), outputs: list[AllowedDemoOutputFormats] = typer.Option( ["all"], "--output-format", "-O", - help="Specify output formats to be included in the demo. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --output-format wav --output-format readalong-html", + help="[FastSpeech2] Output formats to expose in the demo UI.", + rich_help_panel="FastSpeech2 (text-to-spec) Options", + ), + ui_config_file: Annotated[ + Optional[Path], + typer_file_option( + "--ui-config-file", + "-C", + help="[FastSpeech2] JSON file to override UI labels (app_title, app_description, " + "app_instructions, speakers, languages, input_text_label, " + "duration_multiplier_label, language_label, speaker_label, " + "output_format_label, synthesize_label, file_output_label).", + rich_help_panel="FastSpeech2 (text-to-spec) Options", + ), + ] = None, + # ---- StyleTTS2 options -------------------------------------------------- + ref_speaker: list[str] = typer.Option( + [], + "--ref-speaker", + "-R", + help="[StyleTTS2] Named speaker with reference audio, in the format " + "'Display Name=path/to/audio.wav'. " + "Repeat the flag to add multiple speakers. " + "Their style encodings are pre-computed at startup and shown in a dropdown. " + "Example: --ref-speaker 'Eric=eric.wav' --ref-speaker 'Darlene=darlene.wav'", + rich_help_panel="StyleTTS2 (text-to-wav) Options", + ), + reference: Optional[Path] = typer.Option( + None, + "--reference", + "-r", + help="[StyleTTS2] Default reference audio file that sets the initial speaker style. " + "Use this for reference-upload mode (no speaker dropdown). " + "Use --ref-speaker instead to pre-define named speakers.", + exists=True, + rich_help_panel="StyleTTS2 (text-to-wav) Options", ), + # ---- Shared options ----------------------------------------------------- + allowlist: Annotated[ + Optional[Path], + typer_file_option( + "--allowlist", + help="Plain text file with allowed words or utterances (one per line). " + "All other input is rejected. Cannot be combined with --denylist.", + ), + ] = None, + denylist: Annotated[ + Optional[Path], + typer_file_option( + "--denylist", + help="Plain text file with disallowed words or utterances (one per line). " + "All other input is allowed. Cannot be combined with --allowlist. " + "IMPORTANT: there are many ways to bypass a denylist. " + "Use --allowlist for maximum security.", + ), + ] = None, output_dir: Path = typer_directory_option( "synthesis_output", "--output-dir", "-o", exists=False, - help="The directory where your synthesized audio should be written", + help="Directory where synthesized audio files are written.", ), accelerator: str = typer.Option( "auto", "--accelerator", "-a", - help="Specify the Pytorch Lightning accelerator to use", + help="PyTorch Lightning accelerator (e.g. 'auto', 'cpu', 'gpu').", ), - port: int = typer.Option(7860, "--port", "-p", help="The port to run the demo on."), + port: int = typer.Option(7860, "--port", "-p", help="Port to serve the demo on."), share: bool = typer.Option( False, "--share", - help="Share the demo using Gradio's share feature. This will make the demo accessible from the internet.", + help="Publish the demo via Gradio's share tunnel (accessible from the internet).", ), server_name: str = typer.Option( "0.0.0.0", "--server-name", "-n", - help="The server name to run the demo on. This is useful if you want to run the demo on a specific IP address.", + help="Host/IP address to bind the demo server to.", ), - ui_config_file: Annotated[ - Optional[Path], - typer_file_option( - "--ui-config-file", - "-C", - help="""A path to a configuration file that will be used to override parts of the default configuration for the demo UI. This is useful if you want to override some of the text in the UI. - The config file should be a valid JSON FORMAT. The expected optional values and types are: - - "app_title": string, - "app_description": string, - "app_instructions": string, - "languages": dict ["id":"name"], - "speakers": dict ["id":"name"], - "input_text_label": string, - "duration_multiplier_label": string, - "language_label": string, - "speaker_label": string, - "output_format_label":string, - "synthesize_label": string, - "file_output_label": string - - """, - ), - ] = None, **kwargs, ): - if allowlist and denylist: - raise typer.BadParameter( - "You provided a value for both the allowlist and the denylist but you can only provide one." - ) + """Launch an interactive Gradio demo for any EveryVoice model. - allowlist_data = [] - denylist_data = [] - ui_config_json: dict | None = None + The model type is detected automatically from the checkpoint. + Pass a single checkpoint for **StyleTTS2** (text-to-wav) models: - if allowlist: - with open(allowlist) as f: - allowlist_data = [x.strip() for x in f] + everyvoice demo path/to/styletts2.ckpt --ref-speaker 'Eric=eric.wav' - if denylist: - with open(denylist) as f: - denylist_data = [x.strip() for x in f] + Pass a FastSpeech2 (text-to-spec) checkpoint plus a vocoder (spec-to-wav) for **FastSpeech2 + HiFiGAN** models: - # print the parameters to the console - print("INFO - Starting the demo with the following parameters:") - print(f" - Text-to-Spec Model: {text_to_spec_model}") - print(f" - Spec-to-Wav Model: {spec_to_wav_model}") - print(f" - Languages: {languages}") - print(f" - Speakers: {speakers}") - print(f" - Outputs: {outputs}") - print(f" - Output Directory: {output_dir}") - print(f" - Accelerator: {accelerator}") - print(f" - Allowlist: {allowlist_data if allowlist else 'None'}") - print(f" - Denylist: {denylist_data if denylist else 'None'}") - print(f" - Port: {port}") - print(f" - Share: {share}") - print(f" - Server Name: {server_name}") - if ui_config_file: - print(f" - UI Config file path: {ui_config_file}") - with open(ui_config_file) as f: - try: - ui_config_json = json.load(f) - print("\t config loaded") - except Exception as e: - raise typer.BadParameter( - f"Your config file {ui_config_file} has errors\n {e}" - ) - else: - print(" - UI Config file path: None") + everyvoice demo path/to/fs2.ckpt --vocoder path/to/hifigan.ckpt + """ + if allowlist and denylist: + raise typer.BadParameter( + "Provide either --allowlist or --denylist, not both.", + ) - with spinner("Loading software"): - from everyvoice.demo.app import create_demo_app + model_class = _peek_model_class(checkpoint) - with spinner("Loading models"): - demo = create_demo_app( - text_to_spec_model_path=text_to_spec_model, - spec_to_wav_model_path=spec_to_wav_model, - languages=languages, - speakers=speakers, - outputs=outputs, - output_dir=output_dir, - accelerator=accelerator, - allowlist=allowlist_data, - denylist=denylist_data, - app_ui_config=ui_config_json, - **kwargs, + if model_class in _VOCODER_CLASS_NAMES: + raise typer.BadParameter( + f"'{checkpoint}' appears to be a standalone vocoder checkpoint ({model_class}). " + "Pass your FastSpeech2 checkpoint as the CHECKPOINT argument and provide " + "this vocoder with --vocoder.", + param_hint="CHECKPOINT", + ) + if model_class not in _FS2_CLASS_NAMES | _STYLETTS2_CLASS_NAMES: + raise typer.BadParameter( + f"Unrecognized model type '{model_class}' in checkpoint '{checkpoint}'. " + "Expected a FastSpeech2 or StyleTTS2 checkpoint.", + param_hint="CHECKPOINT", ) - demo.launch( + allowlist_data = _load_list_file(allowlist) + denylist_data = _load_list_file(denylist) + + shared = dict( + allowlist=allowlist, + denylist=denylist, + allowlist_data=allowlist_data, + denylist_data=denylist_data, + output_dir=output_dir, + accelerator=accelerator, + port=port, share=share, - server_port=port, server_name=server_name, - # explicitly give permission to gradio to write to the output directory and temp directory - allowed_paths=[str(output_dir), tempfile.gettempdir()], ) + if model_class in _STYLETTS2_CLASS_NAMES: + _run_styletts2_demo( + checkpoint, ref_speaker, reference, speakers, vocoder, **shared # type: ignore[arg-type] + ) + else: + _run_fs2_demo( + checkpoint, + vocoder, + speakers, + languages, + outputs, + ui_config_file, + ref_speaker, + reference, + **shared, # type: ignore[arg-type] + **kwargs, + ) + # Deferred full initialization to optimize the CLI, but still exposed for unit testing. SCHEMAS_TO_OUTPUT: dict[str, Any] = {} # dict[str, type[BaseModel]] diff --git a/everyvoice/demo/app.py b/everyvoice/demo/app.py index 01abba4c..13ff7554 100644 --- a/everyvoice/demo/app.py +++ b/everyvoice/demo/app.py @@ -474,6 +474,244 @@ def load_model_from_checkpoint( return model, vocoder_model, vocoder_config, device +def synthesize_audio_styletts2( + text: str, + speaker: "str | None", # selected display name from dropdown, or None in reference-upload mode + user_reference, # filepath from Gradio Audio component, or None if not uploaded / cleared + diffusion_steps: int, + embedding_scale: float, + acoustic_blend: float, + prosody_blend: float, + *, + module, + mel_transform, + device, + output_dir: Path, + speaker_ref_s: "dict[str, torch.Tensor]", # pre-computed at startup; empty in reference-only mode + default_ref_s: "torch.Tensor | None", # pre-computed from --reference; None in speaker mode + allowlist: list[str], + denylist: list[str], +) -> str: + """Synthesize one utterance with StyleTTS2 and return the path to the saved WAV.""" + if not text or not text.strip(): + raise gr.Error("Please provide text to synthesize.") + + norm_text = normalize_text(text) + if allowlist and norm_text not in allowlist: + raise gr.Error( + f"The text '{text}' is not allowed to be synthesized by this model. " + "Please contact the model owner." + ) + if denylist: + for word in norm_text.split(): + if word in denylist: + raise gr.Error( + f"The text '{text}' contains a word that is not allowed. " + "Please contact the model owner." + ) + + # Determine ref_s — prefer user-uploaded audio, then pre-loaded speaker, then default + if user_reference is not None: + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.synthesize import ( + load_reference_style, + ) + + try: + ref_s = load_reference_style( + module, mel_transform, Path(user_reference), device + ) + except Exception as e: + raise gr.Error(f"Could not load reference audio: {e}") + elif speaker is not None and speaker in speaker_ref_s: + ref_s = speaker_ref_s[speaker] + elif default_ref_s is not None: + ref_s = default_ref_s + else: + raise gr.Error( + "No reference audio available. Please upload a reference audio file." + ) + + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.text_utils import ( + TextCleaner, + ) + + text_cleaner = TextCleaner() + tokens = torch.LongTensor(text_cleaner(text)).unsqueeze(0).to(device) + if tokens.numel() == 0: + raise gr.Error(f"Text produced no tokens: {text!r}") + input_lengths = torch.LongTensor([tokens.size(1)]).to(device) + + try: + audio = module._synthesize_text( + tokens, + input_lengths, + ref_s=ref_s, + diffusion_steps=diffusion_steps, + embedding_scale=embedding_scale, + acoustic_blend=acoustic_blend, + prosody_blend=prosody_blend, + ) + except Exception as e: + raise gr.Error(str(e)) + + import soundfile as sf + + out_path = output_dir / (slugify(text[:50]) + ".wav") + sf.write(str(out_path), audio, module.sr) + return str(out_path) + + +def make_gradio_display_styletts2( + synthesize_fn, + speaker_list: "GradioChoices", + default_reference: "Path | None" = None, +) -> "gr.Blocks": + """Build the Gradio Blocks for the StyleTTS2 demo. + + When ``speaker_list`` is non-empty a speaker dropdown is shown and the + reference audio widget becomes an optional style override. When it is + empty the reference audio widget is the primary input (reference-upload + mode) and the ``speaker`` argument is pre-bound as ``None``. + """ + has_speakers = bool(speaker_list) + interactive_speaker = len(speaker_list) > 1 + + with gr.Blocks() as demo: + gr.Markdown("

EveryVoice StyleTTS2 Demo

") + with gr.Row(): + with gr.Column(): + inp_text = gr.Text( + placeholder="This text will be turned into speech.", + label="Input Text", + ) + inputs = [inp_text] + + if has_speakers: + inp_speaker = gr.Dropdown( + choices=speaker_list, + value=speaker_list[0][1], + interactive=interactive_speaker, + label="Speaker", + ) + inputs.append(inp_speaker) + else: + synthesize_fn = partial(synthesize_fn, speaker=None) + + inp_reference = gr.Audio( + value=( + str(default_reference) + if (not has_speakers and default_reference) + else None + ), + label=( + "Override Reference Audio (optional)" + if has_speakers + else "Reference Audio" + ), + type="filepath", + ) + inputs.append(inp_reference) + + with gr.Accordion("Advanced synthesis options", open=False): + inp_diffusion_steps = gr.Slider( + 1, 20, value=5, step=1, label="Diffusion Steps" + ) + inp_embedding_scale = gr.Slider( + 0.1, 3.0, value=1.0, step=0.1, label="Embedding Scale" + ) + inp_acoustic_blend = gr.Slider( + 0.0, 1.0, value=0.3, step=0.05, label="Acoustic Blend" + ) + inp_prosody_blend = gr.Slider( + 0.0, 1.0, value=0.7, step=0.05, label="Prosody Blend" + ) + inputs.extend( + [ + inp_diffusion_steps, + inp_embedding_scale, + inp_acoustic_blend, + inp_prosody_blend, + ] + ) + btn = gr.Button("Synthesize") + with gr.Column(): + out_audio = gr.Audio(format="wav", label="Output Audio") + + btn.click(synthesize_fn, inputs=inputs, outputs=[out_audio]) + return demo + + +def create_demo_app_styletts2( + model_path: Path, + output_dir: Path, + speakers: "dict[str, Path]", + default_reference: "Path | None" = None, + accelerator: str = "auto", + allowlist: list[str] = [], + denylist: list[str] = [], +) -> "gr.Blocks": + """Load a StyleTTS2 model and return a Gradio Blocks demo app. + + ``speakers`` maps display names to reference audio paths; their style + encodings are pre-computed at startup so each synthesis call is fast. + When ``speakers`` is empty the demo falls back to reference-upload mode + using ``default_reference`` as the pre-populated audio widget value. + """ + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.synthesize import ( + load_reference_style, + load_styletts2_model, + ) + from everyvoice.utils.heavy import get_device_from_accelerator + + require_ffmpeg() + device = get_device_from_accelerator(accelerator) + + model, mel_transform = load_styletts2_model(model_path, device) + output_dir.mkdir(exist_ok=True, parents=True) + + # Pre-compute style encodings for each named speaker + speaker_ref_s: dict[str, torch.Tensor] = {} + for display_name, audio_path in speakers.items(): + logger.info( + f"Pre-computing style encoding for speaker '{display_name}' from {audio_path}" + ) + speaker_ref_s[display_name] = load_reference_style( + model, mel_transform, audio_path, device + ) + + # Pre-compute the default reference encoding so synthesis never re-reads disk + default_ref_s: "torch.Tensor | None" = None + if default_reference is not None and default_reference.exists(): + logger.info( + f"Pre-computing style encoding for default reference {default_reference}" + ) + default_ref_s = load_reference_style( + model, mel_transform, default_reference, device + ) + + norm_allowlist = [normalize_text(w) for w in allowlist] + norm_denylist = [normalize_text(w) for w in denylist] + + synthesize_fn = partial( + synthesize_audio_styletts2, + module=model, + mel_transform=mel_transform, + device=device, + output_dir=output_dir, + speaker_ref_s=speaker_ref_s, + default_ref_s=default_ref_s, + allowlist=norm_allowlist, + denylist=norm_denylist, + ) + + speaker_list: GradioChoices = [(name, name) for name in speakers] + return make_gradio_display_styletts2( + synthesize_fn, + speaker_list, + default_reference=default_reference if not speakers else None, + ) + + def create_demo_app( text_to_spec_model_path: os.PathLike, spec_to_wav_model_path: os.PathLike, diff --git a/everyvoice/model/e2e/StyleTTS2_lightning b/everyvoice/model/e2e/StyleTTS2_lightning index 6058b7b2..bffedae1 160000 --- a/everyvoice/model/e2e/StyleTTS2_lightning +++ b/everyvoice/model/e2e/StyleTTS2_lightning @@ -1 +1 @@ -Subproject commit 6058b7b2c940c558e55eb0baa2c2f6fe21fb1f10 +Subproject commit bffedae128916f4ab44b963ad2043f9907d41874 diff --git a/everyvoice/model/feature_prediction/FastSpeech2_lightning b/everyvoice/model/feature_prediction/FastSpeech2_lightning index e6c366f0..4910f667 160000 --- a/everyvoice/model/feature_prediction/FastSpeech2_lightning +++ b/everyvoice/model/feature_prediction/FastSpeech2_lightning @@ -1 +1 @@ -Subproject commit e6c366f039965e1d63c5e2c4007150f43834f259 +Subproject commit 4910f667d79c56e30e4e5ae09ff89441c0699cb2 diff --git a/everyvoice/tests/stubs.py b/everyvoice/tests/stubs.py index 740dd121..774fcbfd 100644 --- a/everyvoice/tests/stubs.py +++ b/everyvoice/tests/stubs.py @@ -354,11 +354,15 @@ def __exit__(self, *_exc_info): def flatten_log(log_output: str) -> str: - """Replace newlines and other sequences of whitespace by a single space. + """Normalize Rich/Typer CLI output to a plain, single-line string. Usage: assert "some text" in flatten_log(captured_output) - Avoids having to use self.assertRegex everywhere just because of rich or pretty - printing of messages over multiple lines. + Strips ANSI escape codes (emitted when FORCE_COLOR=1 in CI), removes Rich + panel box-drawing characters (╭ ╮ ╰ ╯ │ ─), and collapses all remaining + whitespace to single spaces — so that substrings remain continuous regardless + of terminal-width word-wrap or colour formatting. """ + log_output = re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", log_output) + log_output = re.sub(r"[╭╮╰╯│─]+", " ", log_output) return re.sub(r"\s+", " ", log_output) diff --git a/everyvoice/tests/test_cli.py b/everyvoice/tests/test_cli.py index 3e377750..178fb1ba 100755 --- a/everyvoice/tests/test_cli.py +++ b/everyvoice/tests/test_cli.py @@ -463,12 +463,22 @@ def test_expensive_imports_are_tucked_away(self): self.assertNotIn(b"pydantic", result.stderr, msg.format("pydantic")) def test_demo_with_bad_args(self): + # No checkpoint → missing argument result = self.runner.invoke(app, ["demo"]) assert result.exit_code != 0 assert "Missing argument" in result.output + # Invalid --output-format value result = self.runner.invoke( - app, ["demo", os.devnull, os.devnull, "--output-format", "not-a-format"] + app, + [ + "demo", + os.devnull, + "--vocoder", + os.devnull, + "--output-format", + "not-a-format", + ], ) assert result.exit_code != 0 assert "Invalid value" in result.output @@ -583,6 +593,10 @@ def test_create_demo_app(self): # side_effect=self.mock_create_demo_app, # ): with ( + mock.patch( + "everyvoice.cli._peek_model_class", + return_value="FastSpeech2", + ), mock.patch( "everyvoice.demo.app.load_model_from_checkpoint", side_effect=self.mock_demo_load_model_from_checkpoint, @@ -606,6 +620,7 @@ def test_create_demo_app(self): [ "demo", str(spec_model_path), + "--vocoder", str(vocoder_path), "--port", port, @@ -615,9 +630,9 @@ def test_create_demo_app(self): ], ) assert result.exit_code == 0 - assert f" - Port: {port}" in result.output - assert " - Share: True" in result.output - assert f" - Server Name: {ip}" in result.output + assert f" - Port: {port}" in result.output + assert " - Share: True" in result.output + assert f" - Server Name: {ip}" in result.output def mock_demo_load_model_from_checkpoint( *_arg, **kwargs @@ -698,6 +713,10 @@ def test_create_demo_app_with_ui_config_file(self) -> None: ip = "123.456.78.90" with ( + mock.patch( + "everyvoice.cli._peek_model_class", + return_value="FastSpeech2", + ), mock.patch( "everyvoice.demo.app.load_model_from_checkpoint", side_effect=self.mock_demo_load_model_from_checkpoint, @@ -721,6 +740,7 @@ def test_create_demo_app_with_ui_config_file(self) -> None: [ "demo", str(spec_model_path), + "--vocoder", str(vocoder_path), "--port", port, @@ -785,6 +805,10 @@ def test_create_demo_app_with_malformed_ui_config_file(self): ip = "123.456.78.90" with ( + mock.patch( + "everyvoice.cli._peek_model_class", + return_value="FastSpeech2", + ), mock.patch( "everyvoice.demo.app.load_model_from_checkpoint", side_effect=self.mock_demo_load_model_from_checkpoint, @@ -808,6 +832,7 @@ def test_create_demo_app_with_malformed_ui_config_file(self): [ "demo", str(spec_model_path), + "--vocoder", str(vocoder_path), "--port", port, @@ -901,6 +926,113 @@ def test_create_demo_load_app_ui_labels_errors(self): str(cm.exception), ) + def test_demo_dispatch_styletts2_rejects_vocoder_flag(self): + """Passing --vocoder with a StyleTTS2 checkpoint should produce a clear error.""" + with tempfile.TemporaryDirectory() as tmpdir_str: + import torch + + tmpdir = Path(tmpdir_str) + fake_ckpt = tmpdir / "styletts2.ckpt" + torch.save( + { + "model_info": {"name": "StyleTTS2Module"}, + "hyper_parameters": {"mode": "second", "config": {}}, + "state_dict": {}, + }, + fake_ckpt, + ) + fake_vocoder = tmpdir / "hifigan.ckpt" + fake_vocoder.touch() + + result = self.runner.invoke( + app, + [ + "demo", + str(fake_ckpt), + "--vocoder", + str(fake_vocoder), + "--ref-speaker", + f"Eric={fake_ckpt}", # reuse fake_ckpt as a dummy audio file + ], + ) + assert result.exit_code != 0 + assert "StyleTTS2 does not use a separate vocoder" in flatten_log( + result.output + ) + + def test_demo_dispatch_fs2_requires_vocoder(self): + """Invoking demo with a FastSpeech2 checkpoint but no --vocoder should error.""" + with tempfile.TemporaryDirectory() as tmpdir_str: + tmpdir = Path(tmpdir_str) + _, spec_model_path = everyvoice.tests.model_stubs.get_stubbed_model( + tmpdir / "spec_model" + ) + + with mock.patch( + "everyvoice.cli._peek_model_class", + return_value="FastSpeech2", + ): + result = self.runner.invoke( + app, + ["demo", str(spec_model_path)], + ) + assert result.exit_code != 0 + assert "FastSpeech2 requires a vocoder checkpoint" in flatten_log( + result.output + ) + + def test_demo_dispatch_fs2_rejects_ref_speaker(self): + """Passing --ref-speaker with a FastSpeech2 checkpoint should produce a clear error.""" + with tempfile.TemporaryDirectory() as tmpdir_str: + tmpdir = Path(tmpdir_str) + _, vocoder_path = everyvoice.tests.model_stubs.get_stubbed_vocoder( + tmpdir / "vocoder" + ) + _, spec_model_path = everyvoice.tests.model_stubs.get_stubbed_model( + tmpdir / "spec_model" + ) + + with mock.patch( + "everyvoice.cli._peek_model_class", + return_value="FastSpeech2", + ): + result = self.runner.invoke( + app, + [ + "demo", + str(spec_model_path), + "--vocoder", + str(vocoder_path), + "--ref-speaker", + f"Eric={spec_model_path}", + ], + ) + assert result.exit_code != 0 + assert "--ref-speaker is only used with StyleTTS2" in flatten_log( + result.output + ) + + def test_demo_dispatch_vocoder_checkpoint_as_primary(self): + """Passing a HiFiGAN checkpoint as the primary CHECKPOINT should give a helpful error.""" + with tempfile.TemporaryDirectory() as tmpdir_str: + import torch + + tmpdir = Path(tmpdir_str) + fake_vocoder_ckpt = tmpdir / "hifigan.ckpt" + torch.save( + {"model_info": {"name": "HiFiGAN"}, "state_dict": {}}, + fake_vocoder_ckpt, + ) + + result = self.runner.invoke( + app, + ["demo", str(fake_vocoder_ckpt)], + ) + assert result.exit_code != 0 + assert "appears to be a standalone vocoder checkpoint" in flatten_log( + result.output + ) + def test_rename_speaker(self): with tempfile.TemporaryDirectory() as tmpdir_str: import torch