Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions styletts2/ev_config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,87 @@ def _default_pretrained_symbols() -> list[str]:
# ---------------------------------------------------------------------------


class StyleTTS2PretrainedConfig(ConfigModel):
"""Paths to the frozen pretrained models bundled with StyleTTS2."""
class StyleTTS2JDCConfig(ConfigModel):
"""Source for the JDC F0 extractor checkpoint."""

f0_path: PossiblyRelativePath = Field(
default="styletts2/pretrained/jdc/bst.t7",
validate_default=True,
description="Path to the JDC F0 extractor checkpoint.",
repo_id: str = Field(
default="everyvoice/styletts2-jdc-f0",
description="HuggingFace repo ID for the JDC F0 extractor.",
)
asr_config: PossiblyRelativePath = Field(
default="styletts2/pretrained/asr/config.yml",
validate_default=True,
description="Path to the ASR model config.",
filename: str = Field(
default="bst.t7",
description="Filename of the checkpoint within the HuggingFace repo.",
)
asr_path: PossiblyRelativePath = Field(
default="styletts2/pretrained/asr/epoch_00080.pth",
validate_default=True,
description="Path to the ASR model checkpoint.",
local_path: Optional[Path] = Field(
default=None,
description="Local path to the checkpoint. If set, overrides repo_id/filename.",
)
plbert_dir: PossiblyRelativePath = Field(
default="styletts2/pretrained/plbert",
validate_default=True,
description="Directory containing the PLBERT checkpoint and config.",


class StyleTTS2ASRConfig(ConfigModel):
"""Source for the ASR text-aligner checkpoint."""

repo_id: str = Field(
default="everyvoice/styletts2-asr-aligner",
description="HuggingFace repo ID for the ASR text-aligner.",
)
checkpoint_filename: str = Field(
default="epoch_00080.pth",
description="Filename of the model checkpoint within the HuggingFace repo.",
)
config_filename: str = Field(
default="config.yml",
description="Filename of the model config within the HuggingFace repo.",
)
local_checkpoint: Optional[Path] = Field(
default=None,
description="Local path to the checkpoint file. If set, overrides repo_id/checkpoint_filename.",
)
local_config: Optional[Path] = Field(
default=None,
description="Local path to the config file. If set, overrides repo_id/config_filename.",
)


class StyleTTS2PLBERTConfig(ConfigModel):
"""Source for the PLBERT text encoder checkpoint."""

repo_id: str = Field(
default="papercup-ai/multilingual-pl-bert",
description="HuggingFace repo ID for the PLBERT text encoder.",
)
checkpoint_filename: str = Field(
default="step_1000000.t7",
description="Filename of the checkpoint within the HuggingFace repo.",
)
config_filename: str = Field(
default="config.yml",
description="Filename of the model config within the HuggingFace repo.",
)
local_checkpoint: Optional[Path] = Field(
default=None,
description="Local path to the checkpoint file. If set, overrides repo_id/checkpoint_filename.",
)
local_config: Optional[Path] = Field(
default=None,
description="Local path to the config file. If set, overrides repo_id/config_filename.",
)


class StyleTTS2PretrainedConfig(ConfigModel):
"""Sources for the frozen pretrained models used by StyleTTS2."""

f0: StyleTTS2JDCConfig = Field(
default_factory=StyleTTS2JDCConfig,
description="JDC F0 extractor source (HuggingFace repo or local path).",
)
asr: StyleTTS2ASRConfig = Field(
default_factory=StyleTTS2ASRConfig,
description="ASR text-aligner source (HuggingFace repo or local paths).",
)
plbert: StyleTTS2PLBERTConfig = Field(
default_factory=StyleTTS2PLBERTConfig,
description="PLBERT text encoder source (HuggingFace repo or local paths).",
)
pretrained_symbols: list[str] = Field(
default_factory=_default_pretrained_symbols,
Expand Down
9 changes: 4 additions & 5 deletions styletts2/ev_config/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ def to_native_config(config: StyleTTS2Config) -> dict:
),
"second_stage_load_pretrained": tr.second_stage_load_pretrained,
"load_only_params": tr.load_only_params,
# --- pretrained backbone paths ---
"F0_path": str(pre.f0_path),
"ASR_config": str(pre.asr_config),
"ASR_path": str(pre.asr_path),
"PLBERT_dir": str(pre.plbert_dir),
# --- pretrained backbone sources ---
"pretrained_f0": pre.f0.model_dump(mode="json"),
"pretrained_asr": pre.asr.model_dump(mode="json"),
"pretrained_plbert": pre.plbert.model_dump(mode="json"),
# --- data ---
"data_params": {
"train_data": str(tr.training_filelist),
Expand Down
30 changes: 22 additions & 8 deletions styletts2/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@

from .dataset import build_dataloader
from .losses import DiscriminatorLoss, GeneratorLoss, MultiResolutionSTFTLoss, WavLMLoss
from .models import build_model, load_ASR_models, load_checkpoint, load_F0_models
from .models import (
build_ASR_model_shape,
build_F0_model_shape,
build_model,
load_ASR_model,
load_checkpoint,
load_F0_model,
)
from .modules.diffusion.sampler import ADPM2Sampler, DiffusionSampler, KarrasSchedule
from .modules.slmadv import SLMAdversarialLoss
from .pretrained.plbert.util import load_plbert
from .pretrained.plbert.util import build_plbert_shape, load_plbert
from .utils import (
get_data_path_list,
get_image,
Expand Down Expand Up @@ -132,7 +139,7 @@ def __init__(self, config: dict | None = None, mode: str = "first"):
# Lightning hooks
# ------------------------------------------------------------------

def initialize_from_config(self, config):
def initialize_from_config(self, config, load_pretrained_weights=True):
# Core hyper-parameters
self.sr = config["preprocess_params"].get("sr", 24000)
self.hop_length = config["preprocess_params"]["spect_params"]["hop_length"]
Expand All @@ -148,10 +155,17 @@ def initialize_from_config(self, config):
self.diff_epoch = getattr(loss_params, "diff_epoch", 0)
self.joint_epoch = getattr(loss_params, "joint_epoch", 0)

# Build pretrained backbones then the full model
text_aligner = load_ASR_models(config["ASR_path"], config["ASR_config"])
pitch_extractor = load_F0_models(config["F0_path"])
plbert = load_plbert(config["PLBERT_dir"])
# Build pretrained backbones then the full model.
# When loading from a checkpoint, skip downloading pretrained weights —
# load_state_dict will overwrite them from the checkpoint anyway.
if load_pretrained_weights:
text_aligner = load_ASR_model(config["pretrained_asr"])
pitch_extractor = load_F0_model(config["pretrained_f0"])
plbert = load_plbert(config["pretrained_plbert"])
else:
text_aligner = build_ASR_model_shape(config["pretrained_asr"])
pitch_extractor = build_F0_model_shape()
plbert = build_plbert_shape(config["pretrained_plbert"])
# TODO: model_params passes an incorrect value for n_symbols in the text embedding
nets = build_model(model_params, text_aligner, pitch_extractor, plbert)
# Register every sub-network as a direct attribute so Lightning / DDP
Expand Down Expand Up @@ -285,7 +299,7 @@ def on_load_checkpoint(self, checkpoint):
self.config = hp["config"]
self.mode = hp.get("mode", self.mode)

self.initialize_from_config(self.config)
self.initialize_from_config(self.config, load_pretrained_weights=False)

def on_save_checkpoint(self, checkpoint):
hp = checkpoint.setdefault("hyper_parameters", {})
Expand Down
59 changes: 39 additions & 20 deletions styletts2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,36 +712,55 @@ def length_to_mask(self, lengths):
return mask


def load_F0_models(path):
# load F0 model
def _resolve_pretrained_file(repo_id: str, filename: str, local_path=None) -> str:
"""Return a local path to a pretrained file, downloading from HuggingFace if needed."""
if local_path is not None:
return str(local_path)
from huggingface_hub import hf_hub_download

return hf_hub_download(repo_id, filename=filename)


def build_F0_model_shape():
"""Construct JDCNet with uninitialized weights (for checkpoint loading)."""
return JDCNet(num_class=1, seq_len=192).train()


def load_F0_model(config: dict):
path = _resolve_pretrained_file(
config["repo_id"], config["filename"], config.get("local_path")
)
F0_model = JDCNet(num_class=1, seq_len=192)
params = torch.load(path, map_location="cpu", weights_only=False)["net"]
F0_model.load_state_dict(params)
_ = F0_model.train()
return F0_model.train()

return F0_model

def build_ASR_model_shape(config: dict):
"""Construct ASRCNN from its config file only, without loading pretrained weights."""
config_path = _resolve_pretrained_file(
config["repo_id"], config["config_filename"], config.get("local_config")
)
with open(config_path) as f:
model_config = yaml.safe_load(f)["model_params"]
return ASRCNN(**model_config).train()

def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
# load ASR model
def _load_config(path):
with open(path) as f:
config = yaml.safe_load(f)
model_config = config["model_params"]
return model_config

def _load_model(model_config, model_path):
model = ASRCNN(**model_config)
params = torch.load(model_path, map_location="cpu", weights_only=False)["model"]
model.load_state_dict(params)
return model
def load_ASR_model(config: dict):
config_path = _resolve_pretrained_file(
config["repo_id"], config["config_filename"], config.get("local_config")
)
model_path = _resolve_pretrained_file(
config["repo_id"], config["checkpoint_filename"], config.get("local_checkpoint")
)

asr_model_config = _load_config(ASR_MODEL_CONFIG)
asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
_ = asr_model.train()
with open(config_path) as f:
model_config = yaml.safe_load(f)["model_params"]

return asr_model
model = ASRCNN(**model_config)
params = torch.load(model_path, map_location="cpu", weights_only=False)["model"]
model.load_state_dict(params)
return model.train()


def build_model(args, text_aligner, pitch_extractor, bert):
Expand Down
29 changes: 0 additions & 29 deletions styletts2/pretrained/asr/config.yml

This file was deleted.

Binary file removed styletts2/pretrained/asr/epoch_00080.pth
Binary file not shown.
Binary file removed styletts2/pretrained/jdc/bst.t7
Binary file not shown.
30 changes: 0 additions & 30 deletions styletts2/pretrained/plbert/config.yml

This file was deleted.

Binary file removed styletts2/pretrained/plbert/step_1000000.t7
Binary file not shown.
50 changes: 27 additions & 23 deletions styletts2/pretrained/plbert/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from collections import OrderedDict

import torch
import yaml
Expand All @@ -7,38 +7,42 @@

class CustomAlbert(AlbertModel):
def forward(self, *args, **kwargs):
# Call the original forward method
outputs = super().forward(*args, **kwargs)

# Only return the last_hidden_state
return outputs.last_hidden_state


def load_plbert(log_dir):
config_path = os.path.join(log_dir, "config.yml")
def _resolve(repo_id: str, filename: str, local_path=None) -> str:
if local_path is not None:
return str(local_path)
from huggingface_hub import hf_hub_download

return hf_hub_download(repo_id, filename=filename)


def build_plbert_shape(config: dict):
"""Construct PLBERT from its config file only, without loading pretrained weights."""
config_path = _resolve(
config["repo_id"], config["config_filename"], config.get("local_config")
)
plbert_config = yaml.safe_load(open(config_path))
albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
return CustomAlbert(albert_base_configuration)


def load_plbert(config: dict):
config_path = _resolve(
config["repo_id"], config["config_filename"], config.get("local_config")
)
ckpt_path = _resolve(
config["repo_id"], config["checkpoint_filename"], config.get("local_checkpoint")
)

plbert_config = yaml.safe_load(open(config_path))
albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
bert = CustomAlbert(albert_base_configuration)

# files = os.listdir(log_dir)
ckpts = []
for f in os.listdir(log_dir):
if f.startswith("step_"):
ckpts.append(f)

iters = [
int(f.split("_")[-1].split(".")[0])
for f in ckpts
if os.path.isfile(os.path.join(log_dir, f))
]
iters = sorted(iters)[-1]

checkpoint = torch.load(
log_dir + "/step_" + str(iters) + ".t7", map_location="cpu", weights_only=False
)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
state_dict = checkpoint["net"]
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in state_dict.items():
Expand Down