From 5e79b10cf49aef81e7c8b8ab2601966a6f2af0d4 Mon Sep 17 00:00:00 2001 From: Haichuan Hu Date: Wed, 6 May 2026 16:56:07 +0800 Subject: [PATCH] feat(experimental): integrate Ray RDT for weight syncing Signed-off-by: Haichuan Hu --- .../inference_service/sglang/launch_server.py | 16 +- .../inference_service/sglang/rdt.py | 117 ++ .../inference_service/sglang/scheduler.py | 127 ++- .../training_service/controller/controller.py | 19 +- .../training_service/worker/app.py | 27 +- .../training_service/worker/rdt.py | 242 ++++ areal/experimental/weight_update/__init__.py | 21 + .../weight_update/awex/fsdp_adapter.py | 22 + .../weight_update/awex/megatron_adapter.py | 22 + .../weight_update/awex/sglang_adapter.py | 23 + .../weight_update/controller/controller.py | 3 +- .../experimental/weight_update/gateway/app.py | 231 +++- .../weight_update/gateway/config.py | 10 +- .../weight_update/rdt/__init__.py | 128 +++ .../weight_update/rdt/fsdp_adapter.py | 393 +++++++ .../weight_update/rdt/megatron_adapter.py | 273 +++++ .../weight_update/rdt/sglang_adapter.py | 583 ++++++++++ .../rdt/weight_transport_actor.py | 217 ++++ .../weight_update/test_rdt_integration.py | 1000 +++++++++++++++++ .../weight_update/test_sglang_integration.py | 118 +- .../weight_update/test_wu_controller.py | 35 +- .../torchrun/run_rdt_weight_transfer.py | 304 +++++ 22 files changed, 3889 insertions(+), 42 deletions(-) create mode 100644 areal/experimental/inference_service/sglang/rdt.py create mode 100644 areal/experimental/training_service/worker/rdt.py create mode 100644 areal/experimental/weight_update/rdt/__init__.py create mode 100644 areal/experimental/weight_update/rdt/fsdp_adapter.py create mode 100644 areal/experimental/weight_update/rdt/megatron_adapter.py create mode 100644 areal/experimental/weight_update/rdt/sglang_adapter.py create mode 100644 areal/experimental/weight_update/rdt/weight_transport_actor.py create mode 100644 tests/experimental/weight_update/test_rdt_integration.py create mode 100644 tests/experimental/weight_update/torchrun/run_rdt_weight_transfer.py diff --git a/areal/experimental/inference_service/sglang/launch_server.py b/areal/experimental/inference_service/sglang/launch_server.py index 1894398674..30f3a66b35 100644 --- a/areal/experimental/inference_service/sglang/launch_server.py +++ b/areal/experimental/inference_service/sglang/launch_server.py @@ -24,18 +24,21 @@ def areal_launch_server(server_args) -> None: from sglang.srt.managers.detokenizer_manager import run_detokenizer_process # ---- BEGIN AREAL ---- - from areal.experimental.inference_service.sglang.awex import ( - register_awex_endpoints, - ) + from areal.experimental.inference_service.sglang.awex import register_awex_endpoints + from areal.experimental.inference_service.sglang.rdt import register_rdt_endpoints from areal.experimental.inference_service.sglang.rpc_proxy import RpcProxy from areal.experimental.inference_service.sglang.scheduler import ( areal_run_scheduler_process, create_result_ipc, + get_weight_update_backend, ) # ---- END AREAL ---- # ---- BEGIN AREAL ---- - result_ipc = create_result_ipc() + backend = getattr(server_args, "weight_update_backend", None) + if backend is None: + backend = get_weight_update_backend() + result_ipc = create_result_ipc(backend) # ---- END AREAL ---- ( @@ -60,7 +63,10 @@ def areal_launch_server(server_args) -> None: # ---- BEGIN AREAL ---- rpc_proxy = RpcProxy(port_args, result_ipc) - register_awex_endpoints(app, rpc_proxy) + if backend == "awex": + register_awex_endpoints(app, rpc_proxy) + elif backend == "rdt": + register_rdt_endpoints(app, rpc_proxy) # ---- END AREAL ---- try: diff --git a/areal/experimental/inference_service/sglang/rdt.py b/areal/experimental/inference_service/sglang/rdt.py new file mode 100644 index 0000000000..742882e08f --- /dev/null +++ b/areal/experimental/inference_service/sglang/rdt.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +"""RDT HTTP endpoints for IW weight update. + +Reference: areal.experimental.inference_service.sglang.awex +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse + +from areal.utils import logging + +if TYPE_CHECKING: + from areal.experimental.inference_service.sglang.rpc_proxy import RpcProxy + +logger = logging.getLogger("RDTIWEndpoints") + + +def register_rdt_endpoints(app: FastAPI, rpc_proxy: RpcProxy) -> None: + """Register ``/rdt/*`` weight-update endpoints on IW's FastAPI app. + + Each endpoint dispatches to all scheduler processes via RpcProxy, + using collective_rpc_with_result or collective_rpc. + + Args: + app: FastAPI application + rpc_proxy: RpcProxy for scheduler subprocess communication + """ + + @app.get("/rdt/report_parallelism") + async def report_parallelism() -> JSONResponse: + """Report IW parallelism strategy for TransferPlan building.""" + try: + result = rpc_proxy.collective_rpc_with_result("rdt_report_parallelism") + if not isinstance(result, dict): + err_msg = f"Expected dict from rdt_report_parallelism, got {type(result).__name__}" + logger.error(err_msg) + return JSONResponse(status_code=500, content={"error": err_msg}) + return JSONResponse(content=result) + except Exception as e: + logger.error("Failed to report parallelism: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + @app.post("/rdt/report_weight_meta") + async def report_weight_meta() -> JSONResponse: + """Report IW weight metadata for TransferPlan building.""" + try: + result = rpc_proxy.collective_rpc_with_result("rdt_report_weight_meta") + return JSONResponse(content={"status": "ok", "meta": result}) + except Exception as e: + logger.error("Failed to report weight meta: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + @app.post("/rdt/init_weight_update_group") + async def init_weight_update_group(request: Request) -> JSONResponse: + """Initialize RDT weight update group. + + Args passed via JSON body: + pair_name: TW-IW pair identifier + kv_store_url: Gateway KV store URL + tw_actor_bytes_b64_list: Base64-encoded TW actor handle bytes + infer_world_size: Total IW world size + train_world_size: Total TW world size + num_engines: Number of IW engines + transfer_rank: IW's transfer rank + """ + try: + data = await request.json() + rpc_proxy.collective_rpc("rdt_init_weight_update_group", **data) + return JSONResponse(content={"status": "ok"}) + except Exception as e: + logger.error("Failed to init RDT weight update group: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + @app.post("/rdt/update_weights") + async def update_weights(request: Request) -> JSONResponse: + """Execute RDT weight update - pull from TW via Ray RPC. + + Args passed via JSON body: + version: Weight version number (optional, default 0) + """ + try: + data = await request.json() + version = data.get("version", 0) + rpc_proxy.collective_rpc("rdt_execute_weight_update", version=version) + return JSONResponse(content={"status": "ok", "version": version}) + except Exception as e: + logger.error("Failed to execute RDT weight update: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + # --------------------------------------------------------------------------- + # Debug endpoints for E2E testing + # --------------------------------------------------------------------------- + + @app.post("/rdt/debug/randomize_parameters") + async def randomize_parameters() -> JSONResponse: + """Randomize model parameters for testing.""" + try: + rpc_proxy.collective_rpc("rdt_randomize_parameters") + return JSONResponse(content={"status": "ok"}) + except Exception as e: + logger.error("Failed to randomize parameters: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) + + @app.post("/rdt/debug/get_parameters") + async def get_parameters(request: Request) -> JSONResponse: + """Save parameters to disk for validation.""" + try: + data = await request.json() + rpc_proxy.collective_rpc("rdt_get_parameters", **data) + return JSONResponse(content={"status": "ok"}) + except Exception as e: + logger.error("Failed to get parameters: %s", e) + return JSONResponse(status_code=500, content={"error": str(e)}) diff --git a/areal/experimental/inference_service/sglang/scheduler.py b/areal/experimental/inference_service/sglang/scheduler.py index 34149a6fc9..cd9a816ebb 100644 --- a/areal/experimental/inference_service/sglang/scheduler.py +++ b/areal/experimental/inference_service/sglang/scheduler.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""AwexSchedulerBridge + PPSchedulerBridge: compose weight-update methods onto SGLang Scheduler.""" +"""AwexSchedulerBridge/RDTSchedulerBridge + PPSchedulerBridge: compose weight-update methods onto SGLang Scheduler.""" from __future__ import annotations @@ -12,9 +12,16 @@ import zmq from sglang.srt.server_args import PortArgs, ServerArgs +from areal.experimental.weight_update import ( + BACKEND_AWEX, + BACKEND_RDT, + WEIGHT_UPDATE_BACKEND_ENV, + get_weight_update_backend, +) from areal.infra.rpc.serialization import serialize_value RESULT_IPC_ENV = "AREAL_AWEX_RESULT_IPC" +RDT_RESULT_IPC_ENV = "AREAL_RDT_RESULT_IPC" class AwexSchedulerBridge: @@ -134,6 +141,96 @@ def awex_resume_memory(self, tags: list[str] | None = None) -> None: self._require_adapter().resume_memory(tags) +class RDTSchedulerBridge: + """Compose RDT weight-update capabilities onto a plain Scheduler instance. + + Lifecycle: + 1. Created after ``Scheduler.__init__()`` in :func:`areal_run_scheduler_process` + 2. :meth:`bind` attaches ``rdt_*`` methods to the scheduler via ``setattr`` + 3. ``handle_rpc_request`` dispatches via ``getattr(self, method)`` and finds them + 4. Methods delegate to :class:`RDTSGLangAdapter` for actual work + 5. Data-returning methods push results via ZMQ PUSH (tp_rank 0, dp_rank 0 only) + + No inheritance. No monkey-patch. The scheduler instance remains a plain + ``sglang.srt.managers.scheduler.Scheduler``. + """ + + def __init__(self, scheduler: Any) -> None: + self._scheduler = scheduler + self._adapter: Any | None = None + self._result_push: zmq.Socket | None = None + + result_ipc = os.environ.get(RDT_RESULT_IPC_ENV) + if ( + result_ipc + and scheduler.tp_rank == 0 + and (getattr(scheduler, "dp_rank", None) is None or scheduler.dp_rank == 0) + ): + ctx = zmq.Context(1) + self._result_push = ctx.socket(zmq.PUSH) + self._result_push.connect(result_ipc) + + def bind(self) -> None: + """Attach ``rdt_*`` methods to the scheduler instance.""" + methods = [ + "rdt_report_weight_meta", + "rdt_report_parallelism", + "rdt_init_weight_update_group", + "rdt_execute_weight_update", + "rdt_randomize_parameters", + "rdt_get_parameters", + ] + for name in methods: + setattr(self._scheduler, name, getattr(self, name)) + + def _require_adapter(self) -> Any: + if self._adapter is None: + from areal.experimental.weight_update.rdt.sglang_adapter import ( + RDTSGLangAdapter, + ) + + self._adapter = RDTSGLangAdapter(self._scheduler) + return self._adapter + + def _push_result(self, result: Any) -> None: + if self._result_push is not None: + self._result_push.send_pyobj(result) + + def rdt_report_weight_meta(self) -> None: + adapter = self._require_adapter() + local_meta = adapter.get_weight_metadata() + s = self._scheduler + + if s.tp_size > 1: + gathered: list[list] = [[] for _ in range(s.tp_size)] + dist.all_gather_object(gathered, local_meta, group=s.tp_cpu_group) + all_meta: list = [] + for rank_meta in gathered: + all_meta.extend(rank_meta) + self._push_result(serialize_value(all_meta)) + else: + self._push_result(serialize_value(local_meta)) + + def rdt_report_parallelism(self) -> None: + self._push_result(self._require_adapter().parallelism_strategy) + + def rdt_init_weight_update_group(self, **kwargs: Any) -> None: + self._require_adapter().rdt_init_weight_update_group(**kwargs) + + def rdt_execute_weight_update(self, version: int = 0) -> None: + self._require_adapter().rdt_execute_weight_update(version) + + def rdt_randomize_parameters(self) -> None: + """Randomize model parameters for testing.""" + self._require_adapter().randomize_parameters() + + def rdt_get_parameters( + self, save_path: str, names: list[str] | None = None + ) -> None: + """Save parameters to disk for validation.""" + self._require_adapter().save_parameters(save_path, names) + + # --------------------------------------------------------------------------- # Duplicated from sglang.srt.managers.scheduler.run_scheduler_process # (SGLang commit pinned in this repo). @@ -232,7 +329,11 @@ def areal_run_scheduler_process( ) # ---- BEGIN AREAL ---- - AwexSchedulerBridge(scheduler).bind() + backend = get_weight_update_backend() + if backend == BACKEND_AWEX: + AwexSchedulerBridge(scheduler).bind() + elif backend == BACKEND_RDT: + RDTSchedulerBridge(scheduler).bind() PPSchedulerBridge(scheduler, server_args).bind() # ---- END AREAL ---- @@ -245,7 +346,23 @@ def areal_run_scheduler_process( parent_process.send_signal(signal.SIGQUIT) -def create_result_ipc() -> str: - path = f"ipc://{tempfile.mktemp(prefix='areal_result_')}" - os.environ[RESULT_IPC_ENV] = path +def create_result_ipc(backend: str) -> str: + """Create result IPC path for given backend. + + Sets environment variable for scheduler subprocess to read. + + Args: + backend: "awex" or "rdt" + + Returns: + IPC path string + """ + path = f"ipc://{tempfile.mktemp(prefix=f'areal_{backend}_result_')}" + + if backend == BACKEND_AWEX: + os.environ[RESULT_IPC_ENV] = path + elif backend == BACKEND_RDT: + os.environ[RDT_RESULT_IPC_ENV] = path + + os.environ[WEIGHT_UPDATE_BACKEND_ENV] = backend return path diff --git a/areal/experimental/training_service/controller/controller.py b/areal/experimental/training_service/controller/controller.py index 8a498630f6..8541ebc5eb 100644 --- a/areal/experimental/training_service/controller/controller.py +++ b/areal/experimental/training_service/controller/controller.py @@ -916,21 +916,32 @@ def _graceful_shutdown_workers(self) -> None: if not self._worker_addrs: return + # Get backend from scheduling_spec env_vars (worker process config) + backend = "awex" # default + if self.config.scheduling_spec: + env_vars = self.config.scheduling_spec[0].env_vars or {} + backend = env_vars.get("AREAL_WEIGHT_UPDATE_BACKEND", "awex") + + teardown_endpoint = f"/{backend}/teardown" + async def _shutdown_all() -> None: timeout = aiohttp.ClientTimeout(total=30) async with aiohttp.ClientSession(timeout=timeout) as session: tasks = [] for addr in self._worker_addrs: - tasks.append(_shutdown_one(session, addr)) + tasks.append(_shutdown_one(session, addr, teardown_endpoint)) await asyncio.gather(*tasks, return_exceptions=True) - async def _shutdown_one(session: aiohttp.ClientSession, addr: str) -> None: + async def _shutdown_one( + session: aiohttp.ClientSession, addr: str, endpoint: str + ) -> None: try: - async with session.post(f"{addr}/awex/teardown") as resp: + async with session.post(f"{addr}{endpoint}") as resp: resp.raise_for_status() except Exception as e: logger.warning( - "Graceful shutdown: failed to call /awex/teardown on %s: %s", + "Graceful shutdown: failed to call %s on %s: %s", + endpoint, addr, e, ) diff --git a/areal/experimental/training_service/worker/app.py b/areal/experimental/training_service/worker/app.py index ba5ac5bd45..db61c6aa20 100644 --- a/areal/experimental/training_service/worker/app.py +++ b/areal/experimental/training_service/worker/app.py @@ -14,6 +14,8 @@ from areal.experimental.training_service.worker.awex import create_awex_blueprint from areal.experimental.training_service.worker.config import TrainWorkerConfig from areal.experimental.training_service.worker.engine import create_engine_module +from areal.experimental.training_service.worker.rdt import create_rdt_blueprint +from areal.experimental.weight_update import get_weight_update_backend from areal.infra.platforms import current_platform from areal.infra.rpc.serialization import deserialize_value, serialize_value from areal.utils import logging @@ -198,14 +200,25 @@ def _get_node_addr() -> str: ) ) - app.register_blueprint( - create_awex_blueprint( - flask_module=flask, - get_engine=_get_engine, - submit_to_engine_thread=_submit_to_engine_thread, - run_endpoint=_run_endpoint, + backend = get_weight_update_backend() + if backend == "awex": + app.register_blueprint( + create_awex_blueprint( + flask_module=flask, + get_engine=_get_engine, + submit_to_engine_thread=_submit_to_engine_thread, + run_endpoint=_run_endpoint, + ) + ) + elif backend == "rdt": + app.register_blueprint( + create_rdt_blueprint( + flask_module=flask, + get_engine=_get_engine, + submit_to_engine_thread=_submit_to_engine_thread, + run_endpoint=_run_endpoint, + ) ) - ) from areal.infra.rpc.guard.data_blueprint import data_bp diff --git a/areal/experimental/training_service/worker/rdt.py b/areal/experimental/training_service/worker/rdt.py new file mode 100644 index 0000000000..16f05db425 --- /dev/null +++ b/areal/experimental/training_service/worker/rdt.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +"""RDT HTTP endpoints for training worker.""" + +from __future__ import annotations + +import base64 +import os +from typing import TYPE_CHECKING, Any + +import ray +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy +from torch.multiprocessing.reductions import reduce_tensor + +if TYPE_CHECKING: + from flask import Blueprint + +from areal.utils import logging + +logger = logging.getLogger("RDTTWBlueprint") + +# Module-level globals for lifecycle (prevent GC) +_rdt_actor: Any = None +_rdt_adapter: Any = None + + +def create_rdt_blueprint( + *, + flask_module: Any, + get_engine: Any, + submit_to_engine_thread: Any, + run_endpoint: Any, +) -> Blueprint: + """Create Flask blueprint for RDT weight update endpoints.""" + bp = flask_module.Blueprint("rdt", __name__, url_prefix="/rdt") + + def _ensure_actor(): + """Ensure WeightTransportActor is created and return handle.""" + global _rdt_actor + if _rdt_actor is None: + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + engine = get_engine() + if engine is None: + raise RuntimeError("Engine not initialized") + + current_node_id = ray.get_runtime_context().get_node_id() + current_visible_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + actor_name = f"weight-transport-{engine.rank}" + + try: + _rdt_actor = ray.get_actor(actor_name) + logger.info(f"Reused existing WeightTransportActor: {actor_name}") + except ValueError: + from areal.experimental.weight_update.rdt.weight_transport_actor import ( + WeightTransportActor, + ) + + _rdt_actor = WeightTransportActor.options( + name=actor_name, + num_gpus=0.0001, + max_concurrency=8, + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=current_node_id, + soft=False, + ), + runtime_env={ + "env_vars": {"CUDA_VISIBLE_DEVICES": current_visible_gpus} + }, + ).remote() + logger.info(f"Created new WeightTransportActor: {actor_name}") + return _rdt_actor + + def _get_adapter(): + """Create RDT adapter based on engine type (cached).""" + global _rdt_adapter + if _rdt_adapter is None: + engine = get_engine() + if engine is None: + raise RuntimeError("Engine not initialized") + + from areal.engine.fsdp_engine import FSDPEngine + from areal.engine.megatron_engine import MegatronEngine + from areal.experimental.weight_update.rdt.fsdp_adapter import RDTFSDPAdapter + from areal.experimental.weight_update.rdt.megatron_adapter import ( + RDTMegatronAdapter, + ) + + if isinstance(engine, FSDPEngine): + _rdt_adapter = RDTFSDPAdapter(engine) + elif isinstance(engine, MegatronEngine): + _rdt_adapter = RDTMegatronAdapter(engine) + else: + raise TypeError(f"Unsupported engine type: {type(engine).__name__}") + return _rdt_adapter + + @bp.route("/get_actor_handle", methods=["GET"]) + def get_actor_handle(): + """Return serialized WeightTransportActor handle.""" + try: + actor = _ensure_actor() + handle_bytes = ray.cloudpickle.dumps(actor) + return flask_module.jsonify( + {"actor_bytes_b64": base64.b64encode(handle_bytes).decode()} + ) + except RuntimeError as e: + return flask_module.jsonify({"error": str(e)}), 400 + + @bp.route("/report_parallelism", methods=["GET"]) + def report_parallelism(): + """Return parallelism strategy.""" + try: + adapter = _get_adapter() + return flask_module.jsonify(adapter.parallelism_strategy) + except RuntimeError as e: + return flask_module.jsonify({"error": str(e)}), 400 + + @bp.route("/report_weight_meta", methods=["POST"]) + def report_weight_meta(): + """Return parameter metadata.""" + + def action(): + adapter = _get_adapter() + return adapter.get_weight_metadata() + + return run_endpoint( + "report_weight_meta", + lambda: submit_to_engine_thread("report_weight_meta", action), + ) + + @bp.route("/init_weight_update_group", methods=["POST"]) + def init_weight_update_group(): + """Initialize TransferPlan.""" + data = flask_module.request.get_json(force=True) + + def action(): + adapter = _get_adapter() + adapter.init_weight_update_group(**data) + + return run_endpoint( + "init_weight_update_group", + lambda: submit_to_engine_thread("init_weight_update_group", action), + ) + + @bp.route("/update_weights", methods=["POST"]) + def update_weights(): + """Slice tensors → create IPC handles → actor.store_ipc_handles.""" + data = flask_module.request.get_json(force=True) + pair_name = data["pair_name"] + version = data.get("version", 0) + + def action(): + import time + + t0 = time.monotonic() + actor = _ensure_actor() + adapter = _get_adapter() + + plan = adapter._transfer_plans.get(pair_name) + if not plan: + raise RuntimeError(f"TransferPlan not found for {pair_name}") + + t1 = time.monotonic() + local_params = adapter.get_local_shard_parameters() + t2 = time.monotonic() + + # Build IPC handles dict for each infer_rank + for send_rank, operations in plan.inter_operations.items(): + infer_rank = operations[0].recv_shard_meta.global_rank + ipc_handles = {} + + for op in operations: + full_tensor = local_params.get(op.send_shard_meta.name) + if full_tensor is None: + logger.warning(f"Tensor not found: {op.send_shard_meta.name}") + continue + + sliced = full_tensor[op.train_slices] + sliced.share_memory_() + rebuild_fn, tensor_meta = reduce_tensor(sliced) + ipc_handles[op.recv_shard_meta.name] = { + "rebuild_fn": rebuild_fn, + "tensor_meta": tensor_meta, + } + + t3 = time.monotonic() + ray.get( + actor.store_ipc_handles.remote( + pair_name, infer_rank, version, ipc_handles + ) + ) + t4 = time.monotonic() + + logger.info( + f"[RDT-TW-Timing] get_params={1000 * (t2 - t1):.1f}ms | " + f"slice_ipc={1000 * (t3 - t2):.1f}ms | " + f"store_handles={1000 * (t4 - t3):.1f}ms | " + f"total={1000 * (t4 - t0):.1f}ms" + ) + + logger.info(f"[RDT-TW] Prepared weights for pair '{pair_name}' v{version}") + + return run_endpoint( + "update_weights", + lambda: submit_to_engine_thread("update_weights", action), + ) + + @bp.route("/teardown", methods=["POST"]) + def teardown(): + """Clear all TransferPlans and cleanup adapter state.""" + global _rdt_adapter + if _rdt_adapter is None: + return flask_module.jsonify({"status": "success"}) + + def action(): + global _rdt_adapter + _rdt_adapter.teardown_weight_update_group() + _rdt_adapter = None + + return run_endpoint( + "rdt_teardown", + lambda: submit_to_engine_thread("rdt_teardown", action), + return_result=False, + ) + + @bp.route("/debug/get_parameters", methods=["POST"]) + def get_parameters(): + """Save local shard parameters to file.""" + data = flask_module.request.get_json(force=True) + save_path = data["save_path"] + names = data.get("names") + + def action(): + adapter = _get_adapter() + adapter.save_parameters(save_path, names) + + return run_endpoint( + "get_parameters", + lambda: submit_to_engine_thread("get_parameters", action), + return_result=False, + ) + + return bp diff --git a/areal/experimental/weight_update/__init__.py b/areal/experimental/weight_update/__init__.py index 14fd6c00a5..bf78bd6421 100644 --- a/areal/experimental/weight_update/__init__.py +++ b/areal/experimental/weight_update/__init__.py @@ -1,12 +1,33 @@ # SPDX-License-Identifier: Apache-2.0 """Weight update protocol adapters for training and inference.""" +import os + from areal.experimental.weight_update.controller import ( WeightUpdateController, WeightUpdateControllerConfig, ) +WEIGHT_UPDATE_BACKEND_ENV = "AREAL_WEIGHT_UPDATE_BACKEND" +BACKEND_AWEX = "awex" +BACKEND_RDT = "rdt" +BACKEND_DISK = "disk" + + +def get_weight_update_backend() -> str: + """Get weight update backend from env or default to awex.""" + backend = os.environ.get(WEIGHT_UPDATE_BACKEND_ENV, BACKEND_AWEX) + if backend not in (BACKEND_AWEX, BACKEND_RDT): + raise ValueError(f"Invalid backend: {backend}, must be awex or rdt") + return backend + + __all__ = [ "WeightUpdateController", "WeightUpdateControllerConfig", + "WEIGHT_UPDATE_BACKEND_ENV", + "BACKEND_AWEX", + "BACKEND_RDT", + "BACKEND_DISK", + "get_weight_update_backend", ] diff --git a/areal/experimental/weight_update/awex/fsdp_adapter.py b/areal/experimental/weight_update/awex/fsdp_adapter.py index aff173e06a..6e46c4213d 100644 --- a/areal/experimental/weight_update/awex/fsdp_adapter.py +++ b/areal/experimental/weight_update/awex/fsdp_adapter.py @@ -151,7 +151,11 @@ def init_weight_update_group( ) def execute_weight_update(self, version: int) -> None: + import time + del version + t0 = time.monotonic() + if self._transfer_plan is None: raise RuntimeError("Transfer plan is not initialized") if self._weights_update_group is None: @@ -159,15 +163,33 @@ def execute_weight_update(self, version: int) -> None: if self._transfer_rank is None: raise RuntimeError("Transfer rank is not initialized") + t1 = time.monotonic() + params = self.get_local_shard_parameters() + t2 = time.monotonic() + send_ops, _, _ = nccl_build_send_ops( params, self._transfer_plan, self._weights_update_group, copy_rank=self._transfer_rank, ) + t3 = time.monotonic() + batch_send_recv(send_ops=send_ops, recv_ops=[], blocking=True) + t4 = time.monotonic() + torch.distributed.barrier(group=self._weights_update_group) + t5 = time.monotonic() + + logger.info( + f"[Awex-FSDP-TW-Timing] prep={1000 * (t1 - t0):.1f}ms | " + f"get_params={1000 * (t2 - t1):.1f}ms | " + f"build_ops={1000 * (t3 - t2):.1f}ms | " + f"nccl_send={1000 * (t4 - t3):.1f}ms | " + f"barrier={1000 * (t5 - t4):.1f}ms | " + f"total={1000 * (t5 - t0):.1f}ms" + ) def batch_isend_irecv(self, **kwargs) -> None: setup_kwargs = {k: v for k, v in kwargs.items() if k != "world_size"} diff --git a/areal/experimental/weight_update/awex/megatron_adapter.py b/areal/experimental/weight_update/awex/megatron_adapter.py index 6002a7fa8b..777d3dceb6 100644 --- a/areal/experimental/weight_update/awex/megatron_adapter.py +++ b/areal/experimental/weight_update/awex/megatron_adapter.py @@ -183,7 +183,11 @@ def init_weight_update_group( ) def execute_weight_update(self, version: int) -> None: + import time + del version + t0 = time.monotonic() + if self._transfer_plan is None: raise RuntimeError("Transfer plan is not initialized") if self._weights_update_group is None: @@ -191,15 +195,33 @@ def execute_weight_update(self, version: int) -> None: if self._transfer_rank is None: raise RuntimeError("Transfer rank is not initialized") + t1 = time.monotonic() + params = self.get_local_shard_parameters() + t2 = time.monotonic() + send_ops, _, _ = nccl_build_send_ops( params, self._transfer_plan, self._weights_update_group, copy_rank=self._transfer_rank, ) + t3 = time.monotonic() + batch_send_recv(send_ops=send_ops, recv_ops=[], blocking=True) + t4 = time.monotonic() + dist.barrier(group=self._weights_update_group) + t5 = time.monotonic() + + logger.info( + f"[Awex-TW-Timing] prep={1000 * (t1 - t0):.1f}ms | " + f"get_params={1000 * (t2 - t1):.1f}ms | " + f"build_ops={1000 * (t3 - t2):.1f}ms | " + f"nccl_send={1000 * (t4 - t3):.1f}ms | " + f"barrier={1000 * (t5 - t4):.1f}ms | " + f"total={1000 * (t5 - t0):.1f}ms" + ) def batch_isend_irecv(self, **kwargs) -> None: setup_kwargs = {k: v for k, v in kwargs.items() if k != "world_size"} diff --git a/areal/experimental/weight_update/awex/sglang_adapter.py b/areal/experimental/weight_update/awex/sglang_adapter.py index 4ebcf4c8c5..2c4931e441 100644 --- a/areal/experimental/weight_update/awex/sglang_adapter.py +++ b/areal/experimental/weight_update/awex/sglang_adapter.py @@ -381,24 +381,47 @@ def init_weight_update_group( ) def execute_weight_update(self, version: int) -> None: + import time + del version + t0 = time.monotonic() + if self._transfer_plan is None: raise RuntimeError("Transfer plan is not initialized") if self._weights_update_group is None: raise RuntimeError("Weight update group is not initialized") + t1 = time.monotonic() + params = self.get_local_shard_parameters() + t2 = time.monotonic() + recv_ops, non_contiguous_pairs, _ = nccl_build_recv_ops( params, self._transfer_plan, self._weights_update_group, ) + t3 = time.monotonic() + batch_send_recv(send_ops=[], recv_ops=recv_ops, blocking=True) + t4 = time.monotonic() for original, contiguous in non_contiguous_pairs: original.copy_(contiguous) + t5 = time.monotonic() dist.barrier(group=self._weights_update_group) + t6 = time.monotonic() + + logger.info( + f"[Awex-IW-Timing] prep={1000 * (t1 - t0):.1f}ms | " + f"get_params={1000 * (t2 - t1):.1f}ms | " + f"build_ops={1000 * (t3 - t2):.1f}ms | " + f"nccl_recv={1000 * (t4 - t3):.1f}ms | " + f"copy_non_contiguous={1000 * (t5 - t4):.1f}ms | " + f"barrier={1000 * (t6 - t5):.1f}ms | " + f"total={1000 * (t6 - t0):.1f}ms" + ) def batch_isend_irecv(self, **kwargs) -> None: setup_kwargs = {k: v for k, v in kwargs.items() if k != "world_size"} diff --git a/areal/experimental/weight_update/controller/controller.py b/areal/experimental/weight_update/controller/controller.py index a119e1b59f..182d2591c7 100644 --- a/areal/experimental/weight_update/controller/controller.py +++ b/areal/experimental/weight_update/controller/controller.py @@ -9,6 +9,7 @@ import httpx +from areal.experimental.weight_update import BACKEND_AWEX from areal.experimental.weight_update.controller.config import ( WeightUpdateControllerConfig, ) @@ -108,7 +109,7 @@ def connect( pair_name: str, train_worker_urls: list[str], inference_worker_urls: list[str], - mode: str = "awex", + mode: str = BACKEND_AWEX, save_path: str = "", use_lora: bool = False, lora_name: str = "", diff --git a/areal/experimental/weight_update/gateway/app.py b/areal/experimental/weight_update/gateway/app.py index d780355d0b..8ba8029ddf 100644 --- a/areal/experimental/weight_update/gateway/app.py +++ b/areal/experimental/weight_update/gateway/app.py @@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse # pyright: ignore[reportMissingImports] from pydantic import BaseModel # pyright: ignore[reportMissingImports] +from areal.experimental.weight_update import BACKEND_AWEX, BACKEND_DISK, BACKEND_RDT from areal.experimental.weight_update.gateway.auth import require_admin_key from areal.experimental.weight_update.gateway.config import ( PairInfo, @@ -34,7 +35,7 @@ class ConnectRequest(BaseModel): inference_worker_urls: list[str] nccl_master_addr: str = "" nccl_master_port: int = 0 - mode: str = "awex" # "awex" or "disk" + mode: str = BACKEND_AWEX # "awex", BACKEND_DISK, or BACKEND_RDT save_path: str = "" use_lora: bool = False lora_name: str = "" @@ -203,7 +204,7 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse: request, pair_name, train_urls, inference_urls ) - if body.mode == "disk": + if body.mode == BACKEND_DISK: if not body.save_path: return JSONResponse( status_code=400, @@ -225,7 +226,7 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse: pair_name=pair_name, train_worker_urls=train_urls, inference_worker_urls=inference_urls, - mode="disk", + mode=BACKEND_DISK, save_path=body.save_path, use_lora=body.use_lora, lora_name=body.lora_name, @@ -236,6 +237,19 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse: ) return ConnectResponse(pair_name=pair_name) + if body.mode == BACKEND_RDT: + session = request.app.state.http_session + init_timeout_s = config.init_timeout_s + await _rdt_connect( + pair_name, + train_urls, + inference_urls, + session, + init_timeout_s, + config, + ) + return ConnectResponse(pair_name=pair_name) + session = request.app.state.http_session init_timeout_s = config.init_timeout_s @@ -301,8 +315,9 @@ async def connect(request: Request, body: ConnectRequest) -> ConnectResponse: kv_store.put(pair_name, "training_params_meta", training_params_meta) kv_store.put(pair_name, "infer_params_meta", infer_params_meta) - master_addr = body.nccl_master_addr - master_port = body.nccl_master_port + # Auto-generate master_addr/port if not provided + master_addr = body.nccl_master_addr or _get_own_ip() + master_port = body.nccl_master_port or find_free_ports(1)[0] # Use the bound host for kv_store_url so workers can reach the # gateway. When bound to 0.0.0.0 any interface works, so fall @@ -684,6 +699,202 @@ async def _disk_transfer_weights( ] ) + async def _rdt_connect( + pair_name: str, + train_urls: list[str], + inference_urls: list[str], + session: aiohttp.ClientSession, + timeout_s: float, + config: WeightUpdateConfig, + ) -> PairInfo: + """Connect TW-IW pair for RDT mode. + + Flow: + 1. Get TW actor handles (Base64 encoded) from each TW + 2. Get IW parallelism info + 3. Get weight metadata from TW and IW (for TransferPlan building) + 4. Store metadata in KV store + 5. Init IW with TW handles via /rdt/init_weight_update_group + + No NCCL process group init needed for RDT. + """ + + # 1. Get TW actor handles + tw_actor_handle_tasks = [ + _get_json(session, f"{url}/rdt/get_actor_handle", timeout_s) + for url in train_urls + ] + tw_actor_handle_resps = await asyncio.gather(*tw_actor_handle_tasks) + tw_actor_bytes_b64_list = [ + resp["actor_bytes_b64"] for resp in tw_actor_handle_resps + ] + + # 2. Get parallelism info from TW and IW + train_par, infer_par = await asyncio.gather( + _get_json( + session, + f"{train_urls[0]}/rdt/report_parallelism", + timeout_s, + ), + _get_json( + session, + f"{inference_urls[0]}/rdt/report_parallelism", + timeout_s, + ), + ) + + train_world_size = train_par["world_size"] + infer_world_size = infer_par["world_size"] + num_engines = len(inference_urls) + total_infer_ranks = infer_world_size * num_engines + + # 3. Get weight metadata (for TransferPlan building) + train_meta_resps, infer_meta_resps = await asyncio.gather( + asyncio.gather( + *[ + _post_json(session, f"{url}/rdt/report_weight_meta", timeout_s) + for url in train_urls + ] + ), + asyncio.gather( + *[ + _post_json(session, f"{url}/rdt/report_weight_meta", timeout_s) + for url in inference_urls + ] + ), + ) + + # 4. Store metadata in KV store + training_params_meta = [] + for result in train_meta_resps: + meta = result.get("result", result.get("meta", result)) + if isinstance(meta, list): + training_params_meta.extend(meta) + else: + training_params_meta.append(meta) + training_params_meta = _merge_training_meta_by_name(training_params_meta) + + infer_params_meta = [] + for result in infer_meta_resps: + meta = result.get("result", result.get("meta", result)) + if isinstance(meta, list): + infer_params_meta.extend(meta) + else: + infer_params_meta.append(meta) + + kv_store.put(pair_name, "training_params_meta", training_params_meta) + kv_store.put(pair_name, "infer_params_meta", infer_params_meta) + + # 5. Build kv_store_url (same as awex) + master_addr = _get_own_ip() + gateway_addr = master_addr if config.host in ("0.0.0.0", "::") else config.host + kv_store_url = f"http://{gateway_addr}:{config.gateway_port}" + + # 6. Init TW with TransferPlan (TW needs to know what to send to each IW) + tw_init_tasks = [] + for i, url in enumerate(train_urls): + tw_init_tasks.append( + _post( + session, + f"{url}/rdt/init_weight_update_group", + timeout_s, + json_data={ + "pair_name": pair_name, + "kv_store_url": kv_store_url, + "infer_world_size": total_infer_ranks, + "train_world_size": train_world_size, + "num_engines": num_engines, + "transfer_rank": total_infer_ranks + + i, # TW ranks start after IW + }, + ) + ) + await asyncio.gather(*tw_init_tasks) + + # 7. Init IW with TW handles + iw_init_payload_base = { + "pair_name": pair_name, + "kv_store_url": kv_store_url, + "tw_actor_bytes_b64_list": tw_actor_bytes_b64_list, + "infer_world_size": total_infer_ranks, + "train_world_size": train_world_size, + "num_engines": num_engines, + } + + iw_init_tasks = [] + for i, url in enumerate(inference_urls): + iw_init_tasks.append( + _post( + session, + f"{url}/rdt/init_weight_update_group", + timeout_s, + json_data={**iw_init_payload_base, "transfer_rank": i}, + ) + ) + await asyncio.gather(*iw_init_tasks) + + # 8. Create PairInfo + pair_info = PairInfo( + pair_name=pair_name, + train_worker_urls=train_urls, + inference_worker_urls=inference_urls, + train_world_size=train_world_size, + inference_world_size=infer_world_size, + mode=BACKEND_RDT, + ) + registry.register(pair_info) + + logger.info("Connected RDT pair '%s'", pair_name) + return pair_info + + async def _rdt_transfer_weights( + pair_info: PairInfo, + version: int, + session: aiohttp.ClientSession, + timeout_s: float, + ) -> dict[str, float]: + """Execute RDT weight update. + + Flow: + 1. TW prepares IPC handles via /rdt/update_weights + 2. IW pulls weights via /rdt/update_weights (calls TW actor via Ray RPC) + + Returns timing dict for performance analysis. + """ + timings: dict[str, float] = {} + + # 1. TW prepares IPC handles + tw_start = time.monotonic() + await asyncio.gather( + *[ + _post( + session, + f"{url}/rdt/update_weights", + timeout_s, + json_data={"pair_name": pair_info.pair_name, "version": version}, + ) + for url in pair_info.train_worker_urls + ] + ) + timings["tw_prepare_ipc_handles"] = (time.monotonic() - tw_start) * 1000 + + # 2. IW pulls weights + iw_start = time.monotonic() + await asyncio.gather( + *[ + _post( + session, + f"{url}/rdt/update_weights", + timeout_s, + json_data={"version": version}, + ) + for url in pair_info.inference_worker_urls + ] + ) + timings["iw_pull_weights"] = (time.monotonic() - iw_start) * 1000 + + return timings + @app.post("/update_weights") async def update_weights( request: Request, body: UpdateWeightsRequest @@ -712,10 +923,18 @@ async def update_weights( await _colocate_transfer_weights( pair_info, body.version, session, timeout_s ) - elif pair_info.mode == "disk": + elif pair_info.mode == BACKEND_DISK: await _disk_transfer_weights( pair_info, body.version, session, timeout_s ) + elif pair_info.mode == BACKEND_RDT: + timings = await _rdt_transfer_weights( + pair_info, body.version, session, timeout_s + ) + logger.info( + "RDT timing breakdown: %s", + ", ".join(f"{k}={v:.1f}ms" for k, v in timings.items()), + ) else: await _awex_transfer_weights( pair_info, body.version, session, timeout_s diff --git a/areal/experimental/weight_update/gateway/config.py b/areal/experimental/weight_update/gateway/config.py index db564d8c77..dc8e6ab5c8 100644 --- a/areal/experimental/weight_update/gateway/config.py +++ b/areal/experimental/weight_update/gateway/config.py @@ -6,6 +6,8 @@ from pydantic import BaseModel # pyright: ignore[reportMissingImports] +from areal.experimental.weight_update import BACKEND_AWEX, BACKEND_DISK, BACKEND_RDT + @dataclass class WeightUpdateConfig: @@ -55,7 +57,7 @@ class PairInfo: last_version: int = 0 # Disk-mode fields (used when mode="disk") - mode: str = "awex" # "awex" or "disk" + mode: str = BACKEND_AWEX # "awex", "disk", or "rdt" save_path: str = "" use_lora: bool = False lora_name: str = "" @@ -66,5 +68,7 @@ class PairInfo: def __post_init__(self): if not self.pair_name: raise ValueError("pair_name must not be empty") - if self.mode not in ("awex", "disk"): - raise ValueError(f"mode must be 'awex' or 'disk', got '{self.mode}'") + if self.mode not in (BACKEND_AWEX, BACKEND_DISK, BACKEND_RDT): + raise ValueError( + f"mode must be '{BACKEND_AWEX}', '{BACKEND_DISK}', or '{BACKEND_RDT}', got '{self.mode}'" + ) diff --git a/areal/experimental/weight_update/rdt/__init__.py b/areal/experimental/weight_update/rdt/__init__.py new file mode 100644 index 0000000000..2ed7787f72 --- /dev/null +++ b/areal/experimental/weight_update/rdt/__init__.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +"""RDT weight update backend using one-sided RDMA (YR/NIXL). + +See docs/rfc/rdt_weight_update_backend.md for details. +""" + +from __future__ import annotations + +import asyncio +import base64 +from typing import Any + +import aiohttp # pyright: ignore[reportMissingImports] + +from areal.infra.rpc.serialization import deserialize_value +from areal.infra.utils.concurrent import run_async_task +from areal.utils import logging + +logger = logging.getLogger("RDTWeightUpdate") + + +async def _fetch_kv_metadata_async( + kv_store_url: str, + pair_name: str, +) -> tuple[Any, Any]: + """Fetch infer and training parameter metadata from gateway KV store. + + Args: + kv_store_url: Gateway URL (e.g., "http://10.0.0.1:7080") + pair_name: Unique identifier for the TW-IW pair + + Returns: + tuple[Any, Any]: (infer_params_meta, training_params_meta) + """ + infer_url = f"{kv_store_url}/weight_meta/{pair_name}/infer_params_meta" + train_url = f"{kv_store_url}/weight_meta/{pair_name}/training_params_meta" + + async with aiohttp.ClientSession() as session: + + async def _get(url: str) -> Any: + async with session.get(url) as resp: + resp.raise_for_status() + data = await resp.json() + return data.get("value", data) + + infer_json, train_json = await asyncio.gather(_get(infer_url), _get(train_url)) + + return deserialize_value(infer_json), deserialize_value(train_json) + + +def fetch_kv_metadata(kv_store_url: str, pair_name: str) -> tuple[Any, Any]: + """Sync wrapper around :func:`_fetch_kv_metadata_async`. + + Args: + kv_store_url: Gateway URL + pair_name: TW-IW pair identifier + + Returns: + tuple[Any, Any]: (infer_params_meta, training_params_meta) + """ + return run_async_task(_fetch_kv_metadata_async, kv_store_url, pair_name) + + +def serialize_actor_handle_bytes(actor_handle: Any) -> str: + """Serialize actor handle to Base64-encoded cloudpickle bytes. + + Args: + actor_handle: Ray actor handle + + Returns: + str: Base64-encoded string + """ + import ray + + handle_bytes = ray.cloudpickle.dumps(actor_handle) + return base64.b64encode(handle_bytes).decode() + + +def deserialize_actor_handle_bytes(actor_bytes_b64: str) -> Any: + """Deserialize Base64-encoded actor handle bytes. + + Args: + actor_bytes_b64: Base64-encoded cloudpickle bytes + + Returns: + Any: Ray actor handle + """ + import ray + + actor_bytes = base64.b64decode(actor_bytes_b64) + return ray.cloudpickle.loads(actor_bytes) + + +def get_tensor_transport() -> str: + """Get appropriate tensor transport based on device type. + + Returns: + str: "YR" for NPU, "NIXL" for CUDA GPU + """ + from areal.infra.platforms import current_platform + + device_type = current_platform.device_type + if device_type == "npu": + return "YR" + elif device_type == "cuda": + return "NIXL" + else: + raise RuntimeError(f"Unsupported device type for RDT: {device_type}") + + +__all__ = [ + "fetch_kv_metadata", + "serialize_actor_handle_bytes", + "deserialize_actor_handle_bytes", + "get_tensor_transport", + "WeightTransportActor", +] + + +# WeightTransportActor - lazy import to avoid circular dependencies +def __getattr__(name: str): + if name == "WeightTransportActor": + from areal.experimental.weight_update.rdt.weight_transport_actor import ( + WeightTransportActor, + ) + + return WeightTransportActor + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/areal/experimental/weight_update/rdt/fsdp_adapter.py b/areal/experimental/weight_update/rdt/fsdp_adapter.py new file mode 100644 index 0000000000..65af75dbdf --- /dev/null +++ b/areal/experimental/weight_update/rdt/fsdp_adapter.py @@ -0,0 +1,393 @@ +# SPDX-License-Identifier: Apache-2.0 +"""RDT FSDP Adapter for training-side weight update.""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import torch +from awex.meta.weight_meta import ( + ParameterMeta, + ParameterReplicaMeta, + ParameterShardMeta, +) +from awex.sharding.param_sharding import ShardingType +from awex.sharding.rank_info import RankInfo +from awex.transfer.transfer_plan import TransferPlan, TransferPlanBuilder +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.placement_types import Shard + +from areal.engine.core.model import is_qwen_vl_model +from areal.experimental.weight_update.rdt import fetch_kv_metadata +from areal.utils import logging + +if TYPE_CHECKING: + from areal.engine.fsdp_engine import FSDPEngine + +logger = logging.getLogger("RDTFSDPAdapter") + + +class RDTFSDPAdapter: + """RDT training adapter for FSDPEngine.""" + + def __init__(self, engine: FSDPEngine): + """Initialize adapter with FSDPEngine reference. + + Args: + engine: FSDPEngine instance holding the model + """ + self._engine = engine + self._transfer_plans: dict[str, TransferPlan] = {} + self._transfer_ranks: dict[str, int] = {} + + @property + def parallelism_strategy(self) -> dict: + """Parallelism strategy for TransferPlan building. + + Returns: + dict: Contains world_size, tp_size, pp_size, dp_size, ep_size + """ + mesh = self._engine.world_mesh + dim_names = tuple(mesh.mesh_dim_names or ()) + tp_size = mesh.size(dim_names.index("sp_tp")) if "sp_tp" in dim_names else 1 + + return { + "world_size": self._engine.world_size, + "tp_size": tp_size, + "pp_size": 1, + "dp_size": self._engine.data_parallel_world_size, + "ep_size": 1, + "dp_replicated": False, + } + + def get_weight_metadata(self) -> list[ParameterMeta]: + """Extract parameter shard metadata for TransferPlan building. + + Returns: + list[ParameterMeta]: Parameter metadata for all model parameters + """ + rank_info = self._build_rank_info() + metadata: list[ParameterMeta] = [] + + # Skip lm_head.weight if tie_word_embeddings=True (shared with embed_tokens) + tie_word_embeddings = getattr( + self._engine.model_config, "tie_word_embeddings", False + ) + + for raw_name, param in self._engine.model.named_parameters(): + name = self._to_hf_name(raw_name) + if tie_word_embeddings and name == "lm_head.weight": + continue + tensor = param.data + if isinstance(tensor, DTensor): + shard_meta = self._extract_dtensor_shard_meta(name, tensor, rank_info) + global_shape = tuple(tensor.shape) + global_numel = int(tensor.numel()) + dtype = tensor.dtype + else: + shard_meta = self._extract_plain_shard_meta(name, tensor, rank_info) + global_shape = tuple(tensor.shape) + global_numel = int(tensor.numel()) + dtype = tensor.dtype + + replica = ParameterReplicaMeta(shards=[shard_meta]) + metadata.append( + ParameterMeta( + name=name, + global_numel=global_numel, + global_shape=global_shape, + dtype=dtype, + shards=[shard_meta], + replicas=[replica], + ) + ) + + return metadata + + def get_local_shard_parameters( + self, required_names: list[str] | None = None + ) -> dict[str, torch.Tensor]: + """Return local shard tensors in HF naming. + + Args: + required_names: Optional filter for specific parameters + + Returns: + dict[str, torch.Tensor]: Local shard tensors by HF name + """ + required = set(required_names) if required_names else None + local_params: dict[str, torch.Tensor] = {} + + # Skip lm_head.weight if tie_word_embeddings=True (shared with embed_tokens) + tie_word_embeddings = getattr( + self._engine.model_config, "tie_word_embeddings", False + ) + + for raw_name, param in self._engine.model.named_parameters(): + name = self._to_hf_name(raw_name) + if tie_word_embeddings and name == "lm_head.weight": + continue + if required is not None and name not in required: + continue + + tensor = param.data + if isinstance(tensor, DTensor): + local_params[name] = tensor._local_tensor + else: + local_params[name] = tensor + + return local_params + + def save_parameters(self, save_path: str, names: list[str] | None = None) -> None: + """Save local shard parameters to file for debugging. + + Args: + save_path: File path to save parameters + names: Optional filter for specific parameters + """ + params = self.get_local_shard_parameters(names) + cpu_params = {k: v.detach().cpu().clone() for k, v in params.items()} + torch.save(cpu_params, save_path) + + def init_weight_update_group( + self, + pair_name: str, + kv_store_url: str, + infer_world_size: int, + train_world_size: int, + num_engines: int, + transfer_rank: int, + ) -> None: + """Initialize RDT weight update group for TW. + + Args: + pair_name: TW-IW pair identifier + kv_store_url: Gateway KV store URL + infer_world_size: Total IW world size + train_world_size: Total TW world size + num_engines: Number of IW engines + transfer_rank: TW's transfer rank + """ + self._transfer_ranks[pair_name] = transfer_rank + + infer_meta, train_meta = fetch_kv_metadata(kv_store_url, pair_name) + + builder = TransferPlanBuilder( + infer_world_size=infer_world_size, + train_world_size=train_world_size, + num_infer_engines=num_engines, + ) + self._transfer_plans[pair_name] = builder.build_local_transfer_plan( + infer_meta, train_meta, global_transfer_rank=transfer_rank + ) + logger.info( + f"RDT TW init: Built TransferPlan for pair '{pair_name}' transfer_rank={transfer_rank}" + ) + + def teardown_weight_update_group(self, pair_name: str | None = None) -> None: + """Clear stored TransferPlans. + + Args: + pair_name: Optional specific pair to teardown; clears all if None + """ + if pair_name: + self._transfer_plans.pop(pair_name, None) + self._transfer_ranks.pop(pair_name, None) + else: + self._transfer_plans.clear() + self._transfer_ranks.clear() + + def _to_hf_name(self, name: str) -> str: + """Convert to HuggingFace canonical format for Qwen-VL. + + Args: + name: Internal parameter name + + Returns: + str: HF canonical name + """ + if self._engine.is_vision_model and is_qwen_vl_model( + self._engine.model_config.model_type + ): + new_name = name + if new_name.startswith("model.model."): + new_name = new_name.replace("model.model.", "model.", 1) + if new_name.startswith("model.visual."): + new_name = new_name.replace("model.", "", 1) + return new_name + return name + + def _build_rank_info(self) -> RankInfo: + """Build RankInfo for shard metadata extraction. + + Returns: + RankInfo: Rank information for current worker + """ + mesh = self._engine.world_mesh + dim_names = tuple(mesh.mesh_dim_names or ()) + + tp_size = mesh.size(dim_names.index("sp_tp")) if "sp_tp" in dim_names else 1 + tp_rank = ( + mesh.get_local_rank(dim_names.index("sp_tp")) if "sp_tp" in dim_names else 0 + ) + cp_size = mesh.size(dim_names.index("sp")) if "sp" in dim_names else 1 + cp_rank = mesh.get_local_rank(dim_names.index("sp")) if "sp" in dim_names else 0 + local_rank = int(os.environ.get("LOCAL_RANK", self._engine.rank)) + + return RankInfo( + tp_rank=tp_rank, + tp_size=tp_size, + pp_rank=0, + pp_size=1, + dp_size=self._engine.data_parallel_world_size, + dp_rank=self._engine.dp_rank, + ep_rank=0, + ep_size=1, + ep_tp_rank=0, + ep_tp_size=1, + attn_tp_rank=tp_rank, + attn_tp_size=tp_size, + attn_dp_rank=self._engine.dp_rank, + world_size=self._engine.world_size, + global_rank=self._engine.rank, + local_rank=local_rank, + engine_rank=0, + is_infer=False, + cp_rank=cp_rank, + cp_size=cp_size, + cp_mode="none", + ) + + @staticmethod + def _compute_dtensor_offset(dtensor: DTensor) -> tuple[int, ...]: + """Compute global offset for DTensor shard. + + Args: + dtensor: Distributed tensor + + Returns: + tuple[int, ...]: Global offset for local shard + """ + global_shape = tuple(dtensor.shape) + placements = dtensor.placements + mesh = dtensor.device_mesh + + offset = [0] * len(global_shape) + remaining_shape = list(global_shape) + + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + shard_dim = placement.dim + mesh_size = mesh.size(mesh_dim) + chunk_size = remaining_shape[shard_dim] // mesh_size + coord = mesh.get_local_rank(mesh_dim) + offset[shard_dim] += coord * chunk_size + remaining_shape[shard_dim] = chunk_size + + return tuple(offset) + + @staticmethod + def _extract_dtensor_sharding(dtensor: DTensor) -> tuple[int, int]: + """Extract sharding dimension and num_shards from DTensor. + + Args: + dtensor: Distributed tensor + + Returns: + tuple[int, int]: (sharding_dim, num_shards) + """ + shard_info: dict[int, int] = {} + for mesh_dim, placement in enumerate(dtensor.placements): + if isinstance(placement, Shard): + dim = placement.dim + mesh_size = dtensor.device_mesh.size(mesh_dim) + shard_info[dim] = shard_info.get(dim, 1) * mesh_size + + if not shard_info: + return 0, 1 + + primary_dim = max(shard_info.items(), key=lambda item: item[1])[0] + return primary_dim, shard_info[primary_dim] + + def _extract_dtensor_shard_meta( + self, + name: str, + dtensor: DTensor, + rank_info: RankInfo, + ) -> ParameterShardMeta: + """Extract ParameterShardMeta for DTensor. + + Args: + name: Parameter name + dtensor: Distributed tensor + rank_info: Rank information + + Returns: + ParameterShardMeta: Shard metadata + """ + local_tensor = dtensor._local_tensor + sharding_dim, num_shards = self._extract_dtensor_sharding(dtensor) + sharding_type = ( + ShardingType.TP_SHARDING if num_shards > 1 else ShardingType.NO_SHARDING + ) + + return ParameterShardMeta( + tp_rank=rank_info.tp_rank, + attn_tp_rank=rank_info.attn_tp_rank, + pp_rank=rank_info.pp_rank, + ep_rank=rank_info.ep_rank, + ep_tp_rank=rank_info.ep_tp_rank, + global_rank=rank_info.global_rank, + world_size=rank_info.world_size, + engine_rank=rank_info.engine_rank, + cp_rank=rank_info.cp_rank, + cp_size=rank_info.cp_size, + cp_mode=rank_info.cp_mode, + name=name, + shape=tuple(local_tensor.shape), + numel=int(local_tensor.numel()), + dtype=local_tensor.dtype, + global_offset=self._compute_dtensor_offset(dtensor), + sharding_type=sharding_type, + num_shards=num_shards, + sharding_dim=sharding_dim, + ) + + def _extract_plain_shard_meta( + self, + name: str, + tensor: torch.Tensor, + rank_info: RankInfo, + ) -> ParameterShardMeta: + """Extract ParameterShardMeta for plain tensor. + + Args: + name: Parameter name + tensor: Plain tensor (non-DTensor) + rank_info: Rank information + + Returns: + ParameterShardMeta: Shard metadata + """ + return ParameterShardMeta( + tp_rank=rank_info.tp_rank, + attn_tp_rank=rank_info.attn_tp_rank, + pp_rank=rank_info.pp_rank, + ep_rank=rank_info.ep_rank, + ep_tp_rank=rank_info.ep_tp_rank, + global_rank=rank_info.global_rank, + world_size=rank_info.world_size, + engine_rank=rank_info.engine_rank, + cp_rank=rank_info.cp_rank, + cp_size=rank_info.cp_size, + cp_mode=rank_info.cp_mode, + name=name, + shape=tuple(tensor.shape), + numel=int(tensor.numel()), + dtype=tensor.dtype, + global_offset=tuple([0] * len(tuple(tensor.shape))), + sharding_type=ShardingType.NO_SHARDING, + num_shards=1, + sharding_dim=0, + ) diff --git a/areal/experimental/weight_update/rdt/megatron_adapter.py b/areal/experimental/weight_update/rdt/megatron_adapter.py new file mode 100644 index 0000000000..62f213ca91 --- /dev/null +++ b/areal/experimental/weight_update/rdt/megatron_adapter.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +"""RDT Megatron Adapter for training-side weight update.""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import torch +from awex.meta.weight_meta import ( + ParameterMeta, + ParameterReplicaMeta, + ParameterShardMeta, +) +from awex.sharding.param_sharding import ShardingType +from awex.sharding.rank_info import RankInfo +from awex.transfer.transfer_plan import TransferPlan, TransferPlanBuilder + +from areal.experimental.weight_update.rdt import fetch_kv_metadata +from areal.utils import logging + +if TYPE_CHECKING: + from areal.engine.megatron_engine import MegatronEngine + +logger = logging.getLogger("RDTMegatronAdapter") + + +class RDTMegatronAdapter: + """RDT training adapter for MegatronEngine supporting DP, TP, and PP. + + PP: get_named_parameters already yields only the current stage's layers + (with globally-correct HF layer indices via get_transformer_layer_offset), + so each rank naturally reports only its own subset of parameters. + + TP: all_gather_param gathers the full tensor on every TP rank before + convert_to_hf. IW pulls via Ray RPC, so all TP ranks have identical tensors. + + Unlike awex (NCCL P2P push), RDT uses pull-based Ray RPC: + - TW stores TransferPlan and knows what each IW needs + - IW tells TW its infer_rank, TW returns corresponding weights + """ + + def __init__(self, engine: MegatronEngine): + self._engine = engine + self._transfer_plans: dict[str, TransferPlan] = {} + self._transfer_ranks: dict[str, int] = {} + + @property + def parallelism_strategy(self) -> dict: + from megatron.core import parallel_state as mpu + + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + return { + "world_size": self._engine.world_size, + "tp_size": tp_size, + "pp_size": mpu.get_pipeline_model_parallel_world_size(), + "dp_size": self._engine.data_parallel_world_size, + "ep_size": mpu.get_expert_model_parallel_world_size(), + "dp_replicated": tp_size > 1 or cp_size > 1, + } + + def get_weight_metadata(self) -> list[ParameterMeta]: + """Extract parameter metadata for TransferPlan building. + + Returns: + list[ParameterMeta]: Parameter metadata for all model parameters + """ + rank_info = self._build_rank_info() + metadata: list[ParameterMeta] = [] + + for hf_name, tensor in self._iter_hf_params(): + shape = tuple(tensor.shape) + numel = int(tensor.numel()) + shard_meta = ParameterShardMeta( + tp_rank=rank_info.tp_rank, + attn_tp_rank=rank_info.attn_tp_rank, + pp_rank=rank_info.pp_rank, + ep_rank=rank_info.ep_rank, + ep_tp_rank=rank_info.ep_tp_rank, + global_rank=rank_info.global_rank, + world_size=rank_info.world_size, + engine_rank=rank_info.engine_rank, + cp_rank=rank_info.cp_rank, + cp_size=rank_info.cp_size, + cp_mode=rank_info.cp_mode, + name=hf_name, + shape=shape, + numel=numel, + dtype=tensor.dtype, + global_offset=tuple([0] * len(shape)), + sharding_type=ShardingType.NO_SHARDING, + num_shards=1, + sharding_dim=0, + ) + replica = ParameterReplicaMeta(shards=[shard_meta]) + metadata.append( + ParameterMeta( + name=hf_name, + global_numel=numel, + global_shape=shape, + dtype=tensor.dtype, + shards=[shard_meta], + replicas=[replica], + ) + ) + + return metadata + + def get_local_shard_parameters( + self, required_names: list[str] | None = None + ) -> dict[str, torch.Tensor]: + """Return local tensors in HF naming. + + Args: + required_names: Optional filter for specific parameters + + Returns: + dict[str, torch.Tensor]: Local tensors by HF name + """ + required = set(required_names) if required_names else None + result: dict[str, torch.Tensor] = {} + for hf_name, tensor in self._iter_hf_params(): + if required is not None and hf_name not in required: + continue + result[hf_name] = tensor + return result + + def save_parameters(self, save_path: str, names: list[str] | None = None) -> None: + """Save local shard parameters to file for debugging. + + Args: + save_path: File path to save parameters + names: Optional filter for specific parameters + """ + params = self.get_local_shard_parameters(names) + cpu_params = {k: v.detach().cpu().clone() for k, v in params.items()} + torch.save(cpu_params, save_path) + + def init_weight_update_group( + self, + pair_name: str, + kv_store_url: str, + infer_world_size: int, + train_world_size: int, + num_engines: int, + transfer_rank: int, + ) -> None: + """Initialize RDT weight update group for TW. + + Args: + pair_name: TW-IW pair identifier + kv_store_url: Gateway KV store URL + infer_world_size: Total IW world size + train_world_size: Total TW world size + num_engines: Number of IW engines + transfer_rank: TW's transfer rank + """ + self._transfer_ranks[pair_name] = transfer_rank + + infer_meta, train_meta = fetch_kv_metadata(kv_store_url, pair_name) + + builder = TransferPlanBuilder( + infer_world_size=infer_world_size, + train_world_size=train_world_size, + num_infer_engines=num_engines, + ) + self._transfer_plans[pair_name] = builder.build_local_transfer_plan( + infer_meta, train_meta, global_transfer_rank=transfer_rank + ) + logger.info( + f"RDT TW init: Built TransferPlan for pair '{pair_name}' transfer_rank={transfer_rank}" + ) + + def teardown_weight_update_group(self, pair_name: str | None = None) -> None: + """Clear stored TransferPlans. + + Args: + pair_name: Optional specific pair to teardown; clears all if None + """ + if pair_name: + self._transfer_plans.pop(pair_name, None) + self._transfer_ranks.pop(pair_name, None) + else: + self._transfer_plans.clear() + self._transfer_ranks.clear() + + def _build_rank_info(self) -> RankInfo: + """Build RankInfo for shard metadata extraction. + + Returns: + RankInfo: Rank information for current worker + """ + from megatron.core import parallel_state as mpu + + tp_size = mpu.get_tensor_model_parallel_world_size() + tp_rank = mpu.get_tensor_model_parallel_rank() + pp_size = mpu.get_pipeline_model_parallel_world_size() + pp_rank = mpu.get_pipeline_model_parallel_rank() + ep_size = mpu.get_expert_model_parallel_world_size() + ep_rank = mpu.get_expert_model_parallel_rank() + etp_size = mpu.get_expert_tensor_parallel_world_size() + etp_rank = mpu.get_expert_tensor_parallel_rank() + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + local_rank = int(os.environ.get("LOCAL_RANK", self._engine.rank)) + + return RankInfo( + tp_rank=tp_rank, + tp_size=tp_size, + pp_rank=pp_rank, + pp_size=pp_size, + dp_size=self._engine.data_parallel_world_size, + dp_rank=self._engine.data_parallel_rank, + ep_rank=ep_rank, + ep_size=ep_size, + ep_tp_rank=etp_rank, + ep_tp_size=etp_size, + attn_tp_rank=tp_rank, + attn_tp_size=tp_size, + attn_dp_rank=self._engine.data_parallel_rank, + world_size=self._engine.world_size, + global_rank=self._engine.rank, + local_rank=local_rank, + engine_rank=0, + is_infer=False, + cp_rank=cp_rank, + cp_size=cp_size, + cp_mode="ring" if cp_size > 1 else "none", + ) + + def _iter_hf_params(self): + """Yield (hf_name, tensor) for every parameter on this rank. + + Uses get_named_parameters + all_gather_param + convert_to_hf to produce + HF-style per-expert names (e.g. experts.0.gate_proj.weight). The SGLang + adapter's _unfuse_params converts SGLang's fused w13/w2 format to the + same per-expert names, so both sides match for the transfer plan. + """ + from areal.engine.megatron_utils.megatron import ( + all_gather_param, + convert_to_hf, + get_named_parameters, + ) + + num_moe_experts = getattr(self._engine.tf_config, "num_moe_experts", None) + model_name = self._engine.hf_config.model_type + tie_word_embeddings = getattr( + self._engine.hf_config, "tie_word_embeddings", False + ) + + for mcore_name, param in get_named_parameters( + self._engine.model, num_moe_experts + ): + gathered = all_gather_param( + mcore_name, + param, + fp8_direct_convert=False, + quantization_config=None, + duplicated_param_names=self._engine._duplicated_param_names, + ) + if not isinstance(gathered, torch.Tensor): + gathered = gathered.data + + for hf_name, tensor in convert_to_hf( + self._engine.tf_config, + model_name, + mcore_name, + gathered, + ): + if tie_word_embeddings and hf_name == "lm_head.weight": + continue + yield hf_name, tensor.detach() diff --git a/areal/experimental/weight_update/rdt/sglang_adapter.py b/areal/experimental/weight_update/rdt/sglang_adapter.py new file mode 100644 index 0000000000..8454c47655 --- /dev/null +++ b/areal/experimental/weight_update/rdt/sglang_adapter.py @@ -0,0 +1,583 @@ +# SPDX-License-Identifier: Apache-2.0 +"""RDT SGLang Adapter for IW weight update.""" + +from __future__ import annotations + +import math +import os +from typing import Any + +import ray +import torch +import torch.distributed as dist +from awex.meta.weight_meta import ( + ParameterMeta, + ParameterReplicaMeta, + ParameterShardMeta, +) +from awex.sharding.param_sharding import ShardingType +from awex.sharding.rank_info import RankInfo +from awex.sharding.sglang_sharding import ( + get_sglang_rank_info, + get_sglang_sharding_strategy, +) +from awex.transfer.transfer_plan import TransferPlan, TransferPlanBuilder + +from areal.experimental.weight_update.rdt import ( + deserialize_actor_handle_bytes, + fetch_kv_metadata, + get_tensor_transport, +) +from areal.utils import logging + +logger = logging.getLogger("RDTSGLangAdapter") + + +class RDTSGLangAdapter: + """RDT inference adapter for in-process SGLang schedulers. + + Handles one-sided RDMA weight pull via Ray RPC from TW actors. + """ + + def __init__(self, scheduler: Any) -> None: + self._scheduler = scheduler + self._tw_handles: dict[str, list] = {} # pair_name -> TW actor handles + self._transfer_plans: dict[str, TransferPlan] = {} # pair_name -> plan + self._infer_world_sizes: dict[str, int] = {} # pair_name -> infer_world_size + self._tensor_transport: str | None = None + self._ray_initialized: bool = False + self._rank_info: RankInfo | None = None + + def _get_model(self) -> torch.nn.Module: + return self._scheduler.tp_worker.model_runner.model + + def _get_model_context(self) -> dict[str, Any]: + server_args = self._scheduler.server_args + tp_size = int(getattr(server_args, "tp_size", 1)) + pp_size = int(getattr(server_args, "pp_size", 1)) + dp_size = int(getattr(server_args, "dp_size", 1)) + + if dist.is_available() and dist.is_initialized(): + world_size = int(dist.get_world_size()) + global_rank = int(dist.get_rank()) + else: + world_size = int(tp_size * pp_size) + global_rank = int(getattr(self._scheduler, "tp_rank", 0)) + + local_rank = int( + getattr( + self._scheduler, + "local_rank", + os.environ.get("LOCAL_RANK", getattr(self._scheduler, "gpu_id", 0)), + ) + ) + + return { + "scheduler": self._scheduler, + "tp_rank": int(getattr(self._scheduler, "tp_rank", 0)), + "tp_size": tp_size, + "pp_rank": int(getattr(self._scheduler, "pp_rank", 0)), + "pp_size": pp_size, + "dp_size": dp_size, + "world_size": world_size, + "global_rank": global_rank, + "local_rank": local_rank, + "attn_tp_rank": int( + getattr( + self._scheduler, + "attn_tp_rank", + getattr(self._scheduler, "tp_rank", 0), + ) + ), + "attn_tp_size": int(getattr(self._scheduler, "attn_tp_size", tp_size)), + "attn_dp_rank": int(getattr(self._scheduler, "attn_dp_rank", 0)), + } + + @property + def parallelism_strategy(self) -> dict: + model_context = self._get_model_context() + server_args = self._scheduler.server_args + tp_size = int(getattr(server_args, "tp_size", model_context["tp_size"])) + pp_size = int(getattr(server_args, "pp_size", model_context["pp_size"])) + dp_size = int(getattr(server_args, "dp_size", model_context["dp_size"])) + ep_size = int(getattr(server_args, "ep_size", 1)) + + return { + "world_size": int(model_context["world_size"]), + "tp_size": tp_size, + "pp_size": pp_size, + "dp_size": dp_size, + "ep_size": ep_size, + "num_engines": 1, + } + + def _build_rank_info(self) -> RankInfo: + model_context = self._get_model_context() + return get_sglang_rank_info(model_context, engine_rank=0) + + def _build_sharding_strategy(self, rank_info: RankInfo): + model = self._get_model() + model_name = None + model_config = getattr(model, "config", None) + if model_config is not None: + architectures = getattr(model_config, "architectures", None) + if architectures and len(architectures) > 0: + model_name = architectures[0] + + if model_name is None: + model_name = type(model).__name__ + + infer_engine_config = self._scheduler.server_args + return get_sglang_sharding_strategy(model_name, infer_engine_config, rank_info) + + def _get_expert_prefix( + self, prefix: str, expert_idx: int, num_routed: int, total_experts: int + ) -> str: + if expert_idx < num_routed: + return f"{prefix}.{expert_idx}" + + shared_idx = expert_idx - num_routed + num_shared = total_experts - num_routed + if num_shared > 1: + return prefix.replace("experts", f"shared_experts.{shared_idx}") + return prefix.replace("experts", "shared_experts") + + def _unfuse_params( + self, name: str, tensor: torch.Tensor + ) -> list[tuple[str, torch.Tensor]]: + if "qkv_proj" in name: + cfg = self._get_model().config + num_heads = cfg.num_attention_heads + num_kv_heads = getattr(cfg, "num_key_value_heads", num_heads) + total_head_units = num_heads + 2 * num_kv_heads + dim0 = tensor.shape[0] + q_size = dim0 * num_heads // total_head_units + kv_size = dim0 * num_kv_heads // total_head_units + return [ + (name.replace("qkv_proj", "q_proj"), tensor.narrow(0, 0, q_size)), + ( + name.replace("qkv_proj", "k_proj"), + tensor.narrow(0, q_size, kv_size), + ), + ( + name.replace("qkv_proj", "v_proj"), + tensor.narrow(0, q_size + kv_size, kv_size), + ), + ] + if "gate_up_proj" in name: + half = tensor.shape[0] // 2 + return [ + (name.replace("gate_up_proj", "gate_proj"), tensor.narrow(0, 0, half)), + (name.replace("gate_up_proj", "up_proj"), tensor.narrow(0, half, half)), + ] + if "shared_experts" in name and "gate_up_weight" in name: + half = tensor.shape[0] // 2 + return [ + ( + name.replace("gate_up_weight", "gate_proj.weight"), + tensor.narrow(0, 0, half), + ), + ( + name.replace("gate_up_weight", "up_proj.weight"), + tensor.narrow(0, half, half), + ), + ] + if "shared_experts" in name and name.endswith("down_weight"): + return [(name.replace("down_weight", "down_proj.weight"), tensor)] + if ".experts.w13_weight" in name: + cfg = self._get_model().config + num_routed = getattr(cfg, "num_experts", None) or cfg.n_routed_experts + prefix = name.replace(".w13_weight", "") + result = [] + ffn_hidden = tensor.shape[1] // 2 + for i in range(tensor.shape[0]): + expert_tensor = tensor[i] + expert_prefix = self._get_expert_prefix( + prefix, i, num_routed, tensor.shape[0] + ) + result.append( + (f"{expert_prefix}.gate_proj.weight", expert_tensor[:ffn_hidden]) + ) + result.append( + (f"{expert_prefix}.up_proj.weight", expert_tensor[ffn_hidden:]) + ) + return result + if ".experts.w2_weight" in name: + cfg = self._get_model().config + num_routed = getattr(cfg, "num_experts", None) or cfg.n_routed_experts + prefix = name.replace(".w2_weight", "") + result = [] + for i in range(tensor.shape[0]): + expert_prefix = self._get_expert_prefix( + prefix, i, num_routed, tensor.shape[0] + ) + result.append((f"{expert_prefix}.down_proj.weight", tensor[i])) + return result + return [(name, tensor)] + + def get_weight_metadata(self) -> list[ParameterMeta]: + rank_info = self._build_rank_info() + strategy = self._build_sharding_strategy(rank_info) + self._rank_info = rank_info + + metadata: list[ParameterMeta] = [] + + for name, param in self._get_model().named_parameters(): + for hf_name, local_tensor in self._unfuse_params(name, param.data): + local_shape = tuple(local_tensor.shape) + sharding_type, sharding_dim, num_shards = ( + strategy.get_sharding_strategy(hf_name) + ) + + global_offset = [0] * len(local_shape) + if sharding_type == ShardingType.TP_SHARDING: + rank_pos = rank_info.tp_rank + elif sharding_type == ShardingType.DP_TP_SHARDING: + rank_pos = rank_info.attn_tp_rank + elif sharding_type == ShardingType.EP_SHARDING: + rank_pos = rank_info.ep_rank + elif sharding_type == ShardingType.EP_TP_SHARDING: + rank_pos = rank_info.ep_tp_rank + else: + rank_pos = 0 + + if ( + sharding_type != ShardingType.NO_SHARDING + and 0 <= sharding_dim < len(local_shape) + ): + global_offset[sharding_dim] = int(rank_pos) * int( + local_shape[sharding_dim] + ) + + global_shape = list(local_shape) + if ( + sharding_type != ShardingType.NO_SHARDING + and 0 <= sharding_dim < len(global_shape) + ): + global_shape[sharding_dim] = int(local_shape[sharding_dim]) * int( + num_shards + ) + + shard_meta = ParameterShardMeta( + tp_rank=rank_info.tp_rank, + attn_tp_rank=rank_info.attn_tp_rank, + pp_rank=rank_info.pp_rank, + ep_rank=rank_info.ep_rank, + ep_tp_rank=rank_info.ep_tp_rank, + global_rank=rank_info.global_rank, + world_size=rank_info.world_size, + engine_rank=rank_info.engine_rank, + cp_rank=rank_info.cp_rank, + cp_size=rank_info.cp_size, + cp_mode=rank_info.cp_mode, + name=hf_name, + shape=local_shape, + numel=int(local_tensor.numel()), + dtype=local_tensor.dtype, + global_offset=tuple(global_offset), + sharding_type=sharding_type, + num_shards=int(num_shards), + sharding_dim=int(sharding_dim), + ) + + replica = ParameterReplicaMeta(shards=[shard_meta]) + metadata.append( + ParameterMeta( + name=hf_name, + global_numel=math.prod(global_shape) if global_shape else 1, + global_shape=tuple(global_shape), + dtype=local_tensor.dtype, + shards=[shard_meta], + replicas=[replica], + ) + ) + + return metadata + + def get_local_shard_parameters( + self, required_names: list[str] | None = None + ) -> dict[str, torch.Tensor]: + required = set(required_names) if required_names else None + local_params: dict[str, torch.Tensor] = {} + + for name, param in self._get_model().named_parameters(): + for hf_name, hf_tensor in self._unfuse_params(name, param.data): + if required is None or hf_name in required: + local_params[hf_name] = hf_tensor + + return local_params + + def save_parameters(self, save_path: str, names: list[str] | None = None) -> None: + params = self.get_local_shard_parameters(names) + cpu_params = {k: v.detach().cpu().clone() for k, v in params.items()} + torch.save(cpu_params, save_path) + + def randomize_parameters(self) -> None: + for _, param in self._get_model().named_parameters(): + param.data.normal_() + + # --------------------------------------------------------------------------- + # RDT-specific methods: Ray init, TW handle storage, weight pull + # --------------------------------------------------------------------------- + + def _ensure_ray_init(self) -> None: + if self._ray_initialized: + return + + if not ray.is_initialized(): + ray.init(address="auto") + + self._ray_initialized = True + + transport = get_tensor_transport() + if transport == "YR": + try: + from ray_ascend import register_yr_tensor_transport + + register_yr_tensor_transport(["npu", "cpu"]) + logger.info("Registered YR tensor transport for NPU") + except ImportError: + logger.warning("ray_ascend not available, YR transport may not work") + + self._tensor_transport = transport + + def rdt_init_weight_update_group( + self, + pair_name: str, + kv_store_url: str, + tw_actor_bytes_b64_list: list[str], + infer_world_size: int, + train_world_size: int, + num_engines: int, + transfer_rank: int, + ) -> None: + import time + + self._ensure_ray_init() + + tw_handles = [ + deserialize_actor_handle_bytes(b64_bytes) + for b64_bytes in tw_actor_bytes_b64_list + ] + self._tw_handles[pair_name] = tw_handles + logger.info( + f"RDT init: Stored {len(tw_handles)} TW handles for pair '{pair_name}'" + ) + + infer_meta, train_meta = fetch_kv_metadata(kv_store_url, pair_name) + + builder = TransferPlanBuilder( + infer_world_size=infer_world_size, + train_world_size=train_world_size, + num_infer_engines=num_engines, + ) + self._transfer_plans[pair_name] = builder.build_local_transfer_plan( + infer_meta, train_meta, global_transfer_rank=transfer_rank + ) + self._infer_world_sizes[pair_name] = infer_world_size + logger.info( + f"RDT init: Built TransferPlan for pair '{pair_name}' transfer_rank={transfer_rank} " + f"infer_world_size={infer_world_size}" + ) + + # Warmup NIXL agents: call warmup on each TW handle + # This triggers NIXL agent initialization on IW (driver) and TW Actor sides + # Moving ~9s overhead from update_weights to connect phase + transport = self._tensor_transport or get_tensor_transport() + if transport == "NIXL": + warmup_refs = [] + for handle in tw_handles: + warmup_refs.append(handle.warmup_nixl.remote()) + t0 = time.monotonic() + ray.get(warmup_refs) + t1 = time.monotonic() + logger.info( + f"[RDT-Warmup] NIXL agent warmup completed in {1000 * (t1 - t0):.1f}ms " + f"for {len(tw_handles)} TW handles" + ) + + def rdt_execute_weight_update(self, version: int = 0) -> None: + import time + + t0 = time.monotonic() + pair_name = self._get_current_pair_name() + if pair_name not in self._tw_handles: + raise RuntimeError(f"TW handles not initialized for pair '{pair_name}'") + + tw_handles = self._tw_handles[pair_name] + plan = self._transfer_plans[pair_name] + + infer_world_size = self._infer_world_sizes.get(pair_name, 1) + + required_indices = self._get_required_tw_indices(plan, infer_world_size) + required_handles = [tw_handles[i] for i in required_indices] + + # IW's own global rank (infer_rank) + infer_rank = self._rank_info.global_rank if self._rank_info else 0 + + t1 = time.monotonic() + logger.info( + f"[RDT-IW] Pulling from TW shards {required_indices} for pair '{pair_name}' v{version}" + ) + + transport = self._tensor_transport or get_tensor_transport() + if transport == "YR": + raise RuntimeError("YR backend not implemented yet") + elif transport == "NIXL": + method_name = "get_weights_tensor_nixl" + else: + raise RuntimeError(f"Unsupported tensor transport: {transport}") + + # IW only passes infer_rank; TW uses TransferPlan to determine what to send + refs = [] + for handle in required_handles: + refs.append( + handle.__getattr__(method_name).remote(pair_name, infer_rank, version) + ) + + t2 = time.monotonic() + logger.info(f"[RDT-IW] Submitted {len(refs)} RPCs, calling ray.get...") + raw_results = ray.get(refs) + t3 = time.monotonic() + + # Unpack merged buffers back to tensor dicts + shard_tensor_by_rank: dict[int, dict[str, torch.Tensor]] = {} + total_bytes = 0 + for idx, (result, send_rank) in enumerate( + zip(raw_results, list(plan.inter_operations.keys())) + ): + buffer = result["buffer"] + metadata = result["metadata"] + names = metadata["names"] + offsets = metadata["offsets"] + shapes = metadata["shapes"] + dtypes = metadata["dtypes"] + + # Split buffer back into individual tensors + tensor_dict = {} + for i, name in enumerate(names): + offset = offsets[i] + shape = shapes[i] + dtype_str = dtypes[i] + + # Calculate numel from shape + numel = 1 + for s in shape: + numel *= s + + # Extract tensor slice from buffer + if i + 1 < len(offsets): + next_offset = offsets[i + 1] + else: + next_offset = buffer.numel() + + flat_slice = buffer[offset:next_offset] + tensor = flat_slice.reshape(shape) + + # Cast to correct dtype if needed + dtype_map = { + "torch.bfloat16": torch.bfloat16, + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + } + target_dtype = dtype_map.get(dtype_str) + if target_dtype and tensor.dtype != target_dtype: + tensor = tensor.to(target_dtype) + + tensor_dict[name] = tensor + total_bytes += numel * tensor.element_size() + + shard_tensor_by_rank[send_rank] = tensor_dict + + t4 = time.monotonic() + logger.info( + f"[RDT-IW] Unpacked {len(raw_results)} buffers, " + f"total tensors={sum(len(d) for d in shard_tensor_by_rank.values())}, " + f"total_bytes={total_bytes / 1024 / 1024:.1f}MB, " + f"unpack_time={1000 * (t4 - t3):.1f}ms" + ) + + self._apply_transfer_plan_to_model(plan, shard_tensor_by_rank, infer_world_size) + t5 = time.monotonic() + + # Cleanup IPC handles after transfer + for handle in required_handles: + ray.get(handle.clear_ipc_handles.remote(pair_name, infer_rank, version)) + t6 = time.monotonic() + + logger.info( + f"[RDT-IW-Timing] prep={1000 * (t1 - t0):.1f}ms | " + f"rpc_submit={1000 * (t2 - t1):.1f}ms | ray_get={1000 * (t3 - t2):.1f}ms | " + f"unpack={1000 * (t4 - t3):.1f}ms | apply_model={1000 * (t5 - t4):.1f}ms | " + f"cleanup={1000 * (t6 - t5):.1f}ms | total={1000 * (t6 - t0):.1f}ms" + ) + + def _get_current_pair_name(self) -> str: + if len(self._tw_handles) == 0: + raise RuntimeError("No TW handles initialized") + if len(self._tw_handles) == 1: + return next(iter(self._tw_handles.keys())) + raise RuntimeError("Multi-pair scenario requires pair_name specification") + + def _get_required_tw_indices( + self, plan: TransferPlan, infer_world_size: int + ) -> list[int]: + send_ranks = list(plan.inter_operations.keys()) + tw_indices = [r - infer_world_size for r in send_ranks] + + logger.info( + f"TransferPlan: send_ranks={send_ranks}, " + f"tw_indices={tw_indices}, infer_world_size={infer_world_size}" + ) + + return tw_indices + + def _apply_transfer_plan_to_model( + self, + plan: TransferPlan, + shard_tensors_by_rank: dict[int, dict[str, torch.Tensor]], + infer_world_size: int, + ) -> None: + local_params = self.get_local_shard_parameters() + + non_contiguous_pairs: list[tuple[torch.Tensor, torch.Tensor]] = [] + + for send_rank, operations in plan.inter_operations.items(): + tw_tensors = shard_tensors_by_rank.get(send_rank, {}) + + for op in operations: + tw_sliced = tw_tensors.get(op.send_shard_meta.name) + if tw_sliced is None: + logger.warning( + f"TW tensor not found: {op.send_shard_meta.name} from rank {send_rank}" + ) + continue + + iw_tensor = local_params.get(op.recv_shard_meta.name) + if iw_tensor is None: + logger.warning(f"IW tensor not found: {op.recv_shard_meta.name}") + continue + + iw_sliced = iw_tensor[op.inf_slices] + + if not iw_sliced.is_contiguous(): + contiguous = iw_sliced.contiguous() + non_contiguous_pairs.append((iw_sliced, contiguous)) + iw_sliced = contiguous + + iw_sliced.copy_(tw_sliced) + + for original, contiguous in non_contiguous_pairs: + original.copy_(contiguous) + + logger.info( + f"Applied TransferPlan: {len(non_contiguous_pairs)} non-contiguous pairs handled" + ) + + def teardown_weight_update_group(self) -> None: + self._tw_handles.clear() + self._transfer_plans.clear() + self._infer_world_sizes.clear() + self._rank_info = None diff --git a/areal/experimental/weight_update/rdt/weight_transport_actor.py b/areal/experimental/weight_update/rdt/weight_transport_actor.py new file mode 100644 index 0000000000..617b360fc8 --- /dev/null +++ b/areal/experimental/weight_update/rdt/weight_transport_actor.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +"""WeightTransportActor for TW tensor transport via RDT. + +Created by TW subprocess, shares same GPU via CUDA_VISIBLE_DEVICES. +Receives sliced tensor IPC handles and implements @ray.method(tensor_transport). +""" + +from __future__ import annotations + +import threading +from collections import defaultdict +from threading import Condition +from typing import Any + +import ray +import torch + +from areal.utils import logging + +logger = logging.getLogger("WeightTransportActor") + + +@ray.remote +class WeightTransportActor: + """Actor for weight tensor transport via RDT. + + Key features: + - max_concurrency set to IW count (multiple IW may pull concurrently) + - Condition key: {pair_name}/{infer_rank}/{version} for one-to-one TW-IW sync + - IW calls clear_ipc_handles() after ray.get() to release shared GPU memory + """ + + def __init__(self): + # Version-based tensor storage: {pair_name}/{infer_rank}/{version}/{param_name} + self._tensors: dict[str, torch.Tensor] = {} + self._tensors_lock = threading.Lock() # Protect _tensors dict access + + # Synchronization: wait for IPC handles ready + self._tensor_ready_lock = threading.Lock() + self._tensor_ready: dict[str, Condition] = {} + self._tensor_ready_flags: dict[str, bool] = defaultdict(bool) + + # Activate cuda:0 for IPC (CUDA_VISIBLE_DEVICES already set to single GPU by TW) + if torch.cuda.is_available(): + torch.cuda.set_device(0) + logger.info("WeightTransportActor initialized on cuda:0") + + def store_ipc_handles( + self, + pair_name: str, + infer_rank: int, + version: int, + ipc_handles: dict[str, Any], + ) -> None: + """Receive all IPC handles, store tensors and notify IW. + + Args: + pair_name: TW-IW pair identifier + infer_rank: IW's global rank + version: Weight version number + ipc_handles: dict of {param_name: ipc_payload} + """ + prefix = f"{pair_name}/{infer_rank}/{version}" + + # Build tensors outside lock to reduce lock holding time + new_tensors: dict[str, torch.Tensor] = {} + for param_name, ipc_payload in ipc_handles.items(): + rebuild_fn = ipc_payload["rebuild_fn"] + tensor_meta = ipc_payload["tensor_meta"] + shared_tensor = rebuild_fn(*tensor_meta) + key = f"{prefix}/{param_name}" + new_tensors[key] = shared_tensor + + # Store tensors under lock (brief operation) + with self._tensors_lock: + self._tensors.update(new_tensors) + + # Notify waiting IWs + with self._tensor_ready_lock: + self._tensor_ready_flags[prefix] = True + if prefix in self._tensor_ready: + self._tensor_ready[prefix].notify_all() + + logger.info(f"Stored {len(ipc_handles)} IPC handles for {prefix}") + + def _wait_for_ready( + self, pair_name: str, infer_rank: int, version: int, timeout: float = 30.0 + ) -> bool: + """Wait for IPC handles ready (blocking).""" + prefix = f"{pair_name}/{infer_rank}/{version}" + + with self._tensor_ready_lock: + if self._tensor_ready_flags.get(prefix, False): + return True + + if prefix not in self._tensor_ready: + self._tensor_ready[prefix] = Condition(self._tensor_ready_lock) + + return self._tensor_ready[prefix].wait(timeout=timeout) + + @ray.method(tensor_transport="NIXL") + def get_weights_tensor_nixl( + self, + pair_name: str, + infer_rank: int, + version: int, + ) -> dict[str, Any]: + """Tensor transport for GPU (NIXL backend). + + IW calls this method, blocks until TW stores IPC handles. + Returns merged buffer + metadata for efficient single RDMA transfer. + """ + import time + + t0 = time.monotonic() + prefix = f"{pair_name}/{infer_rank}/{version}" + + t1 = time.monotonic() + if not self._wait_for_ready(pair_name, infer_rank, version): + raise RuntimeError(f"IPC handles not ready for {prefix} after 30s") + + t2 = time.monotonic() + with self._tensors_lock: + tensor_items = [ + (k, v) for k, v in self._tensors.items() if k.startswith(prefix) + ] + # Sort by key to ensure consistent order + tensor_items.sort(key=lambda x: x[0]) + + t3 = time.monotonic() + if not tensor_items: + raise RuntimeError(f"Tensors not found for {prefix}") + + # Merge all tensors into single contiguous buffer + # This reduces NIXL registration overhead from N times to 1 time + tensors = [t.clone().detach() for _, t in tensor_items] + param_names = [k.split("/")[-1] for k, _ in tensor_items] + + t4 = time.monotonic() + + # Flatten and concatenate into single buffer + flat_tensors = [t.flatten() for t in tensors] + merged_buffer = torch.cat(flat_tensors) + + # Build metadata for IW to split buffer back + offsets = [] + current_offset = 0 + shapes = [] + dtypes = [] + for t in tensors: + numel = t.numel() + offsets.append(current_offset) + shapes.append(tuple(t.shape)) + dtypes.append(str(t.dtype)) + current_offset += numel + + t5 = time.monotonic() + total_bytes = merged_buffer.numel() * merged_buffer.element_size() + logger.info( + f"[Actor-Timing] wait_ready={1000 * (t2 - t1):.1f}ms | " + f"lock_lookup={1000 * (t3 - t2):.1f}ms | " + f"clone={1000 * (t4 - t3):.1f}ms | " + f"merge={1000 * (t5 - t4):.1f}ms | " + f"total={1000 * (t5 - t0):.1f}ms | " + f"num_tensors={len(tensors)} | " + f"total_bytes={total_bytes / 1024 / 1024:.1f}MB" + ) + + return { + "buffer": merged_buffer, + "metadata": { + "names": param_names, + "offsets": offsets, + "shapes": shapes, + "dtypes": dtypes, + }, + } + + @ray.method(tensor_transport="NIXL") + def warmup_nixl(self) -> dict[str, torch.Tensor]: + """Warmup NIXL agent by returning a minimal tensor. + + IW calls this during init to trigger NIXL agent initialization + on both IW (driver) and TW Actor sides before actual weight transfer. + This moves ~9s initialization overhead from update_weights to connect phase. + """ + # Create a tiny tensor to trigger NIXL registration + warmup_tensor = torch.zeros(1, dtype=torch.float32, device="cuda:0") + logger.info("NIXL warmup tensor created") + return {"warmup": warmup_tensor} + + # TODO: Implement YR backend for NPU (ray-ascend) + # @ray.method(tensor_transport="YR") + # def get_weights_tensor_yr( + # self, + # pair_name: str, + # infer_rank: int, + # version: int, + # ) -> dict[str, torch.Tensor]: + # """Tensor transport for NPU (YR backend).""" + # ... + + def clear_ipc_handles(self, pair_name: str, infer_rank: int, version: int) -> None: + """Clean up IPC handles for specific infer_rank and version.""" + prefix = f"{pair_name}/{infer_rank}/{version}" + + # Remove tensors under lock + with self._tensors_lock: + for key in list(self._tensors.keys()): + if key.startswith(prefix): + del self._tensors[key] + + # Clear readiness state + with self._tensor_ready_lock: + self._tensor_ready_flags[prefix] = False + if prefix in self._tensor_ready: + del self._tensor_ready[prefix] diff --git a/tests/experimental/weight_update/test_rdt_integration.py b/tests/experimental/weight_update/test_rdt_integration.py new file mode 100644 index 0000000000..4fd4be6c9d --- /dev/null +++ b/tests/experimental/weight_update/test_rdt_integration.py @@ -0,0 +1,1000 @@ +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import subprocess + +import pytest +import torch + +from areal.infra.platforms import current_platform +from areal.infra.utils.proc import kill_process_tree +from areal.utils.network import find_free_ports + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="RDT tests require CUDA GPU (NIXL transport)" +) + +# Project root so that torchrun workers can resolve `from tests.*` imports. +_PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..") +) + + +def _run_rdt_weight_transfer_test(n_gpus: int, test_type: str, output: str): + """Run RDT distributed test via torchrun subprocess.""" + port = find_free_ports(1)[0] + env = os.environ.copy() + env["PYTHONPATH"] = _PROJECT_ROOT + os.pathsep + env.get("PYTHONPATH", "") + proc = subprocess.Popen( + [ + "torchrun", + f"--nproc_per_node={n_gpus}", + "--nnodes=1", + "--master-addr=localhost", + f"--master_port={port}", + "tests/experimental/weight_update/torchrun/run_rdt_weight_transfer.py", + f"--test_type={test_type}", + f"--output={output}", + ], + text=True, + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + env=env, + ) + try: + stdout, _ = proc.communicate(timeout=300) + print(stdout) + except BaseException: + kill_process_tree(proc.pid) + raise + if proc.returncode != 0: + pytest.fail(f"torchrun exited with code {proc.returncode}") + + with open(output) as f: + result = f.read().strip() + assert result == "Passed", f"Test failed: {result}" + + +# --------------------------------------------------------------------------- +# Distributed RDT tests via torchrun +# --------------------------------------------------------------------------- + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_rdt_weight_transfer_lifecycle_2gpu(tmp_path_factory): + """Test full RDT weight transfer lifecycle.""" + if current_platform.device_count() < 2: + pytest.skip("This test requires 2 GPUs") + output = tmp_path_factory.mktemp("test_output") / "rdt_lifecycle.out" + _run_rdt_weight_transfer_test(2, "rdt_weight_transfer_lifecycle", str(output)) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.parametrize("n_gpus", [4]) +def test_rdt_weight_transfer_lifecycle_4gpu(n_gpus, tmp_path_factory): + """Test RDT weight transfer with 4 GPUs.""" + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + output = tmp_path_factory.mktemp("test_output") / f"rdt_lifecycle_{n_gpus}gpu.out" + _run_rdt_weight_transfer_test(n_gpus, "rdt_weight_transfer_lifecycle", str(output)) + + +# --------------------------------------------------------------------------- +# E2E RDT weight update: real FSDPEngine + SGLang server + gateway +# --------------------------------------------------------------------------- + + +def _get_test_model_path() -> str: + local = "/storage/openpsi/models/Qwen__Qwen3-0.6B/" + if os.path.isdir(local): + return local + return "Qwen/Qwen3-0.6B" + + +def _get_test_moe_model_path() -> str: + local = "/storage/openpsi/models/Qwen__Qwen3-30B-A3B/" + if os.path.isdir(local): + return local + return "Qwen/Qwen3-30B-A3B" + + +def _make_truncated_moe_model(tmp_path, num_layers: int = 4) -> str: + """Create truncated MoE model with reduced layers for testing.""" + import glob + import json + import shutil + + src = _get_test_moe_model_path() + dst = str(tmp_path / "truncated_moe") + os.makedirs(dst, exist_ok=True) + + with open(os.path.join(src, "config.json")) as f: + config = json.load(f) + config["num_hidden_layers"] = num_layers + with open(os.path.join(dst, "config.json"), "w") as f: + json.dump(config, f) + + for fname in ( + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "vocab.json", + "merges.txt", + "generation_config.json", + ): + src_file = os.path.join(src, fname) + if os.path.isfile(src_file): + shutil.copy2(src_file, dst) + + for weight_file in glob.glob(os.path.join(src, "*.safetensors")) + glob.glob( + os.path.join(src, "*.bin") + ): + os.symlink(weight_file, os.path.join(dst, os.path.basename(weight_file))) + + for index_file in glob.glob( + os.path.join(src, "*.safetensors.index.json") + ) + glob.glob(os.path.join(src, "*.bin.index.json")): + shutil.copy2(index_file, dst) + + return dst + + +def _make_local_scheduler(tmp_path, name: str, gpu_devices: list[int]): + from areal.infra.scheduler.local import LocalScheduler + + fileroot = tmp_path / f"{name}_fileroot" + fileroot.mkdir(exist_ok=True) + nr_root = tmp_path / f"{name}_name_resolve" + nr_root.mkdir(exist_ok=True) + + return LocalScheduler( + gpu_devices=gpu_devices, + log_dir=str(tmp_path / f"{name}_logs"), + experiment_name=f"test-rdt-{name}", + trial_name="t0", + fileroot=str(fileroot), + nfs_record_root=str(nr_root), + ) + + +# Representative parameters spanning different fusion/sharding cases. +_VALIDATE_PARAM_NAMES = [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + "model.layers.0.mlp.gate_proj.weight", + "model.layers.0.mlp.up_proj.weight", + "model.layers.27.self_attn.q_proj.weight", + "model.norm.weight", +] + +# Qwen3-30B-A3B (truncated to 4 layers) MoE validation params. +# This model has no shared experts — pure MoE with 128 routed experts. +_VALIDATE_PARAM_NAMES_MOE = [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + "model.layers.1.mlp.experts.0.gate_proj.weight", + "model.layers.1.mlp.experts.0.up_proj.weight", + "model.layers.1.mlp.experts.0.down_proj.weight", + "model.layers.3.self_attn.q_proj.weight", + "model.norm.weight", +] + + +def _validate_weight_update_correctness( + train_worker_urls: list[str], + inf_worker_url: str, + param_dir, +) -> None: + """Fetch params from both sides via HTTP and compare bitwise.""" + import httpx + + n_train = len(train_worker_urls) + print( + f"\n[weight-validation] Fetching parameters from {n_train} training " + f"worker(s) and 1 inference worker ..." + ) + + train_shard_paths = [] + for i, url in enumerate(train_worker_urls): + p = str(param_dir / f"train_params_rank{i}.pt") + train_shard_paths.append(p) + resp = httpx.post( + f"{url}/rdt/debug/get_parameters", + json={"save_path": p, "names": _VALIDATE_PARAM_NAMES}, + timeout=120.0, + ) + assert resp.status_code == 200, ( + f"get_parameters failed on training worker {i}: {resp.text}" + ) + + inf_path = str(param_dir / "infer_params.pt") + resp = httpx.post( + f"{inf_worker_url}/rdt/debug/get_parameters", + json={"save_path": inf_path, "names": _VALIDATE_PARAM_NAMES}, + timeout=120.0, + ) + assert resp.status_code == 200, ( + f"get_parameters failed on inference worker: {resp.text}" + ) + + infer_params = torch.load(inf_path, map_location="cpu", weights_only=True) + train_shards = [ + torch.load(p, map_location="cpu", weights_only=True) for p in train_shard_paths + ] + + print(f"[weight-validation] Comparing {len(_VALIDATE_PARAM_NAMES)} parameters ...") + for name in _VALIDATE_PARAM_NAMES: + assert name in infer_params, f"Inference missing param: {name}" + for i, shard in enumerate(train_shards): + assert name in shard, f"Training rank {i} missing param: {name}" + + # Reconstruct full training param from FSDP Shard(0) chunks + if len(train_shards) > 1: + full_train = torch.cat([s[name] for s in train_shards], dim=0) + else: + full_train = train_shards[0][name] + + torch.testing.assert_close( + full_train, + infer_params[name], + rtol=0, + atol=0, + msg=f"Parameter mismatch after weight update: {name}", + ) + print( + f"[weight-validation] {name}: OK " + f"(shape={list(full_train.shape)}, dtype={full_train.dtype})" + ) + + print( + f"[weight-validation] All {len(_VALIDATE_PARAM_NAMES)} parameters " + f"match between training and inference" + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize("n_gpus", [2, 4, 8], ids=["2gpu", "4gpu", "8gpu"]) +def test_rdt_fsdp_e2e_weight_update(n_gpus, tmp_path_factory): + """Full round trip: FSDPEngine → weight-update gateway → SGLang (RDT mode). + + Requires Ray cluster with NIXL tensor transport support. + """ + import httpx + + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + # Require Ray cluster + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + from areal.api import FinetuneSpec + from areal.api.cli_args import ( + InferenceEngineConfig, + OptimizerConfig, + SchedulingSpec, + TrainEngineConfig, + ) + from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, + ) + from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, + ) + from areal.experimental.weight_update.controller import ( + WeightUpdateController, + WeightUpdateControllerConfig, + ) + + n_half = n_gpus // 2 + tmp = tmp_path_factory.mktemp("rdt_e2e") + model_path = _get_test_model_path() + + scheduler = _make_local_scheduler(tmp, "rdt_e2e", gpu_devices=list(range(n_gpus))) + + # IW: SGLang with RDT backend + inf_config = InferenceEngineConfig( + tokenizer_path=model_path, + backend=f"sglang:d{n_half}", + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.inference_service.guard", + env_vars={"AREAL_WEIGHT_UPDATE_BACKEND": "rdt"}, + ), + ), + consumer_batch_size=8, + max_head_offpolicyness=1024, + setup_timeout=300.0, + admin_api_key="test-admin", + ) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) + + # TW: FSDP + train_config = TrainEngineConfig( + backend=f"fsdp:d{n_half}", + experiment_name="test-rdt-e2e", + trial_name="t0", + path=model_path, + optimizer=OptimizerConfig(), + _version="v2", + setup_timeout=300.0, + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.training_service.guard", + env_vars=dict( + NCCL_CUMEM_ENABLE="0", + NCCL_NVLS_ENABLE="0", + AREAL_WEIGHT_UPDATE_BACKEND="rdt", + ), + ), + ), + ) + train_ctrl = GatewayTrainController( + train_engine="areal.engine.fsdp_engine.FSDPEngine", + config=train_config, + scheduler=scheduler, + ) + + wu_ctrl: WeightUpdateController | None = None + + try: + # -- 1. SGLang IW (RDT backend) -- + inf_ctrl.initialize( + role="rollout", + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, + wait=True, + ) + inf_worker_urls = list(inf_ctrl._inf_addrs) + + # Randomize IW weights + for url in inf_worker_urls: + resp = httpx.post(f"{url}/rdt/debug/randomize_parameters", timeout=120.0) + assert resp.status_code == 200, f"randomize_parameters failed: {resp.text}" + + # -- 2. FSDP TW -- + ft_spec = FinetuneSpec( + total_train_epochs=1, dataset_size=100, train_batch_size=2 + ) + train_ctrl.initialize(role="actor", ft_spec=ft_spec, wait=True) + train_worker_urls = list(train_ctrl._worker_addrs) + + # -- 3. Weight update gateway -- + wu_ctrl = WeightUpdateController( + config=WeightUpdateControllerConfig( + host="127.0.0.1", + request_timeout=300.0, + ) + ) + wu_ctrl.initialize() + + # -- 4. RDT weight update lifecycle -- + assert wu_ctrl.health_check(), "Weight update gateway health check failed" + + wu_ctrl.connect( + pair_name="test_rdt_e2e", + train_worker_urls=train_worker_urls, + inference_worker_urls=inf_worker_urls, + mode="rdt", + ) + + result = wu_ctrl.update_weights(version=1) + assert result.status == "ok" + assert result.version == 1 + + wu_ctrl.disconnect() + + # -- 5. Verify IW generation works -- + gen_resp = httpx.post( + f"{inf_worker_urls[0]}/generate", + json={ + "text": "Hello", + "sampling_params": {"max_new_tokens": 5, "temperature": 0}, + }, + timeout=30.0, + ) + assert gen_resp.status_code == 200, ( + f"Generation failed after weight update: {gen_resp.text}" + ) + + # -- 6. Validate TW ↔ IW parameter equality -- + _validate_weight_update_correctness( + train_worker_urls=train_worker_urls, + inf_worker_url=inf_worker_urls[0], + param_dir=tmp, + ) + + finally: + if wu_ctrl is not None: + wu_ctrl.destroy() + train_ctrl.destroy() + inf_ctrl.destroy() + scheduler.delete_workers(None) + + +# --------------------------------------------------------------------------- +# E2E RDT weight update: MegatronEngine + SGLang server + gateway +# --------------------------------------------------------------------------- + + +def _validate_weight_update_correctness_megatron_rdt( + train_worker_urls: list[str], + inf_worker_url: str, + param_dir, + tag: str = "megatron", + param_names: list[str] | None = None, +) -> None: + """Fetch params from both sides via HTTP and compare bitwise (Megatron RDT).""" + import concurrent.futures + + import httpx + + names = param_names or _VALIDATE_PARAM_NAMES + n_train = len(train_worker_urls) + print( + f"\n[weight-validation] Fetching parameters from {n_train} Megatron " + f"worker(s) and 1 inference worker …" + ) + + train_paths = [ + str(param_dir / f"{tag}_train_params_rank{i}.pt") for i in range(n_train) + ] + + def _fetch_train(args): + i, url, p = args + resp = httpx.post( + f"{url}/rdt/debug/get_parameters", + json={"save_path": p, "names": names}, + timeout=120.0, + ) + assert resp.status_code == 200, ( + f"get_parameters failed on training worker {i}: {resp.text}" + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=n_train) as pool: + list( + pool.map( + _fetch_train, + [ + (i, url, p) + for i, (url, p) in enumerate(zip(train_worker_urls, train_paths)) + ], + ) + ) + + inf_path = str(param_dir / f"{tag}_infer_params.pt") + resp = httpx.post( + f"{inf_worker_url}/rdt/debug/get_parameters", + json={"save_path": inf_path, "names": names}, + timeout=120.0, + ) + assert resp.status_code == 200, ( + f"get_parameters failed on inference worker: {resp.text}" + ) + + infer_params = torch.load(inf_path, map_location="cpu", weights_only=True) + + # Union params across all training ranks: with PP each rank owns a disjoint + # subset of layers, so we need all ranks to cover the validate param names. + train_params: dict[str, torch.Tensor] = {} + for p in train_paths: + train_params.update(torch.load(p, map_location="cpu", weights_only=True)) + + print(f"[weight-validation] Comparing {len(names)} parameters …") + for name in names: + assert name in infer_params, f"Inference missing param: {name}" + assert name in train_params, ( + f"No training rank owns param: {name}. " + f"Available: {sorted(train_params.keys())[:5]}…" + ) + + torch.testing.assert_close( + train_params[name], + infer_params[name], + rtol=0, + atol=0, + msg=f"Parameter mismatch after weight update: {name}", + ) + print( + f"[weight-validation] {name}: OK " + f"(shape={list(train_params[name].shape)}, dtype={train_params[name].dtype})" + ) + + print( + f"[weight-validation] All {len(names)} parameters " + f"match between training and inference" + ) + + +def _run_megatron_rdt_e2e( + *, + n_gpus: int, + backend: str, + pair_name: str, + tag: str, + tmp_path_factory, + model_path: str | None = None, + validate_param_names: list[str] | None = None, + init_from_scratch: bool = False, +): + """Run MegatronEngine E2E weight update test with RDT backend.""" + import httpx + + from areal.api import FinetuneSpec + from areal.api.cli_args import ( + InferenceEngineConfig, + OptimizerConfig, + SchedulingSpec, + TrainEngineConfig, + ) + from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, + ) + from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, + ) + from areal.experimental.weight_update.controller import ( + WeightUpdateController, + WeightUpdateControllerConfig, + ) + + n_infer = n_gpus // 2 + tmp = tmp_path_factory.mktemp(tag) + model_path = model_path or _get_test_model_path() + scheduler = _make_local_scheduler(tmp, tag, gpu_devices=list(range(n_gpus))) + + # IW: SGLang with RDT backend + inf_config = InferenceEngineConfig( + tokenizer_path=model_path, + backend=f"sglang:d{n_infer}", + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.inference_service.guard", + env_vars={"AREAL_WEIGHT_UPDATE_BACKEND": "rdt"}, + ), + ), + consumer_batch_size=8, + max_head_offpolicyness=1024, + setup_timeout=300.0, + admin_api_key="test-admin", + ) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) + + # TW: Megatron with RDT backend + train_config = TrainEngineConfig( + backend=backend, + experiment_name=f"test-rdt-{tag}", + trial_name="t0", + path=model_path, + init_from_scratch=init_from_scratch, + optimizer=OptimizerConfig(), + _version="v2", + setup_timeout=300.0, + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.training_service.guard", + env_vars=dict( + NCCL_CUMEM_ENABLE="0", + NCCL_NVLS_ENABLE="0", + AREAL_WEIGHT_UPDATE_BACKEND="rdt", + ), + ), + ), + ) + train_ctrl = GatewayTrainController( + train_engine="areal.engine.megatron_engine.MegatronLMEngine", + config=train_config, + scheduler=scheduler, + ) + + wu_ctrl: WeightUpdateController | None = None + + try: + inf_ctrl.initialize( + role="rollout", + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, + wait=True, + ) + inf_worker_urls = list(inf_ctrl._inf_addrs) + + # Randomize IW weights (RDT endpoint) + for url in inf_worker_urls: + resp = httpx.post(f"{url}/rdt/debug/randomize_parameters", timeout=120.0) + assert resp.status_code == 200, f"randomize_parameters failed: {resp.text}" + + train_ctrl.initialize( + role="actor", + ft_spec=FinetuneSpec( + total_train_epochs=1, dataset_size=100, train_batch_size=2 + ), + wait=True, + ) + train_worker_urls = list(train_ctrl._worker_addrs) + + wu_ctrl = WeightUpdateController( + config=WeightUpdateControllerConfig(host="127.0.0.1", request_timeout=300.0) + ) + wu_ctrl.initialize() + assert wu_ctrl.health_check(), "Weight update gateway health check failed" + + wu_ctrl.connect( + pair_name=pair_name, + train_worker_urls=train_worker_urls, + inference_worker_urls=inf_worker_urls, + mode="rdt", + ) + result = wu_ctrl.update_weights(version=1) + assert result.status == "ok" + assert result.version == 1 + wu_ctrl.disconnect() + + gen_resp = httpx.post( + f"{inf_worker_urls[0]}/generate", + json={ + "text": "Hello", + "sampling_params": {"max_new_tokens": 5, "temperature": 0}, + }, + timeout=30.0, + ) + assert gen_resp.status_code == 200, ( + f"Generation failed after weight update: {gen_resp.text}" + ) + + _validate_weight_update_correctness_megatron_rdt( + train_worker_urls=train_worker_urls, + inf_worker_url=inf_worker_urls[0], + param_dir=tmp, + tag=tag, + param_names=validate_param_names, + ) + finally: + if wu_ctrl is not None: + wu_ctrl.destroy() + train_ctrl.destroy() + inf_ctrl.destroy() + scheduler.delete_workers(None) + + +# --------------------------------------------------------------------------- +# Megatron RDT E2E tests +# --------------------------------------------------------------------------- + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize("n_gpus", [2, 4, 8], ids=["2gpu", "4gpu", "8gpu"]) +def test_rdt_megatron_dp_e2e_weight_update(n_gpus, tmp_path_factory): + """Full round trip: MegatronEngine (pure DP) → gateway → SGLang (RDT mode). + + Each training rank holds a full copy of every parameter. Validation unions + params across all training ranks and compares bitwise against SGLang. + + Requires Ray cluster with NIXL tensor transport support. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + n_half = n_gpus // 2 + _run_megatron_rdt_e2e( + n_gpus=n_gpus, + backend=f"megatron:d{n_half}", + pair_name="test_rdt_megatron_dp_e2e", + tag="megatron_dp_e2e", + tmp_path_factory=tmp_path_factory, + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize( + "n_gpus,tp_size", + [(4, 2), (8, 2), (8, 4)], + ids=["4gpu-dp1tp2", "8gpu-dp2tp2", "8gpu-dp1tp4"], +) +def test_rdt_megatron_dp_tp_e2e_weight_update(n_gpus, tp_size, tmp_path_factory): + """Full round trip: MegatronEngine (DP+TP) → gateway → SGLang (RDT mode). + + TP ranks within a DP group each hold the same full parameter after + all_gather_param. dp_replicated=True tells RDT only one rank per group + needs to send, avoiding redundant transfers. + + Requires Ray cluster with NIXL tensor transport support. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + n_infer = n_gpus // 2 + n_train = n_gpus - n_infer + dp_size = n_train // tp_size + if dp_size < 1: + pytest.skip(f"Not enough GPUs for dp={dp_size} tp={tp_size}") + _run_megatron_rdt_e2e( + n_gpus=n_gpus, + backend=f"megatron:d{dp_size}t{tp_size}", + pair_name=f"test_rdt_megatron_dp{dp_size}tp{tp_size}", + tag=f"megatron_dp{dp_size}tp{tp_size}", + tmp_path_factory=tmp_path_factory, + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize( + "n_gpus,pp_size", + [(4, 2), (8, 4)], + ids=["4gpu-dp1pp2", "8gpu-dp1pp4"], +) +def test_rdt_megatron_pp_e2e_weight_update(n_gpus, pp_size, tmp_path_factory): + """Full round trip: MegatronEngine (pure PP) → gateway → SGLang (RDT mode). + + Each PP stage owns a disjoint subset of layers. Validation unions params + across all PP ranks to reconstruct the full parameter set for comparison. + + Requires Ray cluster with NIXL tensor transport support. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + _run_megatron_rdt_e2e( + n_gpus=n_gpus, + backend=f"megatron:d1p{pp_size}", + pair_name=f"test_rdt_megatron_pp{pp_size}", + tag=f"megatron_pp{pp_size}", + tmp_path_factory=tmp_path_factory, + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize( + "n_gpus,dp_size,pp_size", + [(8, 2, 2)], + ids=["8gpu-dp2pp2"], +) +def test_rdt_megatron_dp_pp_e2e_weight_update( + n_gpus, dp_size, pp_size, tmp_path_factory +): + """Full round trip: MegatronEngine (DP+PP) → gateway → SGLang (RDT mode). + + Combines data parallelism (multiple replicas) with pipeline parallelism + (each replica split across PP stages). Each PP stage of each DP replica + sends its own layer subset independently. + + Requires Ray cluster with NIXL tensor transport support. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + _run_megatron_rdt_e2e( + n_gpus=n_gpus, + backend=f"megatron:d{dp_size}p{pp_size}", + pair_name=f"test_rdt_megatron_dp{dp_size}pp{pp_size}", + tag=f"megatron_dp{dp_size}pp{pp_size}", + tmp_path_factory=tmp_path_factory, + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize( + "n_gpus,pp_size,tp_size", + [(8, 2, 2)], + ids=["8gpu-dp1pp2tp2"], +) +def test_rdt_megatron_pp_tp_e2e_weight_update( + n_gpus, pp_size, tp_size, tmp_path_factory +): + """Full round trip: MegatronEngine (PP+TP) → gateway → SGLang (RDT mode). + + PP splits layers across stages; TP shards each stage's weights across ranks. + Both dp_replicated=True (TP) and disjoint layer ownership (PP) apply. + + Requires Ray cluster with NIXL tensor transport support. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + _run_megatron_rdt_e2e( + n_gpus=n_gpus, + backend=f"megatron:d1p{pp_size}t{tp_size}", + pair_name=f"test_rdt_megatron_pp{pp_size}tp{tp_size}", + tag=f"megatron_pp{pp_size}tp{tp_size}", + tmp_path_factory=tmp_path_factory, + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize( + "n_gpus,cp_size", + [(4, 2), (8, 4)], + ids=["4gpu-dp1cp2", "8gpu-dp1cp4"], +) +def test_rdt_megatron_cp_e2e_weight_update(n_gpus, cp_size, tmp_path_factory): + """Full round trip: MegatronEngine (pure CP) → gateway → SGLang (RDT mode). + + CP splits the sequence across ranks for attention but all CP ranks hold + identical parameters. dp_replicated=True tells RDT only one CP rank per + group needs to send. + + Requires Ray cluster with NIXL tensor transport support. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + _run_megatron_rdt_e2e( + n_gpus=n_gpus, + backend=f"megatron:d1c{cp_size}", + pair_name=f"test_rdt_megatron_cp{cp_size}", + tag=f"megatron_cp{cp_size}", + tmp_path_factory=tmp_path_factory, + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize( + "n_gpus,dp_size,cp_size", + [(8, 2, 2)], + ids=["8gpu-dp2cp2"], +) +def test_rdt_megatron_dp_cp_e2e_weight_update( + n_gpus, dp_size, cp_size, tmp_path_factory +): + """Full round trip: MegatronEngine (DP+CP hybrid) → gateway → SGLang (RDT mode). + + Combines data parallelism with context parallelism. Each DP replica has + cp_size CP ranks all holding identical parameters. dp_replicated=True + ensures only one CP rank per DP group sends, avoiding redundant transfers. + + Requires Ray cluster with NIXL tensor transport support. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + _run_megatron_rdt_e2e( + n_gpus=n_gpus, + backend=f"megatron:d{dp_size}c{cp_size}", + pair_name=f"test_rdt_megatron_dp{dp_size}cp{cp_size}", + tag=f"megatron_dp{dp_size}cp{cp_size}", + tmp_path_factory=tmp_path_factory, + ) + + +# --------------------------------------------------------------------------- +# Megatron RDT E2E tests - EP (Expert Parallelism) +# --------------------------------------------------------------------------- + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize( + "n_gpus,ep_size", + [(4, 2), (8, 4)], + ids=["4gpu-dp2ep2", "8gpu-dp4ep4"], +) +def test_rdt_megatron_ep_e2e_weight_update(n_gpus, ep_size, tmp_path_factory): + """Full round trip: MegatronEngine (EP) → gateway → SGLang (RDT mode). + + Each EP rank owns a different subset of expert parameters while attention + and norm weights are replicated. Uses a truncated Qwen3-30B-A3B MoE model + (4 layers) with init_from_scratch=True to avoid loading full weights. + + Requires Ray cluster with NIXL tensor transport support. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + n_train = n_gpus - n_gpus // 2 + tmp = tmp_path_factory.mktemp(f"megatron_ep{ep_size}_model") + _run_megatron_rdt_e2e( + n_gpus=n_gpus, + backend=f"megatron:d{n_train}e{ep_size}", + pair_name=f"test_rdt_megatron_ep{ep_size}", + tag=f"megatron_ep{ep_size}", + tmp_path_factory=tmp_path_factory, + model_path=_make_truncated_moe_model(tmp, num_layers=4), + validate_param_names=_VALIDATE_PARAM_NAMES_MOE, + init_from_scratch=True, + ) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize( + "n_gpus,dp_size,ep_size", + [(8, 4, 2)], + ids=["8gpu-dp4ep2"], +) +def test_rdt_megatron_dp_ep_e2e_weight_update( + n_gpus, dp_size, ep_size, tmp_path_factory +): + """Full round trip: MegatronEngine (DP+EP hybrid) → gateway → SGLang (RDT mode). + + Combines data parallelism with expert parallelism. Each DP replica has + ep_size EP ranks owning different expert subsets. Non-expert params + (attention, norms) are replicated across all ranks. + + Requires Ray cluster with NIXL tensor transport support. + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + + import ray + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + tmp = tmp_path_factory.mktemp(f"megatron_dp{dp_size}ep{ep_size}_model") + _run_megatron_rdt_e2e( + n_gpus=n_gpus, + backend=f"megatron:d{dp_size}e{ep_size}", + pair_name=f"test_rdt_megatron_dp{dp_size}ep{ep_size}", + tag=f"megatron_dp{dp_size}ep{ep_size}", + tmp_path_factory=tmp_path_factory, + model_path=_make_truncated_moe_model(tmp, num_layers=4), + validate_param_names=_VALIDATE_PARAM_NAMES_MOE, + init_from_scratch=True, + ) diff --git a/tests/experimental/weight_update/test_sglang_integration.py b/tests/experimental/weight_update/test_sglang_integration.py index 0cf1b19e92..8f3999c3c3 100644 --- a/tests/experimental/weight_update/test_sglang_integration.py +++ b/tests/experimental/weight_update/test_sglang_integration.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -"""Integration tests for AReaL's custom SGLang server with /awex/* endpoints. +"""Integration tests for AReaL's custom SGLang server with /awex/* and /rdt/* endpoints. Requires GPU. Marked @pytest.mark.slow and @pytest.mark.sglang to exclude from default CI. Run manually: - uv run pytest tests/experimental/weight_update/test_sglang_server_integration.py -v -s + uv run pytest tests/experimental/weight_update/test_sglang_integration.py -v -s """ from __future__ import annotations @@ -40,7 +40,7 @@ def _get_test_model_path() -> str: @pytest.fixture(scope="module") def sglang_server(): - """Launch AReaL's custom SGLang server and yield (base_url, process). + """Launch AReaL's custom SGLang server with AWEX backend. The server is launched as a subprocess using the custom entry point that registers /awex/* weight update endpoints. @@ -48,6 +48,9 @@ def sglang_server(): port = find_free_ports(1)[0] model_path = _get_test_model_path() + env = os.environ.copy() + env["AREAL_WEIGHT_UPDATE_BACKEND"] = "awex" + process = subprocess.Popen( [ "python", @@ -68,6 +71,7 @@ def sglang_server(): ], stdout=sys.stdout, stderr=sys.stdout, + env=env, ) base_url = f"http://127.0.0.1:{port}" @@ -107,15 +111,81 @@ def sglang_server(): process.wait(timeout=10) -class TestSGLangServerHealth: - def test_server_healthy(self, sglang_server): - """Verify /health returns 200 after startup.""" - base_url, _ = sglang_server - resp = httpx.get(f"{base_url}/health", timeout=10.0) - assert resp.status_code == 200 +@pytest.fixture(scope="module") +def sglang_rdt_server(): + """Launch AReaL's custom SGLang server with RDT backend. + + Sets AREAL_WEIGHT_UPDATE_BACKEND=rdt to register /rdt/* endpoints. + """ + port = find_free_ports(1)[0] + model_path = _get_test_model_path() + + env = os.environ.copy() + env["AREAL_WEIGHT_UPDATE_BACKEND"] = "rdt" + + process = subprocess.Popen( + [ + "python", + "-m", + "areal.experimental.inference_service.sglang.launch_server", + "--model-path", + model_path, + "--port", + str(port), + "--host", + "127.0.0.1", + "--tp-size", + "1", + "--mem-fraction-static", + "0.7", + "--log-level", + "warning", + ], + stdout=sys.stdout, + stderr=sys.stdout, + env=env, + ) + + base_url = f"http://127.0.0.1:{port}" + + deadline = time.monotonic() + SERVER_STARTUP_TIMEOUT + healthy = False + while time.monotonic() < deadline: + try: + resp = httpx.get(f"{base_url}/health", timeout=5.0) + if resp.status_code == 200: + healthy = True + break + except (httpx.ConnectError, httpx.ReadTimeout): + pass + + if process.poll() is not None: + stdout = process.stdout.read().decode() if process.stdout else "" + stderr = process.stderr.read().decode() if process.stderr else "" + pytest.fail( + f"Server process exited prematurely (code {process.returncode}).\n" + f"stdout: {stdout[-2000:]}\nstderr: {stderr[-2000:]}" + ) + time.sleep(2.0) + + if not healthy: + process.kill() + process.wait(timeout=10) + pytest.fail(f"Server failed to become healthy within {SERVER_STARTUP_TIMEOUT}s") + + yield base_url, process + + os.kill(process.pid, signal.SIGTERM) + try: + process.wait(timeout=30) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=10) class TestAwexEndpointsRegistered: + """Test AWEX weight update endpoints on SGLang server.""" + def test_report_parallelism_endpoint_exists(self, sglang_server): """GET /awex/report_parallelism should return parallelism info.""" base_url, _ = sglang_server @@ -141,3 +211,33 @@ def test_report_parallelism_returns_valid_parallelism(self, sglang_server): data = resp.json() assert isinstance(data["world_size"], int) assert data["world_size"] >= 1 + + +class TestRdtEndpointsRegistered: + """Test RDT weight update endpoints on SGLang server.""" + + def test_report_parallelism_endpoint_exists(self, sglang_rdt_server): + """GET /rdt/report_parallelism should return parallelism info.""" + base_url, _ = sglang_rdt_server + resp = httpx.get(f"{base_url}/rdt/report_parallelism", timeout=30.0) + assert resp.status_code == 200 + data = resp.json() + assert "world_size" in data + + def test_report_weight_meta_endpoint_exists(self, sglang_rdt_server): + """POST /rdt/report_weight_meta should return weight metadata.""" + base_url, _ = sglang_rdt_server + resp = httpx.post(f"{base_url}/rdt/report_weight_meta", timeout=60.0) + assert resp.status_code == 200 + data = resp.json() + assert data.get("status") == "ok" + assert "meta" in data + + def test_report_parallelism_returns_valid_parallelism(self, sglang_rdt_server): + """world_size must be a positive integer (>= 1).""" + base_url, _ = sglang_rdt_server + resp = httpx.get(f"{base_url}/rdt/report_parallelism", timeout=30.0) + assert resp.status_code == 200 + data = resp.json() + assert isinstance(data["world_size"], int) + assert data["world_size"] >= 1 diff --git a/tests/experimental/weight_update/test_wu_controller.py b/tests/experimental/weight_update/test_wu_controller.py index 302d9356b7..06db4b1179 100644 --- a/tests/experimental/weight_update/test_wu_controller.py +++ b/tests/experimental/weight_update/test_wu_controller.py @@ -114,6 +114,27 @@ def test_connect_disk_mode_sends_disk_fields(self, ctrl): timeout=10.0, ) + def test_connect_rdt_mode_sends_rdt_mode(self, ctrl): + ctrl._session.post.return_value = _mock_response(200, {"pair_name": "pair0"}) + train_urls = ["http://train1:8000"] + infer_urls = ["http://infer1:8000"] + + ctrl.connect("pair0", train_urls, infer_urls, mode="rdt") + + ctrl._session.post.assert_called_once_with( + f"{GATEWAY_URL}/connect", + json={ + "pair_name": "pair0", + "train_worker_urls": train_urls, + "inference_worker_urls": infer_urls, + "mode": "rdt", + "save_path": "", + "use_lora": False, + "lora_name": "", + }, + timeout=10.0, + ) + class TestUpdateWeights: def test_update_weights_returns_result(self, ctrl): @@ -163,7 +184,8 @@ def test_disconnect_noop_when_not_connected(self, ctrl): class TestLifecycle: - def test_full_lifecycle(self, ctrl): + @pytest.mark.parametrize("mode", ["awex", "disk", "rdt"]) + def test_full_lifecycle(self, ctrl, mode): connect_resp = _mock_response(200, {"pair_name": "pair0"}) update_resp = _mock_response( 200, {"status": "ok", "version": 1, "duration_ms": 50.0, "error": None} @@ -171,7 +193,16 @@ def test_full_lifecycle(self, ctrl): disconnect_resp = _mock_response(200, {"status": "ok", "pair_name": "pair0"}) ctrl._session.post.side_effect = [connect_resp, update_resp, disconnect_resp] - ctrl.connect("pair0", ["http://t:8000"], ["http://i:8000"]) + if mode == "disk": + ctrl.connect( + "pair0", + ["http://t:8000"], + ["http://i:8000"], + mode=mode, + save_path="/tmp/w", + ) + else: + ctrl.connect("pair0", ["http://t:8000"], ["http://i:8000"], mode=mode) assert ctrl._pair_name == "pair0" result = ctrl.update_weights(version=1) diff --git a/tests/experimental/weight_update/torchrun/run_rdt_weight_transfer.py b/tests/experimental/weight_update/torchrun/run_rdt_weight_transfer.py new file mode 100644 index 0000000000..5b4424c13a --- /dev/null +++ b/tests/experimental/weight_update/torchrun/run_rdt_weight_transfer.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import os +import sys + +import torch +import torch.distributed as dist + +_PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..", "..") +) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +from tests.experimental.weight_update.torchrun.dist_utils import ( # noqa: E402 + print_rank0, + write_result, +) + +from areal.infra.platforms import current_platform # noqa: E402 + +# Skip YR tests - only test NIXL (CUDA GPU) +# YR requires ray_ascend which may not be available in test environment +assert current_platform.device_type == "cuda", "RDT tests require CUDA GPU (NIXL)" + + +def run_rdt_weight_transfer_lifecycle(output=None): + """Test: Full RDT weight transfer lifecycle with real WeightTransportActor + CUDA IPC. + + This test validates the complete RDT flow: + 1. TW creates WeightTransportActor with GPU binding + 2. TW creates tensors, uses CUDA IPC (share_memory_ + reduce_tensor) + 3. TW calls actor.store_ipc_handles.remote() + 4. IW receives TW actor handle, calls get_weights_tensor_nixl.remote() + 5. IW verifies transferred weights + 6. IW calls clear_ipc_handles.remote() to cleanup + """ + rank = dist.get_rank() + world_size = dist.get_world_size() + + print_rank0( + "=== RDT Weight Transfer Lifecycle Test (Real WeightTransportActor) ===" + ) + + infer_world_size = world_size // 2 + is_inference = rank < infer_world_size + + # All processes use cuda:0 + device = torch.device("cuda:0") + + print_rank0( + f" Inference ranks: 0..{infer_world_size - 1}, " + f"Training ranks: {infer_world_size}..{world_size - 1}" + ) + + try: + import ray + from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + from torch.multiprocessing.reductions import reduce_tensor + + if not ray.is_initialized(): + ray.init(address="auto", ignore_reinit_error=True) + + from areal.experimental.weight_update.rdt import ( + deserialize_actor_handle_bytes, + serialize_actor_handle_bytes, + ) + from areal.experimental.weight_update.rdt.weight_transport_actor import ( + WeightTransportActor, + ) + + param_shapes = [(512, 256), (256,), (1024, 512), (2048, 1024)] + pair_name = "lifecycle_pair" + version = 1 + + # Phase 1: TW creates WeightTransportActor and distributes handle + tw_handles = {} # IW will store these + + if not is_inference: + # Training side: create WeightTransportActor + tw_rank = rank - infer_world_size + infer_rank = tw_rank # Corresponding IW rank + + current_node_id = ray.get_runtime_context().get_node_id() + current_visible_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + + print( + f"[TW rank {rank}] CUDA_VISIBLE_DEVICES: '{current_visible_gpus}'", + flush=True, + ) + + tw_actor = WeightTransportActor.options( + name=f"tw-actor-{tw_rank}", + num_gpus=0.0001, + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=current_node_id, soft=False + ), + runtime_env={ + "env_vars": {"CUDA_VISIBLE_DEVICES": current_visible_gpus} + }, + ).remote() + + # Broadcast handle to all inference ranks + encoded = serialize_actor_handle_bytes(tw_actor) + + for iw_rank in range(infer_world_size): + length_tensor = torch.tensor([len(encoded)], dtype=torch.long) + dist.send(length_tensor, dst=iw_rank) + handle_tensor = torch.tensor( + [ord(c) for c in encoded], dtype=torch.long + ) + dist.send(handle_tensor, dst=iw_rank) + + print_rank0(f" TW rank {rank}: Distributed handle to all IW ranks") + + if is_inference: + # Inference side: receive handles from all TW ranks + for tw_idx in range(infer_world_size): + tw_global_rank = tw_idx + infer_world_size + + length_tensor = torch.zeros(1, dtype=torch.long) + dist.recv(length_tensor, src=tw_global_rank) + handle_length = int(length_tensor.item()) + + handle_tensor = torch.zeros(handle_length, dtype=torch.long) + dist.recv(handle_tensor, src=tw_global_rank) + encoded = "".join([chr(int(c.item())) for c in handle_tensor]) + + tw_handle = deserialize_actor_handle_bytes(encoded) + tw_handles[tw_idx] = tw_handle + + print_rank0(f" IW rank {rank}: Received {len(tw_handles)} TW handles") + + dist.barrier() + + # Phase 2: TW creates tensors and stores IPC handles + if not is_inference: + tw_rank = rank - infer_world_size + infer_rank = tw_rank + + # Create weights on cuda:0 + torch.manual_seed(100 + tw_rank) + params = { + f"model.layers.{i}.weight": torch.randn(shape, device=device) + for i, shape in enumerate(param_shapes) + } + params["model.norm.weight"] = torch.randn(param_shapes[1][0], device=device) + + # Create IPC handles via share_memory_() + reduce_tensor() + ipc_handles = {} + for name, tensor in params.items(): + tensor.share_memory_() + rebuild_fn, tensor_meta = reduce_tensor(tensor) + ipc_handles[name] = { + "rebuild_fn": rebuild_fn, + "tensor_meta": tensor_meta, + } + + # Store IPC handles in actor + ray.get( + tw_actor.store_ipc_handles.remote( + pair_name, infer_rank, version, ipc_handles + ) + ) + + print_rank0( + f" TW rank {rank}: Stored IPC handles for infer_rank {infer_rank}" + ) + + dist.barrier() + + # Phase 3: IW pulls weights from TW via Ray RPC (NIXL transport) + if is_inference: + tw_idx = rank + infer_rank = tw_idx + + if tw_idx in tw_handles: + tw_handle = tw_handles[tw_idx] + + # IW pulls weights via tensor_transport + received_params = ray.get( + tw_handle.get_weights_tensor_nixl.remote( + pair_name, infer_rank, version + ) + ) + + print_rank0(f" IW rank {rank}: Pulled weights via Ray RPC") + + # Phase 4: Verify transferred weights + torch.manual_seed(100 + tw_idx) + expected_params = { + f"model.layers.{i}.weight": torch.randn(shape, device=device) + for i, shape in enumerate(param_shapes) + } + expected_params["model.norm.weight"] = torch.randn( + param_shapes[1][0], device=device + ) + + verify_success = True + for name in expected_params: + expected = expected_params[name] + actual = received_params[name] + try: + torch.testing.assert_close( + actual, expected, rtol=1e-5, atol=1e-5 + ) + except AssertionError: + max_diff = (actual - expected).abs().max().item() + print_rank0(f" MISMATCH {name}: max_diff={max_diff}") + verify_success = False + + print_rank0( + f" IW rank {rank}: Weight verification " + f"{'PASSED' if verify_success else 'FAILED'}" + ) + + # Phase 5: Cleanup IPC handles + ray.get( + tw_handle.clear_ipc_handles.remote(pair_name, infer_rank, version) + ) + print_rank0(f" IW rank {rank}: Cleaned up IPC handles") + + print_rank0(" RDT weight transfer lifecycle: PASSED") + success = True + + except Exception as e: + print_rank0(f" FAILED: {e}") + import traceback + + traceback.print_exc() + success = False + + dist.barrier() + if rank == 0 and output: + write_result(output, success) + return success + + +TEST_REGISTRY = { + "rdt_weight_transfer_lifecycle": run_rdt_weight_transfer_lifecycle, +} + + +def main(): + parser = argparse.ArgumentParser(description="RDT Weight Transfer Tests") + parser.add_argument( + "--test_type", + type=str, + required=True, + choices=list(TEST_REGISTRY.keys()), + ) + parser.add_argument("--output", type=str, default=None) + args = parser.parse_args() + + # Modify CUDA_VISIBLE_DEVICES BEFORE CUDA context initialization + # Each process gets its own GPU (IW and TW on different GPUs) + rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0"))) + + all_gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + gpu_list = all_gpus.split(",") + + # Each process has its own GPU (IW and TW on different GPUs) + gpu_index = rank + my_gpu = gpu_list[gpu_index] if gpu_index < len(gpu_list) else gpu_list[0] + os.environ["CUDA_VISIBLE_DEVICES"] = my_gpu + + dist.init_process_group(backend="gloo") + torch.cuda.set_device(0) # cuda:0 = my_gpu + + rank = dist.get_rank() + + print_rank0("=" * 60) + print_rank0(f"Running: {args.test_type}") + print_rank0("=" * 60) + + try: + test_fn = TEST_REGISTRY[args.test_type] + success = test_fn(args.output) + + dist.barrier() + if success: + print_rank0(f"\n{args.test_type}: PASSED") + else: + print_rank0(f"\n{args.test_type}: FAILED") + if rank == 0 and args.output: + write_result(args.output, False) + except Exception as e: + print(f"Rank {rank} failed: {e}") + import traceback + + traceback.print_exc() + if rank == 0 and args.output: + write_result(args.output, False) + raise + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()