feat(awex): FSDP colocate weight update via CUDA IPC#1361
Open
guozhihao-224 wants to merge 1 commit into
Open
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request implements FSDP colocate weight transfer in the AwexFSDPAdapter, enabling FSDP-trained models to update SGLang inference weights via CUDA IPC on shared GPUs. The changes include adding methods for initializing and executing colocate updates, as well as memory management functions (release_memory and resume_memory) to offload weights to the CPU. Comprehensive unit and integration tests have also been added. Feedback was provided regarding the robustness of the _execute_colocate_weight_update_locked method, specifically recommending the use of a try...finally block to ensure weights are re-offloaded even if an exception occurs, preventing potential GPU OOM issues.
f81950b to
ee70752
Compare
Implement the full colocate weight-update lifecycle for AwexFSDPAdapter:
- Add colocate state fields and init_colocate_weight_update to publish
per-rank local shards via CUDA IPC handles
- Implement execute_colocate_weight_update with all-gathered metadata
alignment and try/finally release semantics
- Implement release_memory/resume_memory('weights') for memory management
during colocated inference
- Wrap save_parameters with colocate resume/release for checkpoint safety
- Lazy-import Megatron in adapter factory for FSDP-only users
- Add DP e2e test and unit tests for the colocate path
ee70752 to
d4b9d34
Compare
59ecf48 to
d4b9d34
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds FSDP colocate weight transfer in
AwexFSDPAdapterso FSDP-trained models can update SGLang inference weights via CUDA IPC on shared GPUs, mirroring the existing Megatron colocate path.What changed:
4 colocate methods on
AwexFSDPAdapter:init_colocate_weight_update,execute_colocate_weight_update,release_memory(["weights"]),resume_memory(["weights"])._iter_hf_params_localhelper: yields each train rank's DTensor_local_tensor(itsShard(0)chunk), or the plain tensor if the param isn't a DTensor; reloads CPU-offloaded tensors to GPU. Skipslm_head.weightwhentie_word_embeddings=Trueso the train-side key set matches inference (SGLang/vLLM collapse the tied head intomodel.embed_tokens.weight).get_weight_metadatareports each train rank's truthfulShard(0)metadata:shape = local shape,global_offset = where this chunk starts in the global tensor. The colocate IPC payload from_iter_hf_params_localmatches that contract exactly, so awex's standardslice_tensor(shard-relativetrain_slices) indexes correctly into each rank's payload, and cross-engine P2P slices that reassemble the full tensor on the infer side are computed against truthful per-rank ownership.save_parameterswrapped with resume/release so the gateway debug/awex/debug/get_parameterspath works after colocate offloads training weights._create_training_adapterlazy-importsMegatronEngine. The eager import previously transitively pulled inmegatron.bridge→transformer_engine, whichpyproject.tomldeliberately marks never-install; FSDP-only deployments couldn't start the awex worker before this fix.13 mocked unit tests in
test_fsdp_colocate_unit.pycovering protocol-level correctness without GPU.New multi-GPU e2e test
test_awex_fsdp_colocate_dp_e2e_weight_update(gated bymulti_gpu and sglang and slow).Related Issue
Follow-up to #1310 (colocated CUDA IPC weight transfer for Megatron).
Type of Change
Checklist
pre-commit run --all-files)main/review-prcommand/create-prAdditional Context
GPU verification (manual, not in CI)
Verified manually on a 2-GPU host with
flash-attn 2.8.3matchingtorch 2.9.1+cu129/ py3.12.DP=2 splits
q_proj.weight[2048, 1024]into[1024, 1024]chunks across train ranks; both halves match bit-exactly (rtol=0, atol=0), confirming train rank 1's local-shard IPC payload reaches the right region of infer rank 0's full tensor via awex's standard transfer plan.[4gpu]and[8gpu]deselected due to GPU count.Open verification items (need 4+ GPU host)
test_awex_fsdp_colocate_dp_e2e_weight_update[4gpu]/[8gpu]test_awex_fsdp_e2e_weight_update[4gpu]/[8gpu](separated NCCL P2P path; metadata-data contract is now self-consistent for this path too, but not GPU-verified yet)Out of scope (file separately)
pyproject.toml'sflash-attn-4-only install is incomplete for transformers's flash-attn-2 detection; needs an additional pre-builtflash-attn==2.8.3wheel (already done byDockerfile).uv sync --extra cudaalone leavestransformersthinking fa2 is available but the import fails.nightly.ymlworkflow is a placeholder; theslow + multi_gputests, including this PR's e2e, never get auto-run.