Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions docs/DEVELOPERNOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,35 @@

The MUSE project supports Google Cloud's Translation LLM (TLLM) model for machine translation. This requires Google Cloud CLI (gcloud) setup and authentication.

### Prerequisites
### Installing Google Cloud CLI

1. **Install Google Cloud CLI**
Install `gcloud` using the [provided installation guide](https://cloud.google.com/sdk/docs/install).

- Follow instructions at: https://cloud.google.com/sdk/docs/install
- Verify installation: `gcloud --version`
To verify the installation run:

2. **Authenticate with Application Default Credentials**
```bash
gcloud --version
```

```bash
gcloud auth application-default login
```
### Authentication with Application Default Credentials (ADC)

3. **Set required environment variables**
For Google Cloud authentication, we will rely on the ADC file that can be generated with the following command:

```bash
export GOOGLE_CLOUD_PROJECT="cdh-muse"
```
```bash
gcloud auth application-default login
```

The ADC file is written to the following location:
`~/.config/gcloud/application_default_credentials.json`

#### Working with Multiple Google Cloud Projects

If you’ve used `gcloud` for other projects, make sure that your local ADC file corresponds to the correct project. **Switching configs within `gcloud` will not update the ADC file.** However, `gcloud` will provide a warning if the activated (quota/billing) project does not match the one in the ADC file.

To switch quote projects run:

```bash
gcloud auth application-default set-quote-project [project id]
```

Alternatively, a different credential file may be selected by setting the `GOOGLE_APPLICATION_CREDENTIALS` environmental variable. See the [Google ADC guide](https://docs.cloud.google.com/docs/authentication/application-default-credentials) for more information.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"ipython", # Required by transformers for trainer functionality
"transformers[torch, sentencepiece, tiktoken]",
"google-cloud-translate",
"google-auth",
"orjsonl",
"ftfy",
]
Expand Down
60 changes: 25 additions & 35 deletions src/muse/translation/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
from timeit import default_timer as timer

import google.auth
from google.cloud import translate_v3
from transformers import (
AutoModelForCausalLM,
Expand All @@ -25,14 +26,6 @@
# Maximum number of (new) tokens a model can generate
MAX_GEN_LEN = 2048

# Supported models for the unified translate() function
SUPPORTED_MODELS = {
"tencent/HY-MT1.5-7B": "hymt",
"facebook/nllb-200-3.3B": "nllb",
"google/madlad400-7b-mt": "madlad",
"google/translation-llm": "google_cloud",
}


def hymt_translate(
src_lang: str,
Expand Down Expand Up @@ -242,15 +235,20 @@ def google_cloud_translate(
Translated text as a string

Raises:
ValueError: If GOOGLE_CLOUD_PROJECT environment variable is not set
RuntimeError: If there is an issue loading the Google Application
Default Credentials (ADC)
"""
project_id = os.environ.get("GOOGLE_CLOUD_PROJECT")
if not project_id:
raise ValueError(
"GOOGLE_CLOUD_PROJECT environment variable is not set. "
"Set it with: export GOOGLE_CLOUD_PROJECT='cdh-muse'"
# Get project id from Google Application Default Credentials
try:
_, project_id = google.auth.default()
except Exception as e:
err_msg = (
"Issue loading Application Default Credentials (ADC). "
"See developer notes for more details."
)
raise RuntimeError(err_msg) from e

# Default to us-central 1 if not set in environment
region = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

if verbose:
Expand Down Expand Up @@ -300,10 +298,10 @@ def translate(
parameter.

Supported models:
- tencent/HY-MT1.5-7B: Tencent's Hunyuan Translation Model v1.5 (7B)
- facebook/nllb-200-3.3B: Meta's No Language Left Behind (3.3B)
- google/madlad400-7b-mt: Google's MADLAD-400 (7B)
- google/translation-llm: Google Cloud Translation LLM (TLLM)
- hymt: Tencent's Hunyuan Translation Model v1.5 (1.8B)
- madlad: Google's MADLAD-400 (3B)
- nllb: Meta's No Language Left Behind (3.3B)
- googe_tllm: Google Cloud Translation LLM (TLLM)

Languages are specified using ISO 639-1 codes (e.g., "zh", "ja", "es", "en").
Language validation is delegated to the model-specific functions, so supported
Expand All @@ -324,23 +322,15 @@ def translate(
ValueError: If the specified model is not supported, or if the source/target
languages are not supported by the chosen model
"""
# Validate model
if model not in SUPPORTED_MODELS:
supported = list(SUPPORTED_MODELS.keys())
raise ValueError(f"Unsupported model: {model}. Supported models: {supported}")

# Route to appropriate model-specific function
model_type = SUPPORTED_MODELS[model]

if model_type == "hymt":
return hymt_translate(src_lang, tgt_lang, text, model, verbose)
elif model_type == "nllb":
return nllb_translate(src_lang, tgt_lang, text, model, verbose)
elif model_type == "madlad":

if model == "hymt":
return hymt_translate(src_lang, tgt_lang, text, verbose=verbose)
elif model == "nllb":
return nllb_translate(src_lang, tgt_lang, text, verbose=verbose)
elif model == "madlad":
# MADLAD does not use src_lang parameter
return madlad_translate(tgt_lang, text, model, verbose)
elif model_type == "google_cloud":
return madlad_translate(tgt_lang, text, verbose=verbose)
elif model == "google_tllm":
return google_cloud_translate(src_lang, tgt_lang, text, verbose=verbose)
else:
# This should never happen if SUPPORTED_MODELS is correctly maintained
raise ValueError(f"Unknown model type: {model_type}")
raise ValueError(f"Unsupported model: {model}")
95 changes: 21 additions & 74 deletions src/muse/translation/translate_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,14 @@
import orjsonl
from tqdm import tqdm

from muse.translation.translate import SUPPORTED_MODELS, translate
from muse.translation.translate import translate

# Required fields in input parallel corpus records
REQUIRED_FIELDS = ["id", "lang", "text", "en_tr"]

logger = logging.getLogger(__name__)


def validate_model(model: str) -> None:
"""
Validate that the specified model is supported.

Args:
model: Model identifier

Raises:
ValueError: If model is not supported
"""
if model not in SUPPORTED_MODELS:
supported = list(SUPPORTED_MODELS.keys())
raise ValueError(
f"Unsupported model: {model}. Supported models: {', '.join(supported)}"
)


def generate_translation_record(
pair_id: int,
model: str,
Expand Down Expand Up @@ -118,16 +101,19 @@ def generate_translations(
Translation record dicts with fields: tr_id, pair_id, model,
src_lang, tr_lang, src_text, ref_text, tr_text
"""
for record in orjsonl.stream(input_path):
# Validate required fields at record level
missing_fields = [field for field in REQUIRED_FIELDS if field not in record]
if missing_fields:
logger.warning(
f"Skipping record {record.get('id', 'unknown')}: "
f"missing fields {missing_fields}"
)
continue
# Count total records for progress bar
total_records = sum(1 for _ in orjsonl.stream(input_path))

logger.info(f"Found {total_records} records in input file")
logger.info(f"Starting translation with model: {model}")

progress_records = tqdm(
orjsonl.stream(input_path),
total=total_records,
desc="Translating records",
)

for record in progress_records:
# Translation 1: original language → English
try:
src_to_en = generate_translation_record(
Expand All @@ -140,11 +126,11 @@ def generate_translations(
verbose=verbose,
)
yield src_to_en
except Exception as e:
except Exception:
logger.warning(
f"Translation failed for record {record['id']} "
f"({record['lang']}→en): {e}"
f"Translation failed for record {record['id']} ({record['lang']}→en)"
)
raise

# Translation 2: English → original language
try:
Expand All @@ -158,11 +144,11 @@ def generate_translations(
verbose=verbose,
)
yield en_to_src
except Exception as e:
except Exception:
logger.warning(
f"Translation failed for record {record['id']} "
f"(en→{record['lang']}): {e}"
f"Translation failed for record {record['id']} (en→{record['lang']})"
)
raise


def save_translated_corpus(
Expand All @@ -189,31 +175,7 @@ def save_translated_corpus(
model: Model identifier
verbose: If True, print timing and token information during translation
"""
# Count total records for progress bar
total_records = sum(1 for _ in orjsonl.stream(input_path))

logger.info(f"Found {total_records} records in input file")
logger.info(f"Starting translation with model: {model}")

# Generate translations with progress bar
# Each input record produces 2 output records (bidirectional)
translations_generator = generate_translations(input_path, model, verbose)

try:
with tqdm(
total=total_records * 2, desc="Translating", unit="translation"
) as pbar:

def progress_wrapper():
for translation in translations_generator:
pbar.update(1)
yield translation

orjsonl.save(output_path, progress_wrapper())
except KeyboardInterrupt:
logger.warning("\nProcessing interrupted by user")
raise

orjsonl.save(output_path, generate_translations(input_path, model, verbose=verbose))
logger.info(f"Processing complete. Output written to: {output_path}")


Expand Down Expand Up @@ -241,13 +203,6 @@ def main():
level=log_level, format="%(levelname)s: %(message)s", stream=sys.stderr
)

# Validate model early (fail fast)
try:
validate_model(parsed.model)
except ValueError as e:
logger.error(str(e))
sys.exit(1)

# Validate input
if not parsed.input.is_file():
logger.error(f"{parsed.input} does not exist")
Expand All @@ -258,15 +213,7 @@ def main():
sys.exit(1)

# Process corpus
try:
save_translated_corpus(
parsed.input, parsed.output, parsed.model, parsed.verbose
)
except KeyboardInterrupt:
sys.exit(1)
except Exception as e:
logger.error(f"Processing failed: {e}")
sys.exit(1)
save_translated_corpus(parsed.input, parsed.output, parsed.model, parsed.verbose)


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.