diff --git a/styletts2/ev_config/__init__.py b/styletts2/ev_config/__init__.py index 9b7edd91..5099acea 100644 --- a/styletts2/ev_config/__init__.py +++ b/styletts2/ev_config/__init__.py @@ -32,28 +32,87 @@ def _default_pretrained_symbols() -> list[str]: # --------------------------------------------------------------------------- -class StyleTTS2PretrainedConfig(ConfigModel): - """Paths to the frozen pretrained models bundled with StyleTTS2.""" +class StyleTTS2JDCConfig(ConfigModel): + """Source for the JDC F0 extractor checkpoint.""" - f0_path: PossiblyRelativePath = Field( - default="styletts2/pretrained/jdc/bst.t7", - validate_default=True, - description="Path to the JDC F0 extractor checkpoint.", + repo_id: str = Field( + default="everyvoice/styletts2-jdc-f0", + description="HuggingFace repo ID for the JDC F0 extractor.", ) - asr_config: PossiblyRelativePath = Field( - default="styletts2/pretrained/asr/config.yml", - validate_default=True, - description="Path to the ASR model config.", + filename: str = Field( + default="bst.t7", + description="Filename of the checkpoint within the HuggingFace repo.", ) - asr_path: PossiblyRelativePath = Field( - default="styletts2/pretrained/asr/epoch_00080.pth", - validate_default=True, - description="Path to the ASR model checkpoint.", + local_path: Optional[Path] = Field( + default=None, + description="Local path to the checkpoint. If set, overrides repo_id/filename.", ) - plbert_dir: PossiblyRelativePath = Field( - default="styletts2/pretrained/plbert", - validate_default=True, - description="Directory containing the PLBERT checkpoint and config.", + + +class StyleTTS2ASRConfig(ConfigModel): + """Source for the ASR text-aligner checkpoint.""" + + repo_id: str = Field( + default="everyvoice/styletts2-asr-aligner", + description="HuggingFace repo ID for the ASR text-aligner.", + ) + checkpoint_filename: str = Field( + default="epoch_00080.pth", + description="Filename of the model checkpoint within the HuggingFace repo.", + ) + config_filename: str = Field( + default="config.yml", + description="Filename of the model config within the HuggingFace repo.", + ) + local_checkpoint: Optional[Path] = Field( + default=None, + description="Local path to the checkpoint file. If set, overrides repo_id/checkpoint_filename.", + ) + local_config: Optional[Path] = Field( + default=None, + description="Local path to the config file. If set, overrides repo_id/config_filename.", + ) + + +class StyleTTS2PLBERTConfig(ConfigModel): + """Source for the PLBERT text encoder checkpoint.""" + + repo_id: str = Field( + default="papercup-ai/multilingual-pl-bert", + description="HuggingFace repo ID for the PLBERT text encoder.", + ) + checkpoint_filename: str = Field( + default="step_1000000.t7", + description="Filename of the checkpoint within the HuggingFace repo.", + ) + config_filename: str = Field( + default="config.yml", + description="Filename of the model config within the HuggingFace repo.", + ) + local_checkpoint: Optional[Path] = Field( + default=None, + description="Local path to the checkpoint file. If set, overrides repo_id/checkpoint_filename.", + ) + local_config: Optional[Path] = Field( + default=None, + description="Local path to the config file. If set, overrides repo_id/config_filename.", + ) + + +class StyleTTS2PretrainedConfig(ConfigModel): + """Sources for the frozen pretrained models used by StyleTTS2.""" + + f0: StyleTTS2JDCConfig = Field( + default_factory=StyleTTS2JDCConfig, + description="JDC F0 extractor source (HuggingFace repo or local path).", + ) + asr: StyleTTS2ASRConfig = Field( + default_factory=StyleTTS2ASRConfig, + description="ASR text-aligner source (HuggingFace repo or local paths).", + ) + plbert: StyleTTS2PLBERTConfig = Field( + default_factory=StyleTTS2PLBERTConfig, + description="PLBERT text encoder source (HuggingFace repo or local paths).", ) pretrained_symbols: list[str] = Field( default_factory=_default_pretrained_symbols, diff --git a/styletts2/ev_config/translation.py b/styletts2/ev_config/translation.py index 3359f06a..5f75a50e 100644 --- a/styletts2/ev_config/translation.py +++ b/styletts2/ev_config/translation.py @@ -61,11 +61,10 @@ def to_native_config(config: StyleTTS2Config) -> dict: ), "second_stage_load_pretrained": tr.second_stage_load_pretrained, "load_only_params": tr.load_only_params, - # --- pretrained backbone paths --- - "F0_path": str(pre.f0_path), - "ASR_config": str(pre.asr_config), - "ASR_path": str(pre.asr_path), - "PLBERT_dir": str(pre.plbert_dir), + # --- pretrained backbone sources --- + "pretrained_f0": pre.f0.model_dump(mode="json"), + "pretrained_asr": pre.asr.model_dump(mode="json"), + "pretrained_plbert": pre.plbert.model_dump(mode="json"), # --- data --- "data_params": { "train_data": str(tr.training_filelist), diff --git a/styletts2/lightning.py b/styletts2/lightning.py index 600a4c17..b38bb029 100644 --- a/styletts2/lightning.py +++ b/styletts2/lightning.py @@ -13,10 +13,17 @@ from .dataset import build_dataloader from .losses import DiscriminatorLoss, GeneratorLoss, MultiResolutionSTFTLoss, WavLMLoss -from .models import build_model, load_ASR_models, load_checkpoint, load_F0_models +from .models import ( + build_ASR_model_shape, + build_F0_model_shape, + build_model, + load_ASR_model, + load_checkpoint, + load_F0_model, +) from .modules.diffusion.sampler import ADPM2Sampler, DiffusionSampler, KarrasSchedule from .modules.slmadv import SLMAdversarialLoss -from .pretrained.plbert.util import load_plbert +from .pretrained.plbert.util import build_plbert_shape, load_plbert from .utils import ( get_data_path_list, get_image, @@ -132,7 +139,7 @@ def __init__(self, config: dict | None = None, mode: str = "first"): # Lightning hooks # ------------------------------------------------------------------ - def initialize_from_config(self, config): + def initialize_from_config(self, config, load_pretrained_weights=True): # Core hyper-parameters self.sr = config["preprocess_params"].get("sr", 24000) self.hop_length = config["preprocess_params"]["spect_params"]["hop_length"] @@ -148,10 +155,17 @@ def initialize_from_config(self, config): self.diff_epoch = getattr(loss_params, "diff_epoch", 0) self.joint_epoch = getattr(loss_params, "joint_epoch", 0) - # Build pretrained backbones then the full model - text_aligner = load_ASR_models(config["ASR_path"], config["ASR_config"]) - pitch_extractor = load_F0_models(config["F0_path"]) - plbert = load_plbert(config["PLBERT_dir"]) + # Build pretrained backbones then the full model. + # When loading from a checkpoint, skip downloading pretrained weights — + # load_state_dict will overwrite them from the checkpoint anyway. + if load_pretrained_weights: + text_aligner = load_ASR_model(config["pretrained_asr"]) + pitch_extractor = load_F0_model(config["pretrained_f0"]) + plbert = load_plbert(config["pretrained_plbert"]) + else: + text_aligner = build_ASR_model_shape(config["pretrained_asr"]) + pitch_extractor = build_F0_model_shape() + plbert = build_plbert_shape(config["pretrained_plbert"]) # TODO: model_params passes an incorrect value for n_symbols in the text embedding nets = build_model(model_params, text_aligner, pitch_extractor, plbert) # Register every sub-network as a direct attribute so Lightning / DDP @@ -285,7 +299,7 @@ def on_load_checkpoint(self, checkpoint): self.config = hp["config"] self.mode = hp.get("mode", self.mode) - self.initialize_from_config(self.config) + self.initialize_from_config(self.config, load_pretrained_weights=False) def on_save_checkpoint(self, checkpoint): hp = checkpoint.setdefault("hyper_parameters", {}) diff --git a/styletts2/models.py b/styletts2/models.py index 992f4def..2acf8a3f 100644 --- a/styletts2/models.py +++ b/styletts2/models.py @@ -712,36 +712,55 @@ def length_to_mask(self, lengths): return mask -def load_F0_models(path): - # load F0 model +def _resolve_pretrained_file(repo_id: str, filename: str, local_path=None) -> str: + """Return a local path to a pretrained file, downloading from HuggingFace if needed.""" + if local_path is not None: + return str(local_path) + from huggingface_hub import hf_hub_download + return hf_hub_download(repo_id, filename=filename) + + +def build_F0_model_shape(): + """Construct JDCNet with uninitialized weights (for checkpoint loading).""" + return JDCNet(num_class=1, seq_len=192).train() + + +def load_F0_model(config: dict): + path = _resolve_pretrained_file( + config["repo_id"], config["filename"], config.get("local_path") + ) F0_model = JDCNet(num_class=1, seq_len=192) params = torch.load(path, map_location="cpu", weights_only=False)["net"] F0_model.load_state_dict(params) - _ = F0_model.train() + return F0_model.train() - return F0_model +def build_ASR_model_shape(config: dict): + """Construct ASRCNN from its config file only, without loading pretrained weights.""" + config_path = _resolve_pretrained_file( + config["repo_id"], config["config_filename"], config.get("local_config") + ) + with open(config_path) as f: + model_config = yaml.safe_load(f)["model_params"] + return ASRCNN(**model_config).train() -def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG): - # load ASR model - def _load_config(path): - with open(path) as f: - config = yaml.safe_load(f) - model_config = config["model_params"] - return model_config - def _load_model(model_config, model_path): - model = ASRCNN(**model_config) - params = torch.load(model_path, map_location="cpu", weights_only=False)["model"] - model.load_state_dict(params) - return model +def load_ASR_model(config: dict): + config_path = _resolve_pretrained_file( + config["repo_id"], config["config_filename"], config.get("local_config") + ) + model_path = _resolve_pretrained_file( + config["repo_id"], config["checkpoint_filename"], config.get("local_checkpoint") + ) - asr_model_config = _load_config(ASR_MODEL_CONFIG) - asr_model = _load_model(asr_model_config, ASR_MODEL_PATH) - _ = asr_model.train() + with open(config_path) as f: + model_config = yaml.safe_load(f)["model_params"] - return asr_model + model = ASRCNN(**model_config) + params = torch.load(model_path, map_location="cpu", weights_only=False)["model"] + model.load_state_dict(params) + return model.train() def build_model(args, text_aligner, pitch_extractor, bert): diff --git a/styletts2/pretrained/asr/config.yml b/styletts2/pretrained/asr/config.yml deleted file mode 100644 index ca334a10..00000000 --- a/styletts2/pretrained/asr/config.yml +++ /dev/null @@ -1,29 +0,0 @@ -log_dir: "logs/20201006" -save_freq: 5 -device: "cuda" -epochs: 180 -batch_size: 64 -pretrained_model: "" -train_data: "ASRDataset/train_list.txt" -val_data: "ASRDataset/val_list.txt" - -dataset_params: - data_augmentation: false - -preprocess_parasm: - sr: 24000 - spect_params: - n_fft: 2048 - win_length: 1200 - hop_length: 300 - mel_params: - n_mels: 80 - -model_params: - input_dim: 80 - hidden_dim: 256 - n_token: 178 - token_embedding_dim: 512 - -optimizer_params: - lr: 0.0005 \ No newline at end of file diff --git a/styletts2/pretrained/asr/epoch_00080.pth b/styletts2/pretrained/asr/epoch_00080.pth deleted file mode 100644 index 04e2d802..00000000 Binary files a/styletts2/pretrained/asr/epoch_00080.pth and /dev/null differ diff --git a/styletts2/pretrained/jdc/bst.t7 b/styletts2/pretrained/jdc/bst.t7 deleted file mode 100644 index d6cf419f..00000000 Binary files a/styletts2/pretrained/jdc/bst.t7 and /dev/null differ diff --git a/styletts2/pretrained/plbert/config.yml b/styletts2/pretrained/plbert/config.yml deleted file mode 100644 index 75f60d1e..00000000 --- a/styletts2/pretrained/plbert/config.yml +++ /dev/null @@ -1,30 +0,0 @@ -log_dir: "Checkpoint" -mixed_precision: "fp16" -data_folder: "wikipedia_20220301.en.processed" -batch_size: 192 -save_interval: 5000 -log_interval: 10 -num_process: 1 # number of GPUs -num_steps: 1000000 - -dataset_params: - tokenizer: "transfo-xl-wt103" - token_separator: " " # token used for phoneme separator (space) - token_mask: "M" # token used for phoneme mask (M) - word_separator: 3039 # token used for word separator () - token_maps: "token_maps.pkl" # token map path - - max_mel_length: 512 # max phoneme length - - word_mask_prob: 0.15 # probability to mask the entire word - phoneme_mask_prob: 0.1 # probability to mask each phoneme - replace_prob: 0.2 # probablity to replace phonemes - -model_params: - vocab_size: 178 - hidden_size: 768 - num_attention_heads: 12 - intermediate_size: 2048 - max_position_embeddings: 512 - num_hidden_layers: 12 - dropout: 0.1 \ No newline at end of file diff --git a/styletts2/pretrained/plbert/step_1000000.t7 b/styletts2/pretrained/plbert/step_1000000.t7 deleted file mode 100644 index b34fd8bc..00000000 Binary files a/styletts2/pretrained/plbert/step_1000000.t7 and /dev/null differ diff --git a/styletts2/pretrained/plbert/util.py b/styletts2/pretrained/plbert/util.py index 9605efdf..75c91f5c 100644 --- a/styletts2/pretrained/plbert/util.py +++ b/styletts2/pretrained/plbert/util.py @@ -1,4 +1,4 @@ -import os +from collections import OrderedDict import torch import yaml @@ -7,38 +7,42 @@ class CustomAlbert(AlbertModel): def forward(self, *args, **kwargs): - # Call the original forward method outputs = super().forward(*args, **kwargs) - - # Only return the last_hidden_state return outputs.last_hidden_state -def load_plbert(log_dir): - config_path = os.path.join(log_dir, "config.yml") +def _resolve(repo_id: str, filename: str, local_path=None) -> str: + if local_path is not None: + return str(local_path) + from huggingface_hub import hf_hub_download + + return hf_hub_download(repo_id, filename=filename) + + +def build_plbert_shape(config: dict): + """Construct PLBERT from its config file only, without loading pretrained weights.""" + config_path = _resolve( + config["repo_id"], config["config_filename"], config.get("local_config") + ) plbert_config = yaml.safe_load(open(config_path)) + albert_base_configuration = AlbertConfig(**plbert_config["model_params"]) + return CustomAlbert(albert_base_configuration) + +def load_plbert(config: dict): + config_path = _resolve( + config["repo_id"], config["config_filename"], config.get("local_config") + ) + ckpt_path = _resolve( + config["repo_id"], config["checkpoint_filename"], config.get("local_checkpoint") + ) + + plbert_config = yaml.safe_load(open(config_path)) albert_base_configuration = AlbertConfig(**plbert_config["model_params"]) bert = CustomAlbert(albert_base_configuration) - # files = os.listdir(log_dir) - ckpts = [] - for f in os.listdir(log_dir): - if f.startswith("step_"): - ckpts.append(f) - - iters = [ - int(f.split("_")[-1].split(".")[0]) - for f in ckpts - if os.path.isfile(os.path.join(log_dir, f)) - ] - iters = sorted(iters)[-1] - - checkpoint = torch.load( - log_dir + "/step_" + str(iters) + ".t7", map_location="cpu", weights_only=False - ) + checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) state_dict = checkpoint["net"] - from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items():