Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/muse/translation/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
from muse.translation.hymt_langs import lang_idx_zh as hymt_lang_idx_zh
from muse.translation.nllb_langs import lang_index as nllb_lang_idx

# Maximum number of (new) tokens a model can generate
MAX_GEN_LEN = 2048

# Workaround to reuse loaded models / tokenizers
LOADED_MODEL = {
"model_name": None,
Expand All @@ -34,6 +31,17 @@
}


def get_max_new_tokens(input_token_len: int) -> int:
"""
Helper function that sets the restriction for model generation based on
the model input's token length. This is used by all HuggingFace translate
functions.

Currently, it returns double the input length.
"""
return 2 * input_token_len


def hymt_translate(
src_lang: str,
tgt_lang: str,
Expand Down Expand Up @@ -101,7 +109,8 @@ def hymt_translate(
if verbose:
start = timer()
outputs = model.generate(
tokenized_chat.to(model.device), max_new_tokens=MAX_GEN_LEN
tokenized_chat.to(model.device),
max_new_tokens=get_max_new_tokens(input_len),
)
if verbose:
print(f"Generated model output in {timer() - start:.0f} seconds")
Expand Down Expand Up @@ -164,7 +173,7 @@ def nllb_translate(
outputs = model.generate(
**model_inputs,
forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_lang_idx[tgt_lang]),
max_length=MAX_GEN_LEN,
max_new_tokens=get_max_new_tokens(input_len),
)
if verbose:
print(f"Generated model output in {timer() - start:.0f} seconds")
Expand Down Expand Up @@ -222,7 +231,7 @@ def madlad_translate(
start = timer()
outputs = model.generate(
**model_inputs,
max_length=MAX_GEN_LEN,
max_new_tokens=get_max_new_tokens(input_len),
)
if verbose:
print(f"Generated model output in {timer() - start:.0f} seconds")
Expand Down
129 changes: 81 additions & 48 deletions src/muse/translation/translate_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
model, src_lang, tr_lang, src_text, ref_text, tr_text.

Usage:
translate_corpus.py MODEL INPUT OUTPUT [--verbose]
translate_corpus.py MODEL INPUT OUTPUT [--resume] [--verbose]
"""

import argparse
Expand Down Expand Up @@ -37,6 +37,7 @@ def generate_translation_record(
tgt_lang: str,
src_text: str,
ref_text: str,
resume: bool = False,
verbose: bool = False,
) -> dict[str, str]:
"""
Expand Down Expand Up @@ -82,24 +83,18 @@ def generate_translation_record(
def generate_translations(
input_path: pathlib.Path,
model: str,
skip_set: set[str] | None = None,
verbose: bool = False,
) -> Iterator[dict[str, str]]:
"""
Generate translation records from parallel corpus.
Generate translation records from parallel corpus. Optionally, can provide
a set of translations to skip with a translation represented by a string
with the following form: "[pair id]:[src_lang]-[tr_lang]".

Yields translation records one at a time for memory efficiency.
For each input record, generates two translations:
1. Original language → English
2. English → Original language

Args:
input_path: Path to input parallel corpus JSONL file
model: Model identifier
verbose: If True, print timing and token information during translation

Yields:
Translation record dicts with fields: tr_id, pair_id, model,
src_lang, tr_lang, src_text, ref_text, tr_text
"""
# Count total records for progress bar
total_records = sum(1 for _ in orjsonl.stream(input_path))
Expand All @@ -115,46 +110,53 @@ def generate_translations(

for record in progress_records:
# Translation 1: original language → English
try:
src_to_en = generate_translation_record(
pair_id=record["id"],
model=model,
src_lang=record["lang"],
tgt_lang="en",
src_text=record["text"],
ref_text=record["en_tr"],
verbose=verbose,
)
yield src_to_en
except Exception:
logger.warning(
f"Translation failed for record {record['id']} ({record['lang']}→en)"
)
raise
if skip_set and f"{record['id']}:{record['lang']}-en" in skip_set:
logger.debug(f"Skipping {record['lang']}-en translation for {record['id']}")
else:
try:
src_to_en = generate_translation_record(
pair_id=record["id"],
model=model,
src_lang=record["lang"],
tgt_lang="en",
src_text=record["text"],
ref_text=record["en_tr"],
verbose=verbose,
)
yield src_to_en
except Exception:
logger.warning(
f"Translation failed for record {record['id']} ({record['lang']}→en)"
)
raise

# Translation 2: English → original language
try:
en_to_src = generate_translation_record(
pair_id=record["id"],
model=model,
src_lang="en",
tgt_lang=record["lang"],
src_text=record["en_tr"],
ref_text=record["text"],
verbose=verbose,
)
yield en_to_src
except Exception:
logger.warning(
f"Translation failed for record {record['id']} (en→{record['lang']})"
)
raise
if skip_set and f"{record['id']}:en-{record['lang']}" in skip_set:
logger.debug(f"Skipping en-{record['lang']} translation for {record['id']}")
else:
try:
en_to_src = generate_translation_record(
pair_id=record["id"],
model=model,
src_lang="en",
tgt_lang=record["lang"],
src_text=record["en_tr"],
ref_text=record["text"],
verbose=verbose,
)
yield en_to_src
except Exception:
logger.warning(
f"Translation failed for record {record['id']} (en→{record['lang']})"
)
raise


def save_translated_corpus(
input_path: pathlib.Path,
output_path: pathlib.Path,
model: str,
resume: bool = False,
verbose: bool = False,
) -> None:
"""
Expand All @@ -175,7 +177,24 @@ def save_translated_corpus(
model: Model identifier
verbose: If True, print timing and token information during translation
"""
orjsonl.save(output_path, generate_translations(input_path, model, verbose=verbose))
if resume:
# Identify completed translations
completed_tr = set()
for tr_record in orjsonl.stream(output_path):
completed_tr.add(
f"{tr_record['pair_id']}:{tr_record['src_lang']}-{tr_record['tr_lang']}"
)
# Append results to existing output corpus
orjsonl.extend(
output_path,
generate_translations(
input_path, model, skip_set=completed_tr, verbose=verbose
),
)
else:
orjsonl.save(
output_path, generate_translations(input_path, model, verbose=verbose)
)
logger.info(f"Processing complete. Output written to: {output_path}")


Expand All @@ -194,7 +213,13 @@ def main():
args.add_argument(
"output", type=pathlib.Path, help="Output machine translation corpus JSONL file"
)
args.add_argument(
"--resume",
action="store_true",
help="Resume translation. Output machine translation corpus JSONL file must exist.",
)
args.add_argument("--verbose", action="store_true", help="Enable verbose output")

parsed = args.parse_args()

# Setup logging
Expand All @@ -207,13 +232,21 @@ def main():
if not parsed.input.is_file():
logger.error(f"{parsed.input} does not exist")
sys.exit(1)

if parsed.output.is_file():
logger.error(f"{parsed.output} exists. Not overwriting")
if not parsed.resume and parsed.output.is_file():
logger.error(f"{parsed.output} exists. Not overwriting.")
sys.exit(1)
if parsed.resume and not parsed.output.is_file():
logger.error(f"{parsed.output} does not exist. Nothing to resume from.")
sys.exit(1)

# Process corpus
save_translated_corpus(parsed.input, parsed.output, parsed.model, parsed.verbose)
save_translated_corpus(
parsed.input,
parsed.output,
parsed.model,
resume=parsed.resume,
verbose=parsed.verbose,
)


if __name__ == "__main__":
Expand Down