diff --git a/everyvoice/cli.py b/everyvoice/cli.py index 0d76bc15..df04ece8 100644 --- a/everyvoice/cli.py +++ b/everyvoice/cli.py @@ -15,13 +15,14 @@ from rich.panel import Panel from everyvoice._version import VERSION -from everyvoice.base_cli.checkpoint import inspect, rename_speaker +from everyvoice.base_cli.checkpoint import inspect, load_checkpoint, rename_speaker from everyvoice.base_cli.interfaces import ( inference_base_command_interface, typer_directory_option, typer_file_argument, typer_file_option, ) +from everyvoice.config.type_definitions import TargetTrainingTextRepresentationLevel from everyvoice.model.aligner.wav2vec2aligner.aligner.cli import ( ALIGN_SINGLE_LONG_HELP, ALIGN_SINGLE_SHORT_HELP, @@ -62,7 +63,7 @@ ) from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.cli import train as train_hfg from everyvoice.run_tests import SUITE_NAMES, run_tests -from everyvoice.utils import spinner +from everyvoice.utils import generic_psv_filelist_reader, spinner from everyvoice.wizard import ( PREPROCESSING_CONFIG_FILENAME_PREFIX, SPEC_TO_WAV_CONFIG_FILENAME_PREFIX, @@ -784,7 +785,7 @@ def demo( print("\t config loaded") except Exception as e: raise typer.BadParameter( - f"Your config file {ui_config_file} has errors\n {e}" + f"Your config file {ui_config_file} has errors.\n {e}" ) else: print(" - UI Config file path: None") @@ -897,5 +898,151 @@ def g2p( print(g2p(line)) +def require_exactly_one_of(arg1: Any, arg1_name: str, arg2: Any, arg2_name: str): + if arg1 and arg2: + raise typer.BadParameter( + f"Please specify only one of {arg1_name} or {arg2_name}." + ) + if not arg1 and not arg2: + raise typer.BadParameter(f"One of {arg1_name} and {arg2_name} is required.") + + +def open_text_or_psv_file( + text_file: Optional[Path], psv_file: Optional[Path], language: Optional[str] +) -> list[dict[str, str]]: + """helper for check_text_config: Open a text or psv file into records. + + Language is required if not already in the psv + + raises: typer.BadParameter if something is wrong""" + if text_file: + with open(text_file, "r", encoding="utf8") as f: + text_lines = list(f) + # print(text_lines) + if language is None: + raise typer.BadParameter("--language is required with --text-file.") + records = [{"characters": line, "language": language} for line in text_lines] + elif psv_file: + records = generic_psv_filelist_reader(psv_file) + if "language" not in records[0]: + if language is None: + raise typer.BadParameter( + "--language is required for a psv file without a language column." + ) + for record in records: + record["language"] = language + else: + assert False + return records + + +def get_text_config_from_config_or_model(config: Optional[Path], model: Optional[Path]): + """Helper for chec_text_config: load a TextConfig from a config file or model file""" + from everyvoice.config.text_config import TextConfig + + if config: + text_config: TextConfig = TextConfig.load_config_from_path(config) + elif model: + with spinner("Loading model"): + checkpoint = load_checkpoint(model) + # print("Looking for text config") + model_config = checkpoint["hyper_parameters"]["config"] + if "text" in model_config: + # Question: FS2 models have text config, do any others have it? + # For other models that have it, are they in the same place in the metadata? + text_config = TextConfig(**model_config["text"]) + else: + # Models without text config, e.g., a HiFiGan Vocoder, are not accepted here + raise typer.BadParameter( + f"Model/checkpoint {model} does not have an embedded text configuration." + ) + return text_config + + +@app.command() +def check_text_config( + config: Annotated[ + Optional[Path], + typer_file_option( + help="path to text config, i.e., everyvoice-shared-text.yaml" + ), + ] = None, + model: Annotated[ + Optional[Path], + typer_file_option(help="path to a model whose text config will be used"), + ] = None, + text_file: Annotated[ + Optional[Path], + typer_file_option(help="path to a plain text file to check"), + ] = None, + psv_file: Annotated[ + Optional[Path], + typer_file_option(help="path to a psv file to check"), + ] = None, + language: Annotated[ + Optional[str], + typer.Option( + help="language id, required with --text-file, or for a psv file without a language column. " + + "Declaring the language is always required, because text normalization can be language specific, and g2p is always language specific." + ), + ] = None, +): + """ + Inspect a text configuration for compatiblity with an input file + + Test processing input_file against the text configuration provided, or the text + configuration found in model, and report any incompatibilities. + """ + require_exactly_one_of(config, "--config", model, "--model") + require_exactly_one_of(text_file, "--text-file", psv_file, "--psv-file") + records = open_text_or_psv_file(text_file, psv_file, language) + + # Expensive imports are deferred so we fail fast where we can + with spinner("Loading software"): + from everyvoice.config.text_config import TextConfig # noqa F401 + from everyvoice.preprocessor.preprocessor import Preprocessor + from everyvoice.text.text_processor import TextProcessor + from everyvoice.text.utils import guess_graphemes_in_text + + text_config = get_text_config_from_config_or_model(config, model) + # print(text_config) + + text_processor_chars_only = TextProcessor(text_config) + text_processor_all = TextProcessor(text_config) + with spinner("Analyzing text"): + for record in records: + # print(record) + # Process just the text to calculate missing characters + _ = Preprocessor.process_text( + record, + text_processor_chars_only, + specific_text_representation=TargetTrainingTextRepresentationLevel.characters, + ) + # Process all to also calculate missing phones + _ = Preprocessor.process_text(record, text_processor_all) + + missing_characters = text_processor_chars_only.missing_symbols + missing_phones = text_processor_all.missing_symbols - missing_characters + missing_symbol_groups = list(missing_characters) + for missing_symbol_group in missing_symbol_groups: + split_symbols = guess_graphemes_in_text(missing_symbol_group) + if len(split_symbols) > 1: + count = missing_characters.pop(missing_symbol_group) + for symbol in split_symbols: + missing_characters[symbol] += count + # print("Missing characters", missing_characters) + # print("Missing phones", missing_phones) + if missing_characters: + print( + "The following characters are missing from your text config:", + sorted(missing_characters), + ) + if missing_phones: + print( + "The following phones are missing from your text config:", + sorted(missing_phones), + ) + + if __name__ == "__main__": app() diff --git a/everyvoice/preprocessor/preprocessor.py b/everyvoice/preprocessor/preprocessor.py index 26e99fa8..8f15b3ad 100644 --- a/everyvoice/preprocessor/preprocessor.py +++ b/everyvoice/preprocessor/preprocessor.py @@ -767,9 +767,12 @@ def process_text( Returns: tuple[Optional[str], Optional[str], Optional[npt.NDArray[np.float32]]]|tuple[Optional[list[int]], Optional[list[int]], Optional[npt.NDArray[np.float32]]]: if encode_as_string is true, returns an optional characters string, an optional phones string, and an optional multi-hot phonological feature vector. if encode_as_string is false, returns a list of ints for characters and phones """ - if specific_text_representation is not None: + if specific_text_representation not in ( + None, + TargetTrainingTextRepresentationLevel.characters, + ): raise NotImplementedError( - "Sorry 'specific_text_representation' isn't implemented yet, please set it to None." + "Sorry 'specific_text_representation' is only implemented for characters, please set it to None or characters." ) # TODO: refactor so that you don't *need* to generate all possible representations, to make synthesis faster. if text_processor is None: raise NotImplementedError( @@ -785,6 +788,8 @@ def process_text( if ( DatasetTextRepresentation.arpabet.value in item and DatasetTextRepresentation.ipa_phones.value not in item + and specific_text_representation + != TargetTrainingTextRepresentationLevel.characters ): tokens = text_processor.encode_text( text=ARPABET_TO_IPA_TRANSDUCER( @@ -813,6 +818,8 @@ def process_text( if ( item["language"] in AVAILABLE_G2P_ENGINES and DatasetTextRepresentation.ipa_phones.value not in item + and specific_text_representation + != TargetTrainingTextRepresentationLevel.characters ): tokens = text_processor.encode_text( text=item[DatasetTextRepresentation.characters.value], @@ -824,23 +831,27 @@ def process_text( ) assert isinstance(tokens, list) phone_tokens = tokens - # if dataset is phones - if DatasetTextRepresentation.ipa_phones.value in item: - tokens = text_processor.encode_text( - text=item[DatasetTextRepresentation.ipa_phones.value], - dataset_label=dataset_label, - apply_g2p=False, - encode_as_phonological_features=False, - quiet=True, - ) - assert isinstance(tokens, list) - phone_tokens = tokens - # calculate pfs - if phone_tokens and use_pfs: - pfs = text_processor.calculate_phonological_features( - text_processor.token_sequence_to_text_sequence(phone_tokens), - apply_punctuation_rules=True, - ) + if ( + specific_text_representation + != TargetTrainingTextRepresentationLevel.characters + ): + # if dataset is phones + if DatasetTextRepresentation.ipa_phones.value in item: + tokens = text_processor.encode_text( + text=item[DatasetTextRepresentation.ipa_phones.value], + dataset_label=dataset_label, + apply_g2p=False, + encode_as_phonological_features=False, + quiet=True, + ) + assert isinstance(tokens, list) + phone_tokens = tokens + # calculate pfs + if phone_tokens and use_pfs: + pfs = text_processor.calculate_phonological_features( + text_processor.token_sequence_to_text_sequence(phone_tokens), + apply_punctuation_rules=True, + ) # encode to string if encode_as_string: if phone_tokens is not None: