diff --git a/inference.py b/inference.py index 782dcfaa..1c9a515a 100644 --- a/inference.py +++ b/inference.py @@ -5,13 +5,16 @@ import soundfile as sf import torch -import torchaudio import typer import yaml from styletts2.lightning import StyleTTS2Module from styletts2.text_utils import TextCleaner -from styletts2.utils import MEL_MEAN, MEL_STD, length_to_mask, make_mel_transform +from styletts2.utils import ( + _load_reference_mel, + length_to_mask, + make_mel_transform, +) try: from phonemizer.backend import EspeakBackend @@ -40,17 +43,6 @@ def _phonemize(text, language): return result[0] if result else "" -def _load_reference_mel(path, target_sr, mel_transform): - wave, sr = torchaudio.load(path) - wave = wave.mean(0) - if sr != target_sr: - wave = torchaudio.functional.resample(wave, sr, target_sr) - wave = wave.to(next(mel_transform.buffers()).device) - mel = mel_transform(wave) - mel = (torch.log(1e-5 + mel.unsqueeze(0)) - MEL_MEAN) / MEL_STD - return mel # [1, n_mels, T] - - def load_model(config_path, checkpoint_path, mode, device): config = yaml.safe_load(open(config_path)) module = StyleTTS2Module(config, mode=mode) @@ -86,9 +78,7 @@ def synthesize( t_en = module.text_encoder(tokens, input_lengths, text_mask) ref_mel = _load_reference_mel(reference_path, module.sr, mel_transform).to(device) - ref_ss = module.style_encoder(ref_mel.unsqueeze(1)) - ref_sp = module.predictor_encoder(ref_mel.unsqueeze(1)) - ref_s = torch.cat([ref_ss, ref_sp], dim=1) + ref_s = module._encode_reference(ref_mel) noise = torch.randn((1, 256), device=device).unsqueeze(1) s_pred = module._sampler( @@ -238,8 +228,11 @@ def main( typer.echo(f"Skipping malformed line: {line!r}", err=True) continue rows.append((parts[0], parts[1])) - else: + elif text: rows = [("output", text)] + else: + typer.echo("Error: no text provided", err=True) + raise typer.Exit(code=1) for stem, raw_text in rows: if do_phonemize: diff --git a/pyproject.toml b/pyproject.toml index 06c76e48..3093529c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ files = "styletts2" ignore_missing_imports = true plugins = ["pydantic.mypy", "numpy.typing.mypy_plugin"] check_untyped_defs = false +explicit_package_bases = true # --------------------------------------------------------------------------- # pytest diff --git a/styletts2/cli/synthesize.py b/styletts2/cli/synthesize.py new file mode 100644 index 00000000..bca35690 --- /dev/null +++ b/styletts2/cli/synthesize.py @@ -0,0 +1,263 @@ +"""Core synthesis helpers for StyleTTS2. + +Shared by the `everyvoice synthesize text-to-wav` CLI command +and the `everyvoice demo text-to-wav` Gradio app. +""" + +from __future__ import annotations + +from pathlib import Path + +import typer +from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.type_definitions import ( + SynthesizeOutputFormats, +) +from loguru import logger + + +def load_styletts2_model(model_path: Path, device): + """Load a StyleTTS2 Lightning module and mel transform from a checkpoint.""" + import torch + + from .lightning import StyleTTS2Module + from .utils import make_mel_transform + + state = torch.load(model_path, map_location="cpu", weights_only=False) + hp = state.get("hyper_parameters", {}) + native_config = hp["config"] + mode = hp.get("mode", "second") + + module = StyleTTS2Module(native_config, mode=mode) + module.check_and_upgrade_checkpoint(state) + module.load_state_dict(state["state_dict"]) + module.eval() + module.to(device) + + mel_transform = make_mel_transform(native_config).to(device) + return module, mel_transform + + +def load_reference_style( + module, + mel_transform, + reference_path: Path, + device, +): + """Load a reference audio file and return a pre-computed style encoding. + + Runs ``_load_reference_mel`` then ``module._encode_reference``, returning + ``ref_s`` of shape ``[1, 256]`` on ``device``. Call this at startup to + avoid re-computing on every synthesis request. + """ + import torch + + from .utils import ( + _load_reference_mel, + ) + + with torch.no_grad(): + ref_mel = _load_reference_mel(reference_path, module.sr, mel_transform).to( + device + ) + return module._encode_reference(ref_mel) + + +def synthesize_one( + module, + mel_transform, + text: str, + device, + reference_path: Path, + diffusion_steps: int = 5, + embedding_scale: float = 1.0, + acoustic_blend: float = 0.3, + prosody_blend: float = 0.7, +): + """Synthesize a single utterance and return a float32 numpy waveform. + + Works only with stage-2 (or finetune) checkpoints that include the + diffusion sampler. Stage-1 checkpoints will raise an AttributeError + because ``module._sampler`` does not exist. + """ + import torch + + from .text_utils import ( + TextCleaner, + ) + from .utils import ( + _load_reference_mel, + ) + + with torch.no_grad(): + text_cleaner = TextCleaner() + tokens = torch.LongTensor(text_cleaner(text)).unsqueeze(0).to(device) + if tokens.numel() == 0: + raise ValueError(f"Text produced no tokens: {text!r}") + + input_lengths = torch.LongTensor([tokens.size(1)]).to(device) + ref_mel = _load_reference_mel(reference_path, module.sr, mel_transform).to( + device + ) + + return module._synthesize_text( + tokens, + input_lengths, + ref_mel=ref_mel, + diffusion_steps=diffusion_steps, + embedding_scale=embedding_scale, + acoustic_blend=acoustic_blend, + prosody_blend=prosody_blend, + ) + + +# --------------------------------------------------------------------------- +# CLI command +# --------------------------------------------------------------------------- + +app = typer.Typer(pretty_exceptions_show_locals=False) + + +@app.command( + name="text-to-wav", + short_help="Synthesize audio from text using a trained StyleTTS2 model", +) +def synthesize( + model_path: Path = typer.Argument( + ..., + help="Path to a trained StyleTTS2 checkpoint (.ckpt).", + exists=True, + file_okay=True, + dir_okay=False, + ), + reference: Path = typer.Option( + ..., + "--reference", + "-r", + help="Reference audio file used to extract speaker style.", + exists=True, + ), + text: list[str] = typer.Option( + ..., + "--text", + "-t", + help="Text string(s) to synthesize. Repeat the flag for multiple utterances.", + ), + output_dir: Path = typer.Option( + Path("synthesis_output"), + "--output-dir", + "-o", + help="Directory where synthesized files will be written.", + ), + output_type: list[SynthesizeOutputFormats] = typer.Option( + [SynthesizeOutputFormats.wav], + "--output-type", + help="Output format(s) to produce.", + ), + accelerator: str = typer.Option( + "auto", + "--accelerator", + help="Lightning accelerator: 'cpu', 'gpu', or 'auto'.", + ), + speaker: str = typer.Option( + "default", + "--speaker", + "-s", + help="Speaker label written into output filenames.", + ), + language: str = typer.Option( + "und", + "--language", + "-l", + help="Language tag written into output filenames.", + ), + diffusion_steps: int = typer.Option( + 5, + "--diffusion-steps", + help="Number of diffusion sampling steps (higher = slower but smoother).", + ), + embedding_scale: float = typer.Option( + 1.0, + "--embedding-scale", + help="Classifier-free guidance scale for the diffusion sampler.", + ), + acoustic_blend: float = typer.Option( + 0.3, + "--acoustic-blend", + help="Blend weight for acoustic style (0 = pure reference, 1 = pure diffusion).", + ), + prosody_blend: float = typer.Option( + 0.7, + "--prosody-blend", + help="Blend weight for prosody style (0 = pure reference, 1 = pure diffusion).", + ), +): + """Synthesize audio from text using a trained StyleTTS2 model. + + Example: + + **everyvoice synthesize text-to-wav logs_and_checkpoints/.../last.ckpt \\ + --reference path/to/reference.wav \\ + --text "Hello world" --text "How are you?"** + """ + import lightning as L + import torch + from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.utils import ( + truncate_basename, + ) + from everyvoice.utils import slugify + + from .utils_heavy import ( + StyleTTS2SynthesisDataModule, + get_styletts2_synthesis_output_callbacks, + ) + + device = torch.device( + "cuda" + if ( + accelerator == "gpu" + or (accelerator == "auto" and torch.cuda.is_available()) + ) + else "cpu" + ) + + logger.info(f"Loading StyleTTS2 model from {model_path}") + module, mel_transform = load_styletts2_model(model_path, device) + module._mel_transform = mel_transform + + state = torch.load(model_path, map_location="cpu", weights_only=False) + global_step = int(state.get("global_step", 0)) + + entries = [ + { + "raw_text": t, + "basename": truncate_basename(slugify(t)), + "speaker": speaker, + "language": language, + "reference_path": str(reference), + "diffusion_steps": diffusion_steps, + "embedding_scale": embedding_scale, + "acoustic_blend": acoustic_blend, + "prosody_blend": prosody_blend, + } + for t in text + ] + + callbacks = get_styletts2_synthesis_output_callbacks( + output_type, output_dir, global_step, module.sr + ) + if not callbacks: + logger.warning("No output format requested; nothing to do.") + return + + datamodule = StyleTTS2SynthesisDataModule(entries) + + trainer = L.Trainer( + accelerator=accelerator, + callbacks=list(callbacks.values()), + logger=False, + enable_progress_bar=True, + enable_model_summary=False, + ) + trainer.predict(module, datamodule=datamodule) + + logger.info(f"Synthesis complete. Output saved to {output_dir}") diff --git a/styletts2/cli/train.py b/styletts2/cli/train.py index a1d75672..a5dc90a7 100644 --- a/styletts2/cli/train.py +++ b/styletts2/cli/train.py @@ -38,26 +38,30 @@ def train( """Train a StyleTTS2 end-to-end TTS model.""" with spinner(): import torch + if not torch.cuda.is_available(): # device="cuda" is assumed in multiple places, so let's just tell the user up front # It's also pointless to try on CPU if it takes around a week on GPU... - sys.exit("ERROR: StyleTTS2 training requires a GPU with the cuda accellerator") + sys.exit( + "ERROR: StyleTTS2 training requires a GPU with the cuda accellerator" + ) import lightning as L - from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.ev_config import ( + from everyvoice.utils import update_config_from_cli_args + from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint + from lightning.pytorch.loggers import TensorBoardLogger + from lightning.pytorch.strategies import DDPStrategy + + from .ev_config import ( StyleTTS2Config, ) - from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.ev_config.translation import ( + from .ev_config.translation import ( to_native_config, ) - from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.lightning import ( + from .lightning import ( StyleTTS2DataModule, StyleTTS2Module, ) - from everyvoice.utils import update_config_from_cli_args - from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint - from lightning.pytorch.loggers import TensorBoardLogger - from lightning.pytorch.strategies import DDPStrategy config_file: Path = kwargs["config_file"] config_args: list[str] = kwargs.get("config_args", []) diff --git a/styletts2/cli/utils_heavy.py b/styletts2/cli/utils_heavy.py new file mode 100644 index 00000000..5a282ae6 --- /dev/null +++ b/styletts2/cli/utils_heavy.py @@ -0,0 +1,104 @@ +"""Heavy synthesis helpers for StyleTTS2 — imported lazily to keep CLI startup fast.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional, Sequence + +import lightning as L +import torch +import torchaudio +from everyvoice.base_cli.prediction_writing_callback import ( + BasePredictionWritingCallback, +) +from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.type_definitions import ( + SynthesizeOutputFormats, +) +from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.utils import ( + truncate_basename, +) +from everyvoice.utils import slugify +from loguru import logger +from torch.utils.data import DataLoader, Dataset + + +def _synthesis_collate_fn(batch): + assert len(batch) == 1, "StyleTTS2 synthesis requires batch_size=1" + return batch[0] + + +class StyleTTS2SynthesisDataset(Dataset): + def __init__(self, entries: list[dict]): + self.entries = entries + + def __len__(self): + return len(self.entries) + + def __getitem__(self, idx): + return self.entries[idx] + + +class StyleTTS2SynthesisDataModule(L.LightningDataModule): + def __init__(self, entries: list[dict]): + super().__init__() + self.entries = entries + + def predict_dataloader(self): + return DataLoader( + StyleTTS2SynthesisDataset(self.entries), + batch_size=1, + collate_fn=_synthesis_collate_fn, + shuffle=False, + num_workers=0, + ) + + +class StyleTTS2PredictionWritingWavCallback(BasePredictionWritingCallback): + def __init__(self, save_dir: Path, global_step: int): + super().__init__( + save_dir=save_dir / "wav", + file_extension="pred.wav", + global_step=global_step, + include_global_step_in_filename=True, + ) + self.last_file_written: Optional[str] = None + + def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride] + self, + _trainer, + _pl_module, + outputs, + batch, + _batch_idx: int, + _dataloader_idx: int = 0, + ): + if outputs is None: + return + wav_tensor = torch.from_numpy(outputs["wav"]).unsqueeze(0) + basename = truncate_basename(slugify(outputs["raw_text"])) + filename = self.get_filename(basename, outputs["speaker"], outputs["language"]) + torchaudio.save( + filename, + wav_tensor, + outputs["sample_rate"], + format="wav", + encoding="PCM_S", + bits_per_sample=16, + ) + self.last_file_written = filename + logger.info(f"Saved WAV: {filename}") + + +def get_styletts2_synthesis_output_callbacks( + output_type: Sequence[SynthesizeOutputFormats], + output_dir: Path, + global_step: int, + sample_rate: int, +) -> dict[SynthesizeOutputFormats, BasePredictionWritingCallback]: + """Build the set of synthesis callbacks for the requested output formats. Only supports wav for now.""" + callbacks: dict[SynthesizeOutputFormats, BasePredictionWritingCallback] = {} + if SynthesizeOutputFormats.wav in output_type: + callbacks[SynthesizeOutputFormats.wav] = StyleTTS2PredictionWritingWavCallback( + output_dir, global_step + ) + return callbacks diff --git a/styletts2/lightning.py b/styletts2/lightning.py index 3bef3206..600a4c17 100644 --- a/styletts2/lightning.py +++ b/styletts2/lightning.py @@ -34,21 +34,16 @@ class StyleTTS2DataModule(L.LightningDataModule): def __init__(self, config, load_for_everyvoice=False): super().__init__() - # Accept either a native dict or a StyleTTS2Config (EveryVoice mode). - from .ev_config import ( - StyleTTS2Config, - ) - from .ev_config.translation import ( - to_native_config, - ) self.load_for_everyvoice = load_for_everyvoice if load_for_everyvoice: - self._ev_text_config = config['ev_config'].text - self._pretrained_symbols = config['ev_config'].pretrained.pretrained_symbols - self._preprocessed_dir = str(config['ev_config'].preprocessing.save_dir) - self._output_sampling_rate = config['ev_config'].preprocessing.audio.output_sampling_rate + self._ev_text_config = config["ev_config"].text + self._pretrained_symbols = config["ev_config"].pretrained.pretrained_symbols + self._preprocessed_dir = str(config["ev_config"].preprocessing.save_dir) + self._output_sampling_rate = config[ + "ev_config" + ].preprocessing.audio.output_sampling_rate self.config = config else: self._ev_text_config = None @@ -119,13 +114,25 @@ class StyleTTS2Module(L.LightningModule): mode: one of ``'first'``, ``'second'``, or ``'finetune'``. """ - def __init__(self, config: dict, mode: str = "first"): + _VERSION: str = "1.0" + + def __init__(self, config: dict | None = None, mode: str = "first"): super().__init__() + assert mode in ("first", "second", "finetune"), f"Unknown mode: {mode}" self.automatic_optimization = False self.config = config self.mode = mode + # If loading from a checkpoint, we have to first load an empty config, and then initialize in on_load_checkpoint + if self.config is not None: + self.initialize_from_config(self.config) + + # ------------------------------------------------------------------ + # Lightning hooks + # ------------------------------------------------------------------ + + def initialize_from_config(self, config): # Core hyper-parameters self.sr = config["preprocess_params"].get("sr", 24000) self.hop_length = config["preprocess_params"]["spect_params"]["hop_length"] @@ -181,11 +188,10 @@ def __init__(self, config: dict, mode: str = "first"): # Running std used for diffusion sigma estimation self._running_std: list[float] = [] - # ------------------------------------------------------------------ - # Lightning hooks - # ------------------------------------------------------------------ - def setup(self, stage=None): + if stage == "predict": + return + p = self._slmadv_params self._slmadv = SLMAdversarialLoss( self, @@ -271,6 +277,57 @@ def setup(self, stage=None): for net in (self.text_aligner, self.text_encoder, self.pitch_extractor): net.requires_grad_(False) + def on_load_checkpoint(self, checkpoint): + """Deserialize the checkpoint hyperparameters.""" + checkpoint = self.check_and_upgrade_checkpoint(checkpoint) + + hp = checkpoint.get("hyper_parameters", {}) + self.config = hp["config"] + self.mode = hp.get("mode", self.mode) + + self.initialize_from_config(self.config) + + def on_save_checkpoint(self, checkpoint): + hp = checkpoint.setdefault("hyper_parameters", {}) + hp["config"] = self.config + hp["mode"] = self.mode + checkpoint["model_info"] = { + "name": self.__class__.__name__, + "version": self._VERSION, + } + + def check_and_upgrade_checkpoint(self, checkpoint): + """ + Check model's compatibility and possibly upgrade. + """ + from packaging.version import Version + + model_info = checkpoint.get( + "model_info", + { + "name": self.__class__.__name__, + "version": "1.0", + }, + ) + + ckpt_model_type = model_info.get("name", "MISSING_TYPE") + if ckpt_model_type != self.__class__.__name__: + raise TypeError( + f"""Wrong model type ({ckpt_model_type}), we are expecting a '{self.__class__.__name__}' model""" + ) + + ckpt_version = Version(model_info.get("version", "0.0")) + if ckpt_version > Version(self._VERSION): + raise ValueError( + "Your model was created with a newer version of EveryVoice, please update your software." + ) + # Successively convert model checkpoints to newer version. + if ckpt_version < Version("1.0"): + # Upgrading from 0.0 to 1.0 requires no changes; future versions might require changes + checkpoint["model_info"]["version"] = "1.0" + + return checkpoint + def configure_optimizers(self): opt_cfg = Munch(self.config["optimizer_params"]) # OneCycleLR with pct_start=0, div_factor=1, final_div_factor=1 is a @@ -430,9 +487,9 @@ def on_validation_epoch_end(self): ref_s = None if self.multispeaker and b["ref_mels"] is not None: ref_mels = b["ref_mels"].to(self.device) - ref_ss = self.style_encoder(ref_mels.unsqueeze(1)) - ref_sp = self.predictor_encoder(ref_mels.unsqueeze(1)) - ref_s = torch.cat([ref_ss, ref_sp], dim=1) + style_enc = self.style_encoder(ref_mels.unsqueeze(1)) + predictor_enc = self.predictor_encoder(ref_mels.unsqueeze(1)) + ref_s = torch.cat([style_enc, predictor_enc], dim=1) t_en = self.text_encoder(texts, input_lengths, text_mask) @@ -523,7 +580,9 @@ def _get_clips( rs = np.random.randint(0, half_len - mel_len) en.append(asr[bib, :, rs : rs + mel_len]) gt.append(mels[bib, :, rs * 2 : (rs + mel_len) * 2]) - y = waves[bib][rs * 2 * self.hop_length : (rs + mel_len) * 2 * self.hop_length] + y = waves[bib][ + rs * 2 * self.hop_length : (rs + mel_len) * 2 * self.hop_length + ] wav.append(torch.from_numpy(y).to(device)) if p is not None: p_en.append(p[bib, :, rs : rs + mel_len]) @@ -729,9 +788,9 @@ def _train_second(self, batch, batch_idx): # noqa: C901 ref = None if self.multispeaker: - ref_ss = self.style_encoder(ref_mels.unsqueeze(1)) - ref_sp = self.predictor_encoder(ref_mels.unsqueeze(1)) - ref = torch.cat([ref_ss, ref_sp], dim=1) + style_enc = self.style_encoder(ref_mels.unsqueeze(1)) + predictor_enc = self.predictor_encoder(ref_mels.unsqueeze(1)) + ref = torch.cat([style_enc, predictor_enc], dim=1) # Per-utterance style (adaptive avgpool prevents batching) ss, gs = [], [] @@ -1001,6 +1060,140 @@ def _train_second(self, batch, batch_idx): # noqa: C901 on_epoch=False, ) + # ------------------------------------------------------------------ + # Inference helpers + # ------------------------------------------------------------------ + + @torch.no_grad() + def _encode_reference(self, ref_mel: "torch.Tensor") -> "torch.Tensor": + """Compute a combined style+predictor encoding from a normalised mel. + + ``ref_mel`` should be shape ``[1, n_mels, T]`` and already on + ``self.device``. Returns ``ref_s`` of shape ``[1, 256]``. + """ + style_enc = self.style_encoder(ref_mel.unsqueeze(1)) + predictor_enc = self.predictor_encoder(ref_mel.unsqueeze(1)) + return torch.cat([style_enc, predictor_enc], dim=1) + + @torch.no_grad() + def _synthesize_text( + self, + tokens: "torch.Tensor", + input_lengths: "torch.Tensor", + ref_mel: "torch.Tensor | None" = None, + diffusion_steps: int = 5, + embedding_scale: float = 1.0, + acoustic_blend: float = 0.3, + prosody_blend: float = 0.7, + ref_s: "torch.Tensor | None" = None, + ): + """Run a single text→waveform forward pass. + + All tensors must already be on ``self.device``. + Exactly one of ``ref_mel`` or ``ref_s`` must be supplied. + Returns a float32 numpy waveform (shape ``[T]``). + """ + if ref_s is None: + assert ref_mel is not None, "Either ref_mel or ref_s must be provided" + ref_s = self._encode_reference(ref_mel) + + text_mask = length_to_mask(input_lengths).to(self.device) + + bert_dur = self.bert(tokens, attention_mask=(~text_mask).int()) + d_en = self.bert_encoder(bert_dur).transpose(-1, -2) + t_en = self.text_encoder(tokens, input_lengths, text_mask) + + noise = torch.randn((1, 256), device=self.device).unsqueeze(1) + s_pred = self._sampler( + noise=noise, + embedding=bert_dur, + embedding_scale=embedding_scale, + num_steps=diffusion_steps, + features=ref_s, + ).squeeze(1) + + ref = acoustic_blend * s_pred[:, :128] + (1 - acoustic_blend) * ref_s[:, :128] + s = prosody_blend * s_pred[:, 128:] + (1 - prosody_blend) * ref_s[:, 128:] + + T = int(input_lengths[0].item()) + tm = text_mask[0, :T].unsqueeze(0) + d = self.predictor.text_encoder( + d_en[0, :, :T].unsqueeze(0), s, input_lengths, tm + ) + x, _ = self.predictor.lstm(d) + duration = torch.sigmoid(self.predictor.duration_proj(x)).sum(axis=-1) + pred_dur = torch.round(duration.squeeze()).clamp(min=1) + if pred_dur.ndim == 0: + pred_dur = pred_dur.unsqueeze(0) + pred_dur[-1] += 5 + + pred_aln = torch.zeros(T, int(pred_dur.sum().item()), device=self.device) + c = 0 + for i in range(T): + pred_aln[i, c : c + int(pred_dur[i].item())] = 1 + c += int(pred_dur[i].item()) + + en = d.transpose(-1, -2) @ pred_aln.unsqueeze(0) + F0_pred, N_pred = self.predictor.F0Ntrain(en, s) + out = self.decoder( + t_en[0, :, :T].unsqueeze(0) @ pred_aln.unsqueeze(0), + F0_pred, + N_pred, + ref.squeeze().unsqueeze(0), + ) + return out.cpu().numpy().squeeze() + + @torch.no_grad() + def predict_step(self, batch, batch_idx): + """Lightning predict step for batch synthesis. + + Expects ``batch`` to be a dict with keys: ``raw_text``, ``basename``, + ``speaker``, ``language``, ``reference_path``, and optional synthesis + control params (``diffusion_steps``, ``embedding_scale``, + ``acoustic_blend``, ``prosody_blend``). + """ + from .text_utils import TextCleaner + from .utils import _load_reference_mel + + device = self.device + raw_text = batch["raw_text"] + + text_cleaner = TextCleaner() + tokens = torch.LongTensor(text_cleaner(raw_text)).unsqueeze(0).to(device) + if tokens.numel() == 0: + return None + + input_lengths = torch.LongTensor([tokens.size(1)]).to(device) + + if not hasattr(self, "_mel_transform"): + from .utils import make_mel_transform + + self._mel_transform = make_mel_transform(self.config).to(device) + + ref_mel = _load_reference_mel( + batch["reference_path"], self.sr, self._mel_transform + ).to(device) + + wav = self._synthesize_text( + tokens, + input_lengths, + ref_mel=ref_mel, + diffusion_steps=batch.get("diffusion_steps", 5), + embedding_scale=batch.get("embedding_scale", 1.0), + acoustic_blend=batch.get("acoustic_blend", 0.3), + prosody_blend=batch.get("prosody_blend", 0.7), + ) + + return { + "wav": wav, + "sample_rate": self.sr, + "basename": batch["basename"], + "speaker": batch["speaker"], + "language": batch["language"], + "raw_text": raw_text, + "duration_seconds": len(wav) / self.sr, + } + # ------------------------------------------------------------------ # Validation # ------------------------------------------------------------------ diff --git a/styletts2/utils.py b/styletts2/utils.py index 83fcb9a7..4345f76e 100644 --- a/styletts2/utils.py +++ b/styletts2/utils.py @@ -21,6 +21,21 @@ def make_mel_transform(config): ) +def _load_reference_mel(path, target_sr, mel_transform): + """Load and normalise a reference audio file into a mel spectrogram. + + Returns a tensor of shape ``[1, n_mels, T]`` on the same device as + ``mel_transform``. + """ + wave, sr = torchaudio.load(path) + wave = wave.mean(0) + if sr != target_sr: + wave = torchaudio.functional.resample(wave, sr, target_sr) + wave = wave.to(next(mel_transform.buffers()).device) + mel = mel_transform(wave) + return (torch.log(1e-5 + mel.unsqueeze(0)) - MEL_MEAN) / MEL_STD + + def maximum_path(neg_cent, mask): """Cython optimized version. neg_cent: [b, t_t, t_s]