diff --git a/everyvoice/base_cli/checkpoint.py b/everyvoice/base_cli/checkpoint.py index 73410239..33231a4f 100644 --- a/everyvoice/base_cli/checkpoint.py +++ b/everyvoice/base_cli/checkpoint.py @@ -122,6 +122,17 @@ def summarize_hfgl_generator_model(model_path: Path, checkpoint: dict) -> None: print(summary(vocoder_model, None, verbose=0)) +def summarize_styletts2_model(model_path: Path, checkpoint: dict) -> None: + from torchinfo import summary + + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.lightning import ( + StyleTTS2Module, + ) + + model = StyleTTS2Module.load_from_checkpoint(model_path) + print(summary(model, None, verbose=0)) + + def summarize_unknown_model(model_path: Path, checkpoint: dict) -> None: from tabulate import tabulate @@ -194,6 +205,7 @@ def inspect( if show_architecture: checkpoint = load_checkpoint(model_path, minimal=False) + if "model_info" in checkpoint: print( "Inspecting checkpoint according to its model info:", @@ -203,6 +215,7 @@ def inspect( "FastSpeech2": summarize_fs2_model, "HiFiGAN": summarize_hfgl_model, "HiFiGANGenerator": summarize_hfgl_generator_model, + "StyleTTS2Module": summarize_styletts2_model, } summarizer = model_summarizers.get(checkpoint["model_info"]["name"], None) if summarizer: diff --git a/everyvoice/base_cli/prediction_writing_callback.py b/everyvoice/base_cli/prediction_writing_callback.py new file mode 100644 index 00000000..2c79e4f1 --- /dev/null +++ b/everyvoice/base_cli/prediction_writing_callback.py @@ -0,0 +1,41 @@ +"""Generic base for synthesis-output-writing Lightning callbacks. + +Shared by the FS2 and StyleTTS2 prediction-writing callback hierarchies. +Subclasses override ``on_predict_batch_end`` with format-specific logic. +""" + +from __future__ import annotations + +from pathlib import Path + +from pytorch_lightning.callbacks import Callback + + +class BasePredictionWritingCallback(Callback): + """Handles output-directory creation and output-path construction. + + Concrete subclasses must implement ``on_predict_batch_end``. + """ + + def __init__( + self, + save_dir: Path, + file_extension: str, + global_step: int, + include_global_step_in_filename: bool = False, + ) -> None: + super().__init__() + self.file_extension = file_extension + self.global_step = f"ckpt={global_step}" + self.save_dir = save_dir + self.sep = "--" + self.include_global_step_in_filename = include_global_step_in_filename + self.save_dir.mkdir(parents=True, exist_ok=True) + + def get_filename(self, basename: str, speaker: str, language: str) -> str: + name_parts = [basename, speaker, language, self.file_extension] + if self.include_global_step_in_filename: + name_parts.insert(-1, self.global_step) + path = self.save_dir / self.sep.join(name_parts) + path.parent.mkdir(parents=True, exist_ok=True) + return str(path) diff --git a/everyvoice/cli.py b/everyvoice/cli.py index f71c1fe3..ab009fa8 100644 --- a/everyvoice/cli.py +++ b/everyvoice/cli.py @@ -35,6 +35,9 @@ from everyvoice.model.aligner.wav2vec2aligner.aligner.cli import ( extract_segments_from_textgrid, ) +from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.synthesize import ( + synthesize as synthesize_styletts2, +) from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.train import ( train as train_styletts2, ) @@ -64,7 +67,6 @@ synthesize as synthesize_hfg, ) from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.cli import train as train_hfg -from everyvoice.utils import spinner from everyvoice.wizard import ( PREPROCESSING_CONFIG_FILENAME_PREFIX, SPEC_TO_WAV_CONFIG_FILENAME_PREFIX, @@ -573,9 +575,12 @@ def new_project( help=""" # Synthesize Help - - **from-text** --- This is the most common input for performing normal speech synthesis. It will take text or a filelist with text and produce either waveform audio or spectrogram. + - **from-text** --- This is the most common input for performing normal speech synthesis. It will take text or a filelist with text and produce either waveform audio or spectrogram. This option uses FastSpeech2 & HiFiGAN. If you want to do end-to-end synthesis with StyleTTS2, run `everyvoice synthesize text-to-wav` instead. + + - **text-to-wav** --- Synthesize audio directly from text using a trained end-to-end (StyleTTS2) model. Only supports the wav output format. - **from-spec** --- This is the model that turns your spectral features into audio. This type of synthesis is also known as copy synthesis and unless you know what you are doing, you probably don't want to do this. + """, ) @@ -587,6 +592,11 @@ def new_project( name="from-spec", )(synthesize_hfg) +synthesize_group.command( + name="text-to-wav", + short_help="Synthesize audio from text using a trained StyleTTS2 model", +)(synthesize_styletts2) + app.add_typer( synthesize_group, name="synthesize", @@ -648,172 +658,440 @@ def inspect_checkpoint(model_path: Path): [("all", "all")] + [(i.name, i.value) for i in SynthesizeOutputFormats], ) +_VOCODER_CLASS_NAMES = {"HiFiGAN", "HiFiGANGenerator"} +_FS2_CLASS_NAMES = {"FastSpeech2"} +_STYLETTS2_CLASS_NAMES = {"StyleTTS2Module"} -@app.command() + +def _peek_model_class(checkpoint_path: Path) -> str: + """Load a checkpoint header and return the stored model class name. + + Returns an empty string for legacy checkpoints that predate the model_info field. + Raises typer.BadParameter if the file cannot be read as a PyTorch checkpoint. + """ + import torch + + try: + ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + except Exception as e: + raise typer.BadParameter( + f"Could not read checkpoint '{checkpoint_path}': {e}", + param_hint="CHECKPOINT", + ) + return ckpt.get("model_info", {}).get("name", "") + + +def _load_list_file(path: Optional[Path]) -> list[str]: + """Read a plain-text word/utterance list from *path*, one entry per line.""" + if path is None: + return [] + with open(path) as f: + return [line.strip() for line in f if line.strip()] + + +def _parse_ref_speakers(ref_speaker: list[str]) -> dict[str, Path]: + """Parse --ref-speaker 'Name=path' values into a display-name → Path mapping.""" + speakers_dict: dict[str, Path] = {} + for s in ref_speaker: + if "=" not in s: + raise typer.BadParameter( + f"--ref-speaker '{s}' must be in the format 'Display Name=path/to/audio.wav'.", + param_hint="--ref-speaker", + ) + display_name, path_str = s.split("=", 1) + audio_path = Path(path_str.strip()).expanduser() + if not audio_path.exists(): + raise typer.BadParameter( + f"Reference audio file not found: {audio_path}", + param_hint="--ref-speaker", + ) + speakers_dict[display_name.strip()] = audio_path + return speakers_dict + + +def _load_fs2_ui_config(ui_config_file: Optional[Path]) -> "dict | None": + """Read and JSON-parse the FS2 UI config file, or return None if not given.""" + if ui_config_file is None: + print(" - UI Config file: None") + return None + print(f" - UI Config file: {ui_config_file}") + with open(ui_config_file) as f: + try: + ui_config_json = json.load(f) + print("\t config loaded") + return ui_config_json + except Exception as e: + raise typer.BadParameter( + f"Your config file {ui_config_file} has errors\n {e}" + ) + + +def _run_styletts2_demo( + checkpoint: Path, + ref_speaker: list[str], + reference: Optional[Path], + speakers: list[str], + vocoder: Optional[Path], + allowlist: Optional[Path], + denylist: Optional[Path], + allowlist_data: list[str], + denylist_data: list[str], + output_dir: Path, + accelerator: str, + port: int, + share: bool, + server_name: str, +) -> None: + """Validate StyleTTS2-specific options and launch the Gradio demo.""" + if vocoder is not None: + raise typer.BadParameter( + "StyleTTS2 does not use a separate vocoder. Remove --vocoder.", + param_hint="--vocoder", + ) + if speakers != ["all"]: + raise typer.BadParameter( + "StyleTTS2 does not use --speaker as a filter. " + "To define named speakers with reference audio use --ref-speaker 'Name=path/to/audio.wav'.", + param_hint="--speaker", + ) + if not ref_speaker and reference is None: + raise typer.BadParameter( + "Provide at least one --ref-speaker 'Name=path/to/audio.wav' or a --reference path.", + param_hint="--ref-speaker / --reference", + ) + + speakers_dict = _parse_ref_speakers(ref_speaker) + default_reference = reference if not speakers_dict else None + + import torch + + print("INFO - Starting the StyleTTS2 demo with the following parameters:") + print(f" - Checkpoint: {checkpoint}") + try: + _state = torch.load(checkpoint, map_location="cpu", weights_only=False) + _hp = _state.get("hyper_parameters", {}) + if "config" in _hp: + print(" - Checkpoint config:") + print(json.dumps(_hp["config"], indent=4, default=str)) + del _state + except Exception as e: + print(f" - (Could not read checkpoint config: {e})") + if speakers_dict: + for name, path in speakers_dict.items(): + print(f" - Ref speaker: {name} = {path}") + else: + print(f" - Reference: {reference}") + print(f" - Allowlist: {allowlist if allowlist else 'None'}") + print(f" - Denylist: {denylist if denylist else 'None'}") + print(f" - Output Dir: {output_dir}") + print(f" - Accelerator: {accelerator}") + print(f" - Port: {port}") + print(f" - Share: {share}") + print(f" - Server Name: {server_name}") + + from everyvoice.utils import spinner + + with spinner("Loading software"): + from everyvoice.demo.app import create_demo_app_styletts2 + + with spinner("Loading model"): + demo_app = create_demo_app_styletts2( + model_path=checkpoint, + output_dir=output_dir, + speakers=speakers_dict, + default_reference=default_reference, + accelerator=accelerator, + allowlist=allowlist_data, + denylist=denylist_data, + ) + + demo_app.launch( + share=share, + server_port=port, + server_name=server_name, + allowed_paths=[str(output_dir), tempfile.gettempdir()], + ) + + +def _run_fs2_demo( + checkpoint: Path, + vocoder: Optional[Path], + speakers: list[str], + languages: list[str], + outputs: list, + ui_config_file: Optional[Path], + ref_speaker: list[str], + reference: Optional[Path], + allowlist: Optional[Path], + denylist: Optional[Path], + allowlist_data: list[str], + denylist_data: list[str], + output_dir: Path, + accelerator: str, + port: int, + share: bool, + server_name: str, + **kwargs, +) -> None: + """Validate FastSpeech2-specific options and launch the Gradio demo.""" + if vocoder is None: + raise typer.BadParameter( + "FastSpeech2 requires a vocoder checkpoint. " + "Pass it with --vocoder path/to/hifigan.ckpt.", + param_hint="--vocoder", + ) + if ref_speaker: + raise typer.BadParameter( + "--ref-speaker is only used with StyleTTS2 models. " + "To filter FastSpeech2 speakers use --speaker.", + param_hint="--ref-speaker", + ) + if reference is not None: + raise typer.BadParameter( + "--reference is only used with StyleTTS2 models.", + param_hint="--reference", + ) + + print("INFO - Starting the demo with the following parameters:") + print(f" - Checkpoint: {checkpoint}") + print(f" - Vocoder: {vocoder}") + print(f" - Languages: {languages}") + print(f" - Speakers: {speakers}") + print(f" - Outputs: {outputs}") + print(f" - Output Dir: {output_dir}") + print(f" - Accelerator: {accelerator}") + print(f" - Allowlist: {allowlist_data if allowlist else 'None'}") + print(f" - Denylist: {denylist_data if denylist else 'None'}") + print(f" - Port: {port}") + print(f" - Share: {share}") + print(f" - Server Name: {server_name}") + ui_config_json = _load_fs2_ui_config(ui_config_file) + + from everyvoice.utils import spinner + + with spinner("Loading software"): + from everyvoice.demo.app import create_demo_app + + with spinner("Loading models"): + demo_app = create_demo_app( + text_to_spec_model_path=checkpoint, + spec_to_wav_model_path=vocoder, + languages=languages, + speakers=speakers, + outputs=outputs, + output_dir=output_dir, + accelerator=accelerator, + allowlist=allowlist_data, + denylist=denylist_data, + app_ui_config=ui_config_json, + **kwargs, + ) + + demo_app.launch( + share=share, + server_port=port, + server_name=server_name, + allowed_paths=[str(output_dir), tempfile.gettempdir()], + ) + + +@app.command( + name="demo", + short_help="Launch an interactive Gradio demo for any EveryVoice model", +) @merge_args(inference_base_command_interface) def demo( - text_to_spec_model: Annotated[ + checkpoint: Annotated[ Path, typer_file_argument( - help="The path to a trained text-to-spec (i.e., feature prediction) EveryVoice model." + help="Path to a trained EveryVoice checkpoint (.ckpt). " + "The model type is detected automatically from the checkpoint. " + "For FastSpeech2 models also pass --vocoder." ), ], - spec_to_wav_model: Annotated[ - Path, typer_file_argument(help="The path to a trained vocoder.") - ], - allowlist: Annotated[ + # ---- FastSpeech2 options ------------------------------------------------ + vocoder: Annotated[ Optional[Path], typer_file_option( - "--allowlist", - help="A plain text file containing a list of words or utterances to allow synthesizing. Words/utterances should be separated by a new line in a plain text file. All other words are disallowed.", + "--vocoder", + "-V", + help="[FastSpeech2] Path to a trained HiFiGAN vocoder checkpoint. " + "Required when the primary checkpoint is a FastSpeech2 model. " + "Not used with StyleTTS2 checkpoints.", + rich_help_panel="FastSpeech2 (text-to-spec) Options", ), ] = None, - denylist: Annotated[ - Optional[Path], - typer_file_option( - "--denylist", - help="A plain text file containing a list of words or utterances to disallow synthesizing. Words/utterances should be separated by a new line in a plain text file. All other words are allowed. IMPORTANT: there are many ways to 'hack' the denylist that we do not protect against. We suggest using the 'allowlist' instead for maximum security if you know the full list of utterances you want to allow synthesis for.", - ), - ] = None, - languages: list[str] = typer.Option( - ["all"], - "--language", - "-l", - help="Specify languages to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --language eng --language fin", - ), speakers: list[str] = typer.Option( ["all"], "--speaker", "-s", - help="Specify speakers to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --speaker speaker_1 --speaker Sue", + help="[FastSpeech2] Speaker names to expose in the demo UI. " + "Repeat the flag to include multiple speakers. " + "Defaults to all speakers in the model. " + "Not applicable to StyleTTS2 — use --ref-speaker instead.", + rich_help_panel="FastSpeech2 (text-to-spec) Options", + ), + languages: list[str] = typer.Option( + ["all"], + "--language", + "-l", + help="[FastSpeech2] Languages to expose in the demo UI. " + "Repeat the flag to include multiple languages. " + "Defaults to all languages in the model.", + rich_help_panel="FastSpeech2 (text-to-spec) Options", ), outputs: list[AllowedDemoOutputFormats] = typer.Option( ["all"], "--output-format", "-O", - help="Specify output formats to be included in the demo. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --output-format wav --output-format readalong-html", + help="[FastSpeech2] Output formats to expose in the demo UI.", + rich_help_panel="FastSpeech2 (text-to-spec) Options", + ), + ui_config_file: Annotated[ + Optional[Path], + typer_file_option( + "--ui-config-file", + "-C", + help="[FastSpeech2] JSON file to override UI labels (app_title, app_description, " + "app_instructions, speakers, languages, input_text_label, " + "duration_multiplier_label, language_label, speaker_label, " + "output_format_label, synthesize_label, file_output_label).", + rich_help_panel="FastSpeech2 (text-to-spec) Options", + ), + ] = None, + # ---- StyleTTS2 options -------------------------------------------------- + ref_speaker: list[str] = typer.Option( + [], + "--ref-speaker", + "-R", + help="[StyleTTS2] Named speaker with reference audio, in the format " + "'Display Name=path/to/audio.wav'. " + "Repeat the flag to add multiple speakers. " + "Their style encodings are pre-computed at startup and shown in a dropdown. " + "Example: --ref-speaker 'Eric=eric.wav' --ref-speaker 'Darlene=darlene.wav'", + rich_help_panel="StyleTTS2 (text-to-wav) Options", + ), + reference: Optional[Path] = typer.Option( + None, + "--reference", + "-r", + help="[StyleTTS2] Default reference audio file that sets the initial speaker style. " + "Use this for reference-upload mode (no speaker dropdown). " + "Use --ref-speaker instead to pre-define named speakers.", + exists=True, + rich_help_panel="StyleTTS2 (text-to-wav) Options", ), + # ---- Shared options ----------------------------------------------------- + allowlist: Annotated[ + Optional[Path], + typer_file_option( + "--allowlist", + help="Plain text file with allowed words or utterances (one per line). " + "All other input is rejected. Cannot be combined with --denylist.", + ), + ] = None, + denylist: Annotated[ + Optional[Path], + typer_file_option( + "--denylist", + help="Plain text file with disallowed words or utterances (one per line). " + "All other input is allowed. Cannot be combined with --allowlist. " + "IMPORTANT: there are many ways to bypass a denylist. " + "Use --allowlist for maximum security.", + ), + ] = None, output_dir: Path = typer_directory_option( "synthesis_output", "--output-dir", "-o", exists=False, - help="The directory where your synthesized audio should be written", + help="Directory where synthesized audio files are written.", ), accelerator: str = typer.Option( "auto", "--accelerator", "-a", - help="Specify the Pytorch Lightning accelerator to use", + help="PyTorch Lightning accelerator (e.g. 'auto', 'cpu', 'gpu').", ), - port: int = typer.Option(7860, "--port", "-p", help="The port to run the demo on."), + port: int = typer.Option(7860, "--port", "-p", help="Port to serve the demo on."), share: bool = typer.Option( False, "--share", - help="Share the demo using Gradio's share feature. This will make the demo accessible from the internet.", + help="Publish the demo via Gradio's share tunnel (accessible from the internet).", ), server_name: str = typer.Option( "0.0.0.0", "--server-name", "-n", - help="The server name to run the demo on. This is useful if you want to run the demo on a specific IP address.", + help="Host/IP address to bind the demo server to.", ), - ui_config_file: Annotated[ - Optional[Path], - typer_file_option( - "--ui-config-file", - "-C", - help="""A path to a configuration file that will be used to override parts of the default configuration for the demo UI. This is useful if you want to override some of the text in the UI. - The config file should be a valid JSON FORMAT. The expected optional values and types are: - - "app_title": string, - "app_description": string, - "app_instructions": string, - "languages": dict ["id":"name"], - "speakers": dict ["id":"name"], - "input_text_label": string, - "duration_multiplier_label": string, - "language_label": string, - "speaker_label": string, - "output_format_label":string, - "synthesize_label": string, - "file_output_label": string - - """, - ), - ] = None, **kwargs, ): - if allowlist and denylist: - raise typer.BadParameter( - "You provided a value for both the allowlist and the denylist but you can only provide one." - ) + """Launch an interactive Gradio demo for any EveryVoice model. - allowlist_data = [] - denylist_data = [] - ui_config_json: dict | None = None + The model type is detected automatically from the checkpoint. + Pass a single checkpoint for **StyleTTS2** (text-to-wav) models: - if allowlist: - with open(allowlist) as f: - allowlist_data = [x.strip() for x in f] + everyvoice demo path/to/styletts2.ckpt --ref-speaker 'Eric=eric.wav' - if denylist: - with open(denylist) as f: - denylist_data = [x.strip() for x in f] + Pass a FastSpeech2 (text-to-spec) checkpoint plus a vocoder (spec-to-wav) for **FastSpeech2 + HiFiGAN** models: - # print the parameters to the console - print("INFO - Starting the demo with the following parameters:") - print(f" - Text-to-Spec Model: {text_to_spec_model}") - print(f" - Spec-to-Wav Model: {spec_to_wav_model}") - print(f" - Languages: {languages}") - print(f" - Speakers: {speakers}") - print(f" - Outputs: {outputs}") - print(f" - Output Directory: {output_dir}") - print(f" - Accelerator: {accelerator}") - print(f" - Allowlist: {allowlist_data if allowlist else 'None'}") - print(f" - Denylist: {denylist_data if denylist else 'None'}") - print(f" - Port: {port}") - print(f" - Share: {share}") - print(f" - Server Name: {server_name}") - if ui_config_file: - print(f" - UI Config file path: {ui_config_file}") - with open(ui_config_file) as f: - try: - ui_config_json = json.load(f) - print("\t config loaded") - except Exception as e: - raise typer.BadParameter( - f"Your config file {ui_config_file} has errors\n {e}" - ) - else: - print(" - UI Config file path: None") + everyvoice demo path/to/fs2.ckpt --vocoder path/to/hifigan.ckpt + """ + if allowlist and denylist: + raise typer.BadParameter( + "Provide either --allowlist or --denylist, not both.", + ) - with spinner("Loading software"): - from everyvoice.demo.app import create_demo_app + model_class = _peek_model_class(checkpoint) - with spinner("Loading models"): - demo = create_demo_app( - text_to_spec_model_path=text_to_spec_model, - spec_to_wav_model_path=spec_to_wav_model, - languages=languages, - speakers=speakers, - outputs=outputs, - output_dir=output_dir, - accelerator=accelerator, - allowlist=allowlist_data, - denylist=denylist_data, - app_ui_config=ui_config_json, - **kwargs, + if model_class in _VOCODER_CLASS_NAMES: + raise typer.BadParameter( + f"'{checkpoint}' appears to be a standalone vocoder checkpoint ({model_class}). " + "Pass your FastSpeech2 checkpoint as the CHECKPOINT argument and provide " + "this vocoder with --vocoder.", + param_hint="CHECKPOINT", + ) + if model_class not in _FS2_CLASS_NAMES | _STYLETTS2_CLASS_NAMES: + raise typer.BadParameter( + f"Unrecognized model type '{model_class}' in checkpoint '{checkpoint}'. " + "Expected a FastSpeech2 or StyleTTS2 checkpoint.", + param_hint="CHECKPOINT", ) - demo.launch( + allowlist_data = _load_list_file(allowlist) + denylist_data = _load_list_file(denylist) + + shared = dict( + allowlist=allowlist, + denylist=denylist, + allowlist_data=allowlist_data, + denylist_data=denylist_data, + output_dir=output_dir, + accelerator=accelerator, + port=port, share=share, - server_port=port, server_name=server_name, - # explicitly give permission to gradio to write to the output directory and temp directory - allowed_paths=[str(output_dir), tempfile.gettempdir()], ) + if model_class in _STYLETTS2_CLASS_NAMES: + _run_styletts2_demo( + checkpoint, ref_speaker, reference, speakers, vocoder, **shared # type: ignore[arg-type] + ) + else: + _run_fs2_demo( + checkpoint, + vocoder, + speakers, + languages, + outputs, + ui_config_file, + ref_speaker, + reference, + **shared, # type: ignore[arg-type] + **kwargs, + ) + # Deferred full initialization to optimize the CLI, but still exposed for unit testing. SCHEMAS_TO_OUTPUT: dict[str, Any] = {} # dict[str, type[BaseModel]] diff --git a/everyvoice/demo/app.py b/everyvoice/demo/app.py index 01abba4c..13ff7554 100644 --- a/everyvoice/demo/app.py +++ b/everyvoice/demo/app.py @@ -474,6 +474,244 @@ def load_model_from_checkpoint( return model, vocoder_model, vocoder_config, device +def synthesize_audio_styletts2( + text: str, + speaker: "str | None", # selected display name from dropdown, or None in reference-upload mode + user_reference, # filepath from Gradio Audio component, or None if not uploaded / cleared + diffusion_steps: int, + embedding_scale: float, + acoustic_blend: float, + prosody_blend: float, + *, + module, + mel_transform, + device, + output_dir: Path, + speaker_ref_s: "dict[str, torch.Tensor]", # pre-computed at startup; empty in reference-only mode + default_ref_s: "torch.Tensor | None", # pre-computed from --reference; None in speaker mode + allowlist: list[str], + denylist: list[str], +) -> str: + """Synthesize one utterance with StyleTTS2 and return the path to the saved WAV.""" + if not text or not text.strip(): + raise gr.Error("Please provide text to synthesize.") + + norm_text = normalize_text(text) + if allowlist and norm_text not in allowlist: + raise gr.Error( + f"The text '{text}' is not allowed to be synthesized by this model. " + "Please contact the model owner." + ) + if denylist: + for word in norm_text.split(): + if word in denylist: + raise gr.Error( + f"The text '{text}' contains a word that is not allowed. " + "Please contact the model owner." + ) + + # Determine ref_s — prefer user-uploaded audio, then pre-loaded speaker, then default + if user_reference is not None: + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.cli.synthesize import ( + load_reference_style, + ) + + try: + ref_s = load_reference_style( + module, mel_transform, Path(user_reference), device + ) + except Exception as e: + raise gr.Error(f"Could not load reference audio: {e}") + elif speaker is not None and speaker in speaker_ref_s: + ref_s = speaker_ref_s[speaker] + elif default_ref_s is not None: + ref_s = default_ref_s + else: + raise gr.Error( + "No reference audio available. Please upload a reference audio file." + ) + + from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.text_utils import ( + TextCleaner, + ) + + text_cleaner = TextCleaner() + tokens = torch.LongTensor(text_cleaner(text)).unsqueeze(0).to(device) + if tokens.numel() == 0: + raise gr.Error(f"Text produced no tokens: {text!r}") + input_lengths = torch.LongTensor([tokens.size(1)]).to(device) + + try: + audio = module._synthesize_text( + tokens, + input_lengths, + ref_s=ref_s, + diffusion_steps=diffusion_steps, + embedding_scale=embedding_scale, + acoustic_blend=acoustic_blend, + prosody_blend=prosody_blend, + ) + except Exception as e: + raise gr.Error(str(e)) + + import soundfile as sf + + out_path = output_dir / (slugify(text[:50]) + ".wav") + sf.write(str(out_path), audio, module.sr) + return str(out_path) + + +def make_gradio_display_styletts2( + synthesize_fn, + speaker_list: "GradioChoices", + default_reference: "Path | None" = None, +) -> "gr.Blocks": + """Build the Gradio Blocks for the StyleTTS2 demo. + + When ``speaker_list`` is non-empty a speaker dropdown is shown and the + reference audio widget becomes an optional style override. When it is + empty the reference audio widget is the primary input (reference-upload + mode) and the ``speaker`` argument is pre-bound as ``None``. + """ + has_speakers = bool(speaker_list) + interactive_speaker = len(speaker_list) > 1 + + with gr.Blocks() as demo: + gr.Markdown("