Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

import soundfile as sf
import torch
import torchaudio
import typer
import yaml

from styletts2.lightning import StyleTTS2Module
from styletts2.text_utils import TextCleaner
from styletts2.utils import MEL_MEAN, MEL_STD, length_to_mask, make_mel_transform
from styletts2.utils import (
_load_reference_mel,
length_to_mask,
make_mel_transform,
)

try:
from phonemizer.backend import EspeakBackend
Expand Down Expand Up @@ -40,17 +43,6 @@ def _phonemize(text, language):
return result[0] if result else ""


def _load_reference_mel(path, target_sr, mel_transform):
wave, sr = torchaudio.load(path)
wave = wave.mean(0)
if sr != target_sr:
wave = torchaudio.functional.resample(wave, sr, target_sr)
wave = wave.to(next(mel_transform.buffers()).device)
mel = mel_transform(wave)
mel = (torch.log(1e-5 + mel.unsqueeze(0)) - MEL_MEAN) / MEL_STD
return mel # [1, n_mels, T]


def load_model(config_path, checkpoint_path, mode, device):
config = yaml.safe_load(open(config_path))
module = StyleTTS2Module(config, mode=mode)
Expand Down Expand Up @@ -86,9 +78,7 @@ def synthesize(
t_en = module.text_encoder(tokens, input_lengths, text_mask)

ref_mel = _load_reference_mel(reference_path, module.sr, mel_transform).to(device)
ref_ss = module.style_encoder(ref_mel.unsqueeze(1))
ref_sp = module.predictor_encoder(ref_mel.unsqueeze(1))
ref_s = torch.cat([ref_ss, ref_sp], dim=1)
ref_s = module._encode_reference(ref_mel)

noise = torch.randn((1, 256), device=device).unsqueeze(1)
s_pred = module._sampler(
Expand Down Expand Up @@ -238,8 +228,11 @@ def main(
typer.echo(f"Skipping malformed line: {line!r}", err=True)
continue
rows.append((parts[0], parts[1]))
else:
elif text:
rows = [("output", text)]
else:
typer.echo("Error: no text provided", err=True)
raise typer.Exit(code=1)

for stem, raw_text in rows:
if do_phonemize:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ files = "styletts2"
ignore_missing_imports = true
plugins = ["pydantic.mypy", "numpy.typing.mypy_plugin"]
check_untyped_defs = false
explicit_package_bases = true

# ---------------------------------------------------------------------------
# pytest
Expand Down
263 changes: 263 additions & 0 deletions styletts2/cli/synthesize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
"""Core synthesis helpers for StyleTTS2.

Shared by the `everyvoice synthesize text-to-wav` CLI command
and the `everyvoice demo text-to-wav` Gradio app.
"""

from __future__ import annotations

from pathlib import Path

import typer
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.type_definitions import (
SynthesizeOutputFormats,
)
from loguru import logger


def load_styletts2_model(model_path: Path, device):
"""Load a StyleTTS2 Lightning module and mel transform from a checkpoint."""
import torch

from .lightning import StyleTTS2Module
from .utils import make_mel_transform

state = torch.load(model_path, map_location="cpu", weights_only=False)
hp = state.get("hyper_parameters", {})
native_config = hp["config"]
mode = hp.get("mode", "second")

module = StyleTTS2Module(native_config, mode=mode)
module.check_and_upgrade_checkpoint(state)
module.load_state_dict(state["state_dict"])
module.eval()
module.to(device)

mel_transform = make_mel_transform(native_config).to(device)
return module, mel_transform


def load_reference_style(
module,
mel_transform,
reference_path: Path,
device,
):
"""Load a reference audio file and return a pre-computed style encoding.

Runs ``_load_reference_mel`` then ``module._encode_reference``, returning
``ref_s`` of shape ``[1, 256]`` on ``device``. Call this at startup to
avoid re-computing on every synthesis request.
"""
import torch

from .utils import (
_load_reference_mel,
)

with torch.no_grad():
ref_mel = _load_reference_mel(reference_path, module.sr, mel_transform).to(
device
)
return module._encode_reference(ref_mel)


def synthesize_one(
module,
mel_transform,
text: str,
device,
reference_path: Path,
diffusion_steps: int = 5,
embedding_scale: float = 1.0,
acoustic_blend: float = 0.3,
prosody_blend: float = 0.7,
):
"""Synthesize a single utterance and return a float32 numpy waveform.

Works only with stage-2 (or finetune) checkpoints that include the
diffusion sampler. Stage-1 checkpoints will raise an AttributeError
because ``module._sampler`` does not exist.
"""
import torch

from .text_utils import (
TextCleaner,
)
from .utils import (
_load_reference_mel,
)

with torch.no_grad():
text_cleaner = TextCleaner()
tokens = torch.LongTensor(text_cleaner(text)).unsqueeze(0).to(device)
if tokens.numel() == 0:
raise ValueError(f"Text produced no tokens: {text!r}")

input_lengths = torch.LongTensor([tokens.size(1)]).to(device)
ref_mel = _load_reference_mel(reference_path, module.sr, mel_transform).to(
device
)

return module._synthesize_text(
tokens,
input_lengths,
ref_mel=ref_mel,
diffusion_steps=diffusion_steps,
embedding_scale=embedding_scale,
acoustic_blend=acoustic_blend,
prosody_blend=prosody_blend,
)


# ---------------------------------------------------------------------------
# CLI command
# ---------------------------------------------------------------------------

app = typer.Typer(pretty_exceptions_show_locals=False)


@app.command(
name="text-to-wav",
short_help="Synthesize audio from text using a trained StyleTTS2 model",
)
def synthesize(
model_path: Path = typer.Argument(
...,
help="Path to a trained StyleTTS2 checkpoint (.ckpt).",
exists=True,
file_okay=True,
dir_okay=False,
),
reference: Path = typer.Option(
...,
"--reference",
"-r",
help="Reference audio file used to extract speaker style.",
exists=True,
),
text: list[str] = typer.Option(
...,
"--text",
"-t",
help="Text string(s) to synthesize. Repeat the flag for multiple utterances.",
),
output_dir: Path = typer.Option(
Path("synthesis_output"),
"--output-dir",
"-o",
help="Directory where synthesized files will be written.",
),
output_type: list[SynthesizeOutputFormats] = typer.Option(
[SynthesizeOutputFormats.wav],
"--output-type",
help="Output format(s) to produce.",
),
accelerator: str = typer.Option(
"auto",
"--accelerator",
help="Lightning accelerator: 'cpu', 'gpu', or 'auto'.",
),
speaker: str = typer.Option(
"default",
"--speaker",
"-s",
help="Speaker label written into output filenames.",
),
language: str = typer.Option(
"und",
"--language",
"-l",
help="Language tag written into output filenames.",
),
diffusion_steps: int = typer.Option(
5,
"--diffusion-steps",
help="Number of diffusion sampling steps (higher = slower but smoother).",
),
embedding_scale: float = typer.Option(
1.0,
"--embedding-scale",
help="Classifier-free guidance scale for the diffusion sampler.",
),
acoustic_blend: float = typer.Option(
0.3,
"--acoustic-blend",
help="Blend weight for acoustic style (0 = pure reference, 1 = pure diffusion).",
),
prosody_blend: float = typer.Option(
0.7,
"--prosody-blend",
help="Blend weight for prosody style (0 = pure reference, 1 = pure diffusion).",
),
):
"""Synthesize audio from text using a trained StyleTTS2 model.

Example:

**everyvoice synthesize text-to-wav logs_and_checkpoints/.../last.ckpt \\
--reference path/to/reference.wav \\
--text "Hello world" --text "How are you?"**
"""
import lightning as L
import torch
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.utils import (
truncate_basename,
)
from everyvoice.utils import slugify

from .utils_heavy import (
StyleTTS2SynthesisDataModule,
get_styletts2_synthesis_output_callbacks,
)

device = torch.device(
"cuda"
if (
accelerator == "gpu"
or (accelerator == "auto" and torch.cuda.is_available())
)
else "cpu"
)

logger.info(f"Loading StyleTTS2 model from {model_path}")
module, mel_transform = load_styletts2_model(model_path, device)
module._mel_transform = mel_transform

state = torch.load(model_path, map_location="cpu", weights_only=False)
global_step = int(state.get("global_step", 0))

entries = [
{
"raw_text": t,
"basename": truncate_basename(slugify(t)),
"speaker": speaker,
"language": language,
"reference_path": str(reference),
"diffusion_steps": diffusion_steps,
"embedding_scale": embedding_scale,
"acoustic_blend": acoustic_blend,
"prosody_blend": prosody_blend,
}
for t in text
]

callbacks = get_styletts2_synthesis_output_callbacks(
output_type, output_dir, global_step, module.sr
)
if not callbacks:
logger.warning("No output format requested; nothing to do.")
return

datamodule = StyleTTS2SynthesisDataModule(entries)

trainer = L.Trainer(
accelerator=accelerator,
callbacks=list(callbacks.values()),
logger=False,
enable_progress_bar=True,
enable_model_summary=False,
)
trainer.predict(module, datamodule=datamodule)

logger.info(f"Synthesis complete. Output saved to {output_dir}")
20 changes: 12 additions & 8 deletions styletts2/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,30 @@ def train(
"""Train a StyleTTS2 end-to-end TTS model."""
with spinner():
import torch

if not torch.cuda.is_available():
# device="cuda" is assumed in multiple places, so let's just tell the user up front
# It's also pointless to try on CPU if it takes around a week on GPU...
sys.exit("ERROR: StyleTTS2 training requires a GPU with the cuda accellerator")
sys.exit(
"ERROR: StyleTTS2 training requires a GPU with the cuda accellerator"
)

import lightning as L
from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.ev_config import (
from everyvoice.utils import update_config_from_cli_args
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.strategies import DDPStrategy

from .ev_config import (
StyleTTS2Config,
)
from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.ev_config.translation import (
from .ev_config.translation import (
to_native_config,
)
from everyvoice.model.e2e.StyleTTS2_lightning.styletts2.lightning import (
from .lightning import (
StyleTTS2DataModule,
StyleTTS2Module,
)
from everyvoice.utils import update_config_from_cli_args
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.strategies import DDPStrategy

config_file: Path = kwargs["config_file"]
config_args: list[str] = kwargs.get("config_args", [])
Expand Down
Loading