distributed llama pre-training workflow on trainium#230
Draft
samhita-alla wants to merge 3 commits into
Draft
Conversation
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
1 task
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR adds a complete Flyte 2.0 workflow for running distributed training on an EKS-deployed Union 2.0 backend.
The example demonstrates LLaMA 3.1 8B pre-training on the FineWeb dataset using AWS Trainium.
What this training pipeline enables
Distributed training on Trainium
Configured with the
Elasticplugin. Extending to multi-node training is as simple as settingnnodes.Cached data preprocessing
The preprocessing task stores preprocessed PyTorch tensors in S3 and returns the path. When the same inputs are used, the task fully resolves from cache.
Drop-in distributed setup with the PyTorch Neuron SDK
The training task uses the Neuron SDK and defaults to a quick setup, but any configuration can be swapped in directly from the UI if desired.
Real-time metrics streaming in the UI
Loss curves and custom dashboards update live as training progresses (Trainium utilization metrics will be supported soon; CPU and memory metrics already appear in the UI)
Full visibility across the entire pipeline
Inspect inputs/outputs, view logs for both leader and worker processes, and trace every step end-to-end.
Built-in caching, retries, and error handling
Training tasks can be cached, retried at the task level, or retried via exception handling inside the task.
Native AWS integrations
S3 for datasets + checkpoints, CloudWatch for logging, ECR for images, etc.
No manual torchrun configuration
The
Elasticplugin automatically sets up and launches torch distributed; users just run the script withpython train.py.Crash-proof training end-to-end
Checkpoints and Neuron compilation cache are saved to blob storage every n steps. If an execution fails, users can resume from the exact step by simply providing the checkpoint + cache.
Historical metrics automatically restore and continue rendering from the resumed point.
Clear recovery guidance in logs
If training fails, checkpoint + cache paths are surfaced in the task logs.
Motivation
The goal of this example is to demonstrate that distributed training, whether pre-training or fine-tuning, can be both effortless to experiment with and robust enough for production. A key requirement for ML teams is the ability to run the same workflow locally, in a lightweight test environment, and at full production scale without rewriting code or reconfiguring infrastructure. Flyte/Union 2.0, combined with EKS, delivers exactly that.
By setting up Union on EKS once, ML engineers can run complex distributed training jobs on Trainium or GPUs without touching any infrastructure. This clean separation of concerns (platform setup vs. model development) ensures rapid iteration, consistent execution environments, and reliable scaling as workloads grow. The workflow in this PR showcases how seamless that experience can be.