Skip to content

feat(awex): FSDP colocate weight update via CUDA IPC#1361

Open
guozhihao-224 wants to merge 1 commit into
areal-project:mainfrom
guozhihao-224:feat/awex-fsdp-colocate
Open

feat(awex): FSDP colocate weight update via CUDA IPC#1361
guozhihao-224 wants to merge 1 commit into
areal-project:mainfrom
guozhihao-224:feat/awex-fsdp-colocate

Conversation

@guozhihao-224
Copy link
Copy Markdown
Collaborator

@guozhihao-224 guozhihao-224 commented May 22, 2026

Description

Adds FSDP colocate weight transfer in AwexFSDPAdapter so FSDP-trained models can update SGLang inference weights via CUDA IPC on shared GPUs, mirroring the existing Megatron colocate path.

What changed:

  • 4 colocate methods on AwexFSDPAdapter: init_colocate_weight_update, execute_colocate_weight_update, release_memory(["weights"]), resume_memory(["weights"]).

  • _iter_hf_params_local helper: yields each train rank's DTensor _local_tensor (its Shard(0) chunk), or the plain tensor if the param isn't a DTensor; reloads CPU-offloaded tensors to GPU. Skips lm_head.weight when tie_word_embeddings=True so the train-side key set matches inference (SGLang/vLLM collapse the tied head into model.embed_tokens.weight).

  • get_weight_metadata reports each train rank's truthful Shard(0) metadata: shape = local shape, global_offset = where this chunk starts in the global tensor. The colocate IPC payload from _iter_hf_params_local matches that contract exactly, so awex's standard slice_tensor (shard-relative train_slices) indexes correctly into each rank's payload, and cross-engine P2P slices that reassemble the full tensor on the infer side are computed against truthful per-rank ownership.

  • save_parameters wrapped with resume/release so the gateway debug /awex/debug/get_parameters path works after colocate offloads training weights.

  • _create_training_adapter lazy-imports MegatronEngine. The eager import previously transitively pulled in megatron.bridgetransformer_engine, which pyproject.toml deliberately marks never-install; FSDP-only deployments couldn't start the awex worker before this fix.

  • 13 mocked unit tests in test_fsdp_colocate_unit.py covering protocol-level correctness without GPU.

  • New multi-GPU e2e test test_awex_fsdp_colocate_dp_e2e_weight_update (gated by multi_gpu and sglang and slow).

Related Issue

Follow-up to #1310 (colocated CUDA IPC weight transfer for Megatron).

Type of Change

  • ✨ New feature
  • 🐛 Bug fix
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated — N/A, no user-facing doc changes
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Additional Context

GPU verification (manual, not in CI)

Verified manually on a 2-GPU host with flash-attn 2.8.3 matching torch 2.9.1+cu129 / py3.12.

$ uv run pytest tests/experimental/weight_update/test_nccl_integration.py::test_awex_fsdp_colocate_dp_e2e_weight_update \
    -v -s -m "multi_gpu and sglang" -k "2gpu"
================= 1 passed, 2 deselected in 126.98s (0:02:06) ==================

[weight-validation] Comparing 7 parameters …
[weight-validation]   model.layers.0.self_attn.q_proj.weight:  OK (shape=[2048, 1024], dtype=torch.bfloat16)
[weight-validation]   model.layers.0.self_attn.k_proj.weight:  OK (shape=[1024, 1024], dtype=torch.bfloat16)
[weight-validation]   model.layers.0.self_attn.v_proj.weight:  OK (shape=[1024, 1024], dtype=torch.bfloat16)
[weight-validation]   model.layers.0.mlp.gate_proj.weight:     OK (shape=[3072, 1024], dtype=torch.bfloat16)
[weight-validation]   model.layers.0.mlp.up_proj.weight:       OK (shape=[3072, 1024], dtype=torch.bfloat16)
[weight-validation]   model.layers.27.self_attn.q_proj.weight: OK (shape=[2048, 1024], dtype=torch.bfloat16)
[weight-validation]   model.norm.weight:                       OK (shape=[1024], dtype=torch.bfloat16)
[weight-validation] All 7 parameters match between training and inference ✓

DP=2 splits q_proj.weight [2048, 1024] into [1024, 1024] chunks across train ranks; both halves match bit-exactly (rtol=0, atol=0), confirming train rank 1's local-shard IPC payload reaches the right region of infer rank 0's full tensor via awex's standard transfer plan.

[4gpu] and [8gpu] deselected due to GPU count.

Open verification items (need 4+ GPU host)

  • test_awex_fsdp_colocate_dp_e2e_weight_update[4gpu] / [8gpu]
  • test_awex_fsdp_e2e_weight_update[4gpu] / [8gpu] (separated NCCL P2P path; metadata-data contract is now self-consistent for this path too, but not GPU-verified yet)

Out of scope (file separately)

  • pyproject.toml's flash-attn-4-only install is incomplete for transformers's flash-attn-2 detection; needs an additional pre-built flash-attn==2.8.3 wheel (already done by Dockerfile). uv sync --extra cuda alone leaves transformers thinking fa2 is available but the import fails.

  • nightly.yml workflow is a placeholder; the slow + multi_gpu tests, including this PR's e2e, never get auto-run.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements FSDP colocate weight transfer in the AwexFSDPAdapter, enabling FSDP-trained models to update SGLang inference weights via CUDA IPC on shared GPUs. The changes include adding methods for initializing and executing colocate updates, as well as memory management functions (release_memory and resume_memory) to offload weights to the CPU. Comprehensive unit and integration tests have also been added. Feedback was provided regarding the robustness of the _execute_colocate_weight_update_locked method, specifically recommending the use of a try...finally block to ensure weights are re-offloaded even if an exception occurs, preventing potential GPU OOM issues.

Comment thread areal/experimental/weight_update/awex/fsdp_adapter.py Outdated
@guozhihao-224 guozhihao-224 force-pushed the feat/awex-fsdp-colocate branch 3 times, most recently from f81950b to ee70752 Compare May 25, 2026 12:37
@guozhihao-224 guozhihao-224 marked this pull request as ready for review May 25, 2026 12:44
Implement the full colocate weight-update lifecycle for AwexFSDPAdapter:

- Add colocate state fields and init_colocate_weight_update to publish
  per-rank local shards via CUDA IPC handles
- Implement execute_colocate_weight_update with all-gathered metadata
  alignment and try/finally release semantics
- Implement release_memory/resume_memory('weights') for memory management
  during colocated inference
- Wrap save_parameters with colocate resume/release for checkpoint safety
- Lazy-import Megatron in adapter factory for FSDP-only users
- Add DP e2e test and unit tests for the colocate path
@guozhihao-224 guozhihao-224 force-pushed the feat/awex-fsdp-colocate branch from ee70752 to d4b9d34 Compare May 26, 2026 11:24
@guozhihao-224 guozhihao-224 reopened this May 26, 2026
@guozhihao-224 guozhihao-224 force-pushed the feat/awex-fsdp-colocate branch from 59ecf48 to d4b9d34 Compare May 26, 2026 11:43
@guozhihao-224 guozhihao-224 reopened this May 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant