Skip to content
50 changes: 48 additions & 2 deletions docs/DEVELOPERNOTES.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Developer Notes

## Google Cloud Translation Setup
## Google Cloud Translation LLM (TLLM) Setup

The MUSE project supports Google Cloud's Translation LLM (TLLM) model for machine translation. This requires Google Cloud CLI (gcloud) setup and authentication.
The MuSE project supports Google Cloud's Translation LLM (TLLM) model for machine translation. This requires Google Cloud CLI (gcloud) setup and authentication.

### Installing Google Cloud CLI

Expand Down Expand Up @@ -36,3 +36,49 @@ gcloud auth application-default set-quote-project [project id]
```

Alternatively, a different credential file may be selected by setting the `GOOGLE_APPLICATION_CREDENTIALS` environmental variable. See the [Google ADC guide](https://docs.cloud.google.com/docs/authentication/application-default-credentials) for more information.

## TranslateGemma Setup

The MuSE project supports Google's TranslateGemma model for machine translation. This requires HuggingFace authentication and license acceptance.

Comment thread
laurejt marked this conversation as resolved.
### Installing HuggingFace CLI

Install the HuggingFace CLI using pip:

```bash
pip install huggingface-hub
```

Verify installation:

```bash
hf --version
```

### Generating an Access Token

Comment thread
laurejt marked this conversation as resolved.
To generate an access token:

1. Visit [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
2. Click "New token" and select "Read" access type
3. Copy the token and use it with `hf auth login`

For more information on security tokens, see the [HuggingFace documentation](https://huggingface.co/docs/hub/security-tokens).

### Authentication with HuggingFace CLI

For HuggingFace authentication, use the HuggingFace CLI to login with your access token:

```bash
hf auth login
```

The token is stored in the following location:
`~/.cache/huggingface/token`

#### Accepting Model License

For gated models like TranslateGemma, you must accept the license:

1. Visit the model page: [https://huggingface.co/google/translategemma-4b-it](https://huggingface.co/google/translategemma-4b-it)
2. Click "Acknowledge license" to accept the license terms
101 changes: 100 additions & 1 deletion src/muse/translation/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

The translate() function provides a unified interface for translating text across
multiple models. Model-specific functions (hymt_translate, nllb_translate,
madlad_translate, google_cloud_translate) are also available for direct use.
madlad_translate, gemma_translate, google_cloud_translate) are also available for
direct use.
"""

import os
Expand All @@ -15,6 +16,7 @@
from google.cloud import translate_v3
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
Expand Down Expand Up @@ -244,6 +246,100 @@ def madlad_translate(
return tr_text


def gemma_translate(
src_lang: str,
tgt_lang: str,
text: str,
model_name: str = "google/translategemma-4b-it",
verbose: bool = False,
) -> str:
"""
Translate text written in source language to target language with Google's
TranslateGemma model. Languages are specified with their ISO 639-1 codes.

By default, the 4B instruction-tuned model (google/translategemma-4b-it) is used,
but an alternative model may be specified via `model_name`.

Note: This model requires HuggingFace authentication. See docs/DEVELOPERNOTES.md
for setup instructions.
"""
# Get tokenizer and model
# Load model and tokenizer if it's not the currently loaded model
if model_name != LOADED_MODEL["model_name"]:
if verbose:
start = timer()
try:
LOADED_MODEL["model_name"] = model_name
LOADED_MODEL["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
LOADED_MODEL["model"] = AutoModelForImageTextToText.from_pretrained(
model_name
)
except Exception as e:
# Check if error is related to authentication
error_str = str(e).lower()
if (
"401" in error_str
or "authentication" in error_str
or "gated" in error_str
):
err_msg = (
f"Failed to load model '{model_name}'. This model requires "
"HuggingFace authentication. See docs/DEVELOPERNOTES.md for "
"setup instructions."
)
raise RuntimeError(err_msg) from e
# Re-raise original exception if not authentication-related
raise
if verbose:
print(f"Loaded tokenizer & model in {timer() - start:.0f} seconds")
tokenizer = LOADED_MODEL["tokenizer"]
model = LOADED_MODEL["model"]

# Generate model input using chat template
# TranslateGemma requires a specific message format with source/target language codes
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"source_lang_code": src_lang,
"target_lang_code": tgt_lang,
"text": text,
}
],
}
]
tokenized_chat = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
input_len = len(tokenized_chat["input_ids"][0])
if verbose:
print(f"Input length: {input_len} tokens")

# Generate translation
if verbose:
start = timer()
outputs = model.generate(
**tokenized_chat,
max_new_tokens=get_max_new_tokens(input_len),
)
if verbose:
print(f"Generated model output in {timer() - start:.0f} seconds")
# Model output begins with initial prompt
tr_tokens = outputs[0][input_len:]
if verbose:
# Report generated output length excluding the prompt prefix
print(f"Output length: {outputs[0].size()[0] - input_len} tokens")
tr_text = tokenizer.decode(tr_tokens, skip_special_tokens=True)

return tr_text


def google_cloud_translate(
src_lang: str,
tgt_lang: str,
Expand Down Expand Up @@ -332,6 +428,7 @@ def translate(
- hymt: Tencent's Hunyuan Translation Model v1.5 (1.8B)
- madlad: Google's MADLAD-400 (3B)
- nllb: Meta's No Language Left Behind (3.3B)
- gemma: Google's TranslateGemma (4B)
- googe_tllm: Google Cloud Translation LLM (TLLM)

Languages are specified using ISO 639-1 codes (e.g., "zh", "ja", "es", "en").
Expand Down Expand Up @@ -361,6 +458,8 @@ def translate(
elif model == "madlad":
# MADLAD does not use src_lang parameter
return madlad_translate(tgt_lang, text, verbose=verbose)
elif model == "gemma":
return gemma_translate(src_lang, tgt_lang, text, verbose=verbose)
elif model == "google_tllm":
return google_cloud_translate(src_lang, tgt_lang, text, verbose=verbose)
else:
Expand Down