Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 150 additions & 3 deletions everyvoice/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from rich.panel import Panel

from everyvoice._version import VERSION
from everyvoice.base_cli.checkpoint import inspect, rename_speaker
from everyvoice.base_cli.checkpoint import inspect, load_checkpoint, rename_speaker
from everyvoice.base_cli.interfaces import (
inference_base_command_interface,
typer_directory_option,
typer_file_argument,
typer_file_option,
)
from everyvoice.config.type_definitions import TargetTrainingTextRepresentationLevel
from everyvoice.model.aligner.wav2vec2aligner.aligner.cli import (
ALIGN_SINGLE_LONG_HELP,
ALIGN_SINGLE_SHORT_HELP,
Expand Down Expand Up @@ -62,7 +63,7 @@
)
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.cli import train as train_hfg
from everyvoice.run_tests import SUITE_NAMES, run_tests
from everyvoice.utils import spinner
from everyvoice.utils import generic_psv_filelist_reader, spinner
from everyvoice.wizard import (
PREPROCESSING_CONFIG_FILENAME_PREFIX,
SPEC_TO_WAV_CONFIG_FILENAME_PREFIX,
Expand Down Expand Up @@ -784,7 +785,7 @@ def demo(
print("\t config loaded")
except Exception as e:
raise typer.BadParameter(
f"Your config file {ui_config_file} has errors\n {e}"
f"Your config file {ui_config_file} has errors.\n {e}"
)
else:
print(" - UI Config file path: None")
Expand Down Expand Up @@ -897,5 +898,151 @@ def g2p(
print(g2p(line))


def require_exactly_one_of(arg1: Any, arg1_name: str, arg2: Any, arg2_name: str):
if arg1 and arg2:
raise typer.BadParameter(
f"Please specify only one of {arg1_name} or {arg2_name}."
)
if not arg1 and not arg2:
raise typer.BadParameter(f"One of {arg1_name} and {arg2_name} is required.")


def open_text_or_psv_file(
text_file: Optional[Path], psv_file: Optional[Path], language: Optional[str]
) -> list[dict[str, str]]:
"""helper for check_text_config: Open a text or psv file into records.

Language is required if not already in the psv

raises: typer.BadParameter if something is wrong"""
if text_file:
with open(text_file, "r", encoding="utf8") as f:
text_lines = list(f)
# print(text_lines)
if language is None:
raise typer.BadParameter("--language is required with --text-file.")
records = [{"characters": line, "language": language} for line in text_lines]
elif psv_file:
records = generic_psv_filelist_reader(psv_file)
if "language" not in records[0]:
if language is None:
raise typer.BadParameter(
"--language is required for a psv file without a language column."
)
for record in records:
record["language"] = language
else:
assert False
return records


def get_text_config_from_config_or_model(config: Optional[Path], model: Optional[Path]):
"""Helper for chec_text_config: load a TextConfig from a config file or model file"""
from everyvoice.config.text_config import TextConfig

if config:
text_config: TextConfig = TextConfig.load_config_from_path(config)
elif model:
with spinner("Loading model"):
checkpoint = load_checkpoint(model)
# print("Looking for text config")
model_config = checkpoint["hyper_parameters"]["config"]
if "text" in model_config:
# Question: FS2 models have text config, do any others have it?
# For other models that have it, are they in the same place in the metadata?
text_config = TextConfig(**model_config["text"])
else:
# Models without text config, e.g., a HiFiGan Vocoder, are not accepted here
raise typer.BadParameter(
f"Model/checkpoint {model} does not have an embedded text configuration."
)
return text_config


@app.command()
def check_text_config(
config: Annotated[
Optional[Path],
typer_file_option(
help="path to text config, i.e., everyvoice-shared-text.yaml"
),
] = None,
model: Annotated[
Optional[Path],
typer_file_option(help="path to a model whose text config will be used"),
] = None,
text_file: Annotated[
Optional[Path],
typer_file_option(help="path to a plain text file to check"),
] = None,
psv_file: Annotated[
Optional[Path],
typer_file_option(help="path to a psv file to check"),
] = None,
language: Annotated[
Optional[str],
typer.Option(
help="language id, required with --text-file, or for a psv file without a language column. "
+ "Declaring the language is always required, because text normalization can be language specific, and g2p is always language specific."
),
] = None,
):
"""
Inspect a text configuration for compatiblity with an input file

Test processing input_file against the text configuration provided, or the text
configuration found in model, and report any incompatibilities.
"""
require_exactly_one_of(config, "--config", model, "--model")
require_exactly_one_of(text_file, "--text-file", psv_file, "--psv-file")
records = open_text_or_psv_file(text_file, psv_file, language)

# Expensive imports are deferred so we fail fast where we can
with spinner("Loading software"):
from everyvoice.config.text_config import TextConfig # noqa F401
from everyvoice.preprocessor.preprocessor import Preprocessor
from everyvoice.text.text_processor import TextProcessor
from everyvoice.text.utils import guess_graphemes_in_text

text_config = get_text_config_from_config_or_model(config, model)
# print(text_config)

text_processor_chars_only = TextProcessor(text_config)
text_processor_all = TextProcessor(text_config)
with spinner("Analyzing text"):
for record in records:
# print(record)
# Process just the text to calculate missing characters
_ = Preprocessor.process_text(
record,
text_processor_chars_only,
specific_text_representation=TargetTrainingTextRepresentationLevel.characters,
)
# Process all to also calculate missing phones
_ = Preprocessor.process_text(record, text_processor_all)

missing_characters = text_processor_chars_only.missing_symbols
missing_phones = text_processor_all.missing_symbols - missing_characters
missing_symbol_groups = list(missing_characters)
for missing_symbol_group in missing_symbol_groups:
split_symbols = guess_graphemes_in_text(missing_symbol_group)
if len(split_symbols) > 1:
count = missing_characters.pop(missing_symbol_group)
for symbol in split_symbols:
missing_characters[symbol] += count
# print("Missing characters", missing_characters)
# print("Missing phones", missing_phones)
if missing_characters:
print(
"The following characters are missing from your text config:",
sorted(missing_characters),
)
if missing_phones:
print(
"The following phones are missing from your text config:",
sorted(missing_phones),
)


if __name__ == "__main__":
app()
49 changes: 30 additions & 19 deletions everyvoice/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,12 @@ def process_text(
Returns:
tuple[Optional[str], Optional[str], Optional[npt.NDArray[np.float32]]]|tuple[Optional[list[int]], Optional[list[int]], Optional[npt.NDArray[np.float32]]]: if encode_as_string is true, returns an optional characters string, an optional phones string, and an optional multi-hot phonological feature vector. if encode_as_string is false, returns a list of ints for characters and phones
"""
if specific_text_representation is not None:
if specific_text_representation not in (
None,
TargetTrainingTextRepresentationLevel.characters,
):
raise NotImplementedError(
"Sorry 'specific_text_representation' isn't implemented yet, please set it to None."
"Sorry 'specific_text_representation' is only implemented for characters, please set it to None or characters."
) # TODO: refactor so that you don't *need* to generate all possible representations, to make synthesis faster.
if text_processor is None:
raise NotImplementedError(
Expand All @@ -785,6 +788,8 @@ def process_text(
if (
DatasetTextRepresentation.arpabet.value in item
and DatasetTextRepresentation.ipa_phones.value not in item
and specific_text_representation
!= TargetTrainingTextRepresentationLevel.characters
):
tokens = text_processor.encode_text(
text=ARPABET_TO_IPA_TRANSDUCER(
Expand Down Expand Up @@ -813,6 +818,8 @@ def process_text(
if (
item["language"] in AVAILABLE_G2P_ENGINES
and DatasetTextRepresentation.ipa_phones.value not in item
and specific_text_representation
!= TargetTrainingTextRepresentationLevel.characters
):
tokens = text_processor.encode_text(
text=item[DatasetTextRepresentation.characters.value],
Expand All @@ -824,23 +831,27 @@ def process_text(
)
assert isinstance(tokens, list)
phone_tokens = tokens
# if dataset is phones
if DatasetTextRepresentation.ipa_phones.value in item:
tokens = text_processor.encode_text(
text=item[DatasetTextRepresentation.ipa_phones.value],
dataset_label=dataset_label,
apply_g2p=False,
encode_as_phonological_features=False,
quiet=True,
)
assert isinstance(tokens, list)
phone_tokens = tokens
# calculate pfs
if phone_tokens and use_pfs:
pfs = text_processor.calculate_phonological_features(
text_processor.token_sequence_to_text_sequence(phone_tokens),
apply_punctuation_rules=True,
)
if (
specific_text_representation
!= TargetTrainingTextRepresentationLevel.characters
):
# if dataset is phones
if DatasetTextRepresentation.ipa_phones.value in item:
tokens = text_processor.encode_text(
text=item[DatasetTextRepresentation.ipa_phones.value],
dataset_label=dataset_label,
apply_g2p=False,
encode_as_phonological_features=False,
quiet=True,
)
assert isinstance(tokens, list)
phone_tokens = tokens
# calculate pfs
if phone_tokens and use_pfs:
pfs = text_processor.calculate_phonological_features(
text_processor.token_sequence_to_text_sequence(phone_tokens),
apply_punctuation_rules=True,
)
# encode to string
if encode_as_string:
if phone_tokens is not None:
Expand Down
Loading