Skip to content

distributed llama pre-training workflow on trainium#230

Draft
samhita-alla wants to merge 3 commits into
mainfrom
add-trainium-example
Draft

distributed llama pre-training workflow on trainium#230
samhita-alla wants to merge 3 commits into
mainfrom
add-trainium-example

Conversation

@samhita-alla

@samhita-alla samhita-alla commented Mar 25, 2026

Copy link
Copy Markdown
Collaborator

This PR adds a complete Flyte 2.0 workflow for running distributed training on an EKS-deployed Union 2.0 backend.

The example demonstrates LLaMA 3.1 8B pre-training on the FineWeb dataset using AWS Trainium.

What this training pipeline enables

Distributed training on Trainium

Configured with the Elastic plugin. Extending to multi-node training is as simple as setting nnodes.

trainium_env = flyte.TaskEnvironment(
    name="llama-trainium-training",
    # TODO: Fix the builder to support unpacking high-UID files so we can switch to a
    # from_base + with_pip_packages style.
    image=flyte.Image.from_base(image_uri=os.getenv("TRAINIUM_IMAGE_URI")),
    resources=TRAINIUM_RESOURCES,
    plugin_config=Elastic(
        nnodes=NNODES,
        nproc_per_node=NPROC_PER_NODE,
        neuron_parallel_compile=True,
    ),
    env_vars={
        "MALLOC_ARENA_MAX": "64",
        "NEURON_CC_FLAGS": "--model-type transformer --cache_dir /tmp/neuron_compile_cache",
        "NEURON_FUSE_SOFTMAX": "1",
        "NEURON_RT_STOCHASTIC_ROUNDING_EN": "0",
        "NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS": "3",
        "NEURON_COMPILE_CACHE_URL": "/tmp/neuron_compile_cache",
    },
    secrets=[flyte.Secret(key="samhita_hf_key", as_env_var="HF_TOKEN")],
    cache="auto",
)

Cached data preprocessing

The preprocessing task stores preprocessed PyTorch tensors in S3 and returns the path. When the same inputs are used, the task fully resolves from cache.

Drop-in distributed setup with the PyTorch Neuron SDK

The training task uses the Neuron SDK and defaults to a quick setup, but any configuration can be swapped in directly from the UI if desired.

Real-time metrics streaming in the UI

Loss curves and custom dashboards update live as training progresses (Trainium utilization metrics will be supported soon; CPU and memory metrics already appear in the UI)

image image

Full visibility across the entire pipeline

Inspect inputs/outputs, view logs for both leader and worker processes, and trace every step end-to-end.

image

Built-in caching, retries, and error handling

Training tasks can be cached, retried at the task level, or retried via exception handling inside the task.

Native AWS integrations

S3 for datasets + checkpoints, CloudWatch for logging, ECR for images, etc.

image

No manual torchrun configuration

The Elastic plugin automatically sets up and launches torch distributed; users just run the script with python train.py.

Crash-proof training end-to-end

Checkpoints and Neuron compilation cache are saved to blob storage every n steps. If an execution fails, users can resume from the exact step by simply providing the checkpoint + cache.

Historical metrics automatically restore and continue rendering from the resumed point.

image

Clear recovery guidance in logs

If training fails, checkpoint + cache paths are surfaced in the task logs.

Motivation

The goal of this example is to demonstrate that distributed training, whether pre-training or fine-tuning, can be both effortless to experiment with and robust enough for production. A key requirement for ML teams is the ability to run the same workflow locally, in a lightweight test environment, and at full production scale without rewriting code or reconfiguring infrastructure. Flyte/Union 2.0, combined with EKS, delivers exactly that.

By setting up Union on EKS once, ML engineers can run complex distributed training jobs on Trainium or GPUs without touching any infrastructure. This clean separation of concerns (platform setup vs. model development) ensures rapid iteration, consistent execution environments, and reliable scaling as workloads grow. The workflow in this PR showcases how seamless that experience can be.

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant