-
Notifications
You must be signed in to change notification settings - Fork 508
feat:enable v2 training pipeline with controller parity #1363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
sitabulaixizawaluduo
merged 4 commits into
main
from
feat/training-controller-v2-parity
May 29, 2026
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
2d66f6a
feat: enable v2 training pipeline with controller parity
sitabulaixizawaluduo 2541bbd
fix: update wu controller connect method
sitabulaixizawaluduo eca4c31
chore: unblock CI for grpo and grpo_lora with admin key + lora name
sitabulaixizawaluduo d7a834c
chore: unblock CI for v2 parity
sitabulaixizawaluduo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| import threading | ||
| import time | ||
| import traceback | ||
| from threading import Lock | ||
| from typing import TYPE_CHECKING, Any | ||
| from uuid import uuid4 | ||
|
|
||
|
|
@@ -30,11 +31,6 @@ | |
| class GatewayTrainController: | ||
| _GUARD_SUFFIX = "-guard" | ||
|
|
||
| # TODO(agent): Controller v2 is not yet a drop-in replacement for | ||
| # TrainController on PPO/GRPO paths. Add parity for connect_engine, | ||
| # prepare_batch/rollout_batch, and update_weights (plus the matching | ||
| # gateway/data-proxy/worker endpoints), or keep RL controllers on v1. | ||
|
|
||
| def __init__( | ||
| self, | ||
| train_engine: type[TrainEngine] | str, | ||
|
|
@@ -52,11 +48,20 @@ def __init__( | |
| self._router_addr: str = "" | ||
| self._model_addr: str = "" | ||
| self._worker_addrs: list[str] = [] | ||
| self._guard_addrs: list[str] = [] | ||
| self._forked_services: list[tuple[str, str, int]] = [] | ||
| self._service_roles: list[str] = [] | ||
| self._role: str = "" | ||
| self._parallel_strategy = self.train_alloc.parallel | ||
| self._own_process_group = False | ||
| self.rollout: Any | None = None | ||
| self._weight_update_ctrl: Any | None = None | ||
|
|
||
| # Version management | ||
| self._version_lock = Lock() | ||
| self._version = 0 | ||
|
|
||
| # Shared HTTP client (lazy, per-event-loop) | ||
| self._async_client: Any | None = None | ||
| self._async_client_loop: asyncio.AbstractEventLoop | None = None | ||
|
|
||
|
|
@@ -205,6 +210,15 @@ async def _async_initialize( | |
| guard_addr_0 = f"http://{format_hostport(guard_workers[0].ip, int(guard_workers[0].worker_ports[0]))}" | ||
| master_addr = guard_workers[0].ip | ||
|
|
||
| # Persist guard addresses so connect_engine() can allocate | ||
| # ports later (e.g. for the weight-update NCCL group). | ||
| def _guard_addr(worker: Worker) -> str: | ||
| return ( | ||
| f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" | ||
| ) | ||
|
|
||
| self._guard_addrs = [_guard_addr(w) for w in guard_workers] | ||
|
|
||
| client = await self._get_async_client() | ||
| resp = await client.post( | ||
| f"{guard_addr_0}/alloc_ports", json={"count": 1}, timeout=30.0 | ||
|
|
@@ -215,10 +229,6 @@ async def _async_initialize( | |
| # ============================================================== | ||
| # Step 1.5: Set NCCL env on each guard so forked workers inherit it | ||
| # ============================================================== | ||
| def _guard_addr(worker: Worker) -> str: | ||
| return ( | ||
| f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}" | ||
| ) | ||
|
|
||
| await self._async_set_guards_env( | ||
| guard_workers, | ||
|
|
@@ -767,6 +777,9 @@ def eval(self) -> GatewayTrainController: | |
| def set_version(self, version: int) -> None: | ||
| from areal.infra.rpc.serialization import serialize_value | ||
|
|
||
| with self._version_lock: | ||
| self._version = version | ||
|
|
||
| self._gateway_post( | ||
| "/set_version", | ||
| { | ||
|
|
@@ -776,7 +789,8 @@ def set_version(self, version: int) -> None: | |
| ) | ||
|
|
||
| def get_version(self) -> int: | ||
| return int(self._gateway_get_result("/get_version")) | ||
| with self._version_lock: | ||
| return self._version | ||
|
|
||
| def save(self, meta: Any) -> None: | ||
| from areal.infra.rpc.serialization import serialize_value | ||
|
|
@@ -832,13 +846,22 @@ def get_device_stats(self) -> Any: | |
| return self._gateway_post_result("/get_device_stats", payload) | ||
|
|
||
| def config_perf_tracer(self, config: Any, role: str) -> None: | ||
| from areal.infra.rpc.serialization import serialize_value | ||
| self._ensure_initialized() | ||
|
|
||
| payload = { | ||
| "args": serialize_value([]), | ||
| "kwargs": serialize_value({"config": config, "role": role}), | ||
| } | ||
| self._gateway_post("/config_perf_tracer", payload) | ||
| async def _call() -> None: | ||
| tasks = [ | ||
| self._call_worker_engine_endpoint( | ||
| addr, | ||
| "/config_perf_tracer", | ||
| args=[], | ||
| kwargs={"config": config, "rank": rank, "role": role}, | ||
| timeout=self.config.request_timeout, | ||
| ) | ||
| for rank, addr in enumerate(self._worker_addrs) | ||
| ] | ||
| await asyncio.gather(*tasks) | ||
|
|
||
| run_async_task(_call) | ||
|
|
||
| def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None: | ||
| from areal.infra.rpc.serialization import serialize_value | ||
|
|
@@ -850,10 +873,31 @@ def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None | |
| self._gateway_post("/save_perf_tracer", payload) | ||
|
|
||
| def clear_batches(self, *targets: Any) -> None: | ||
| from areal.infra.rpc.rtensor import RTensor, flatten_shard_ids | ||
| from areal.infra.rpc.serialization import serialize_value | ||
|
|
||
| # Step 1: HTTP DELETE to storage nodes to evict _storage entries | ||
| # (mirrors TrainController._async_clear_batches) | ||
| shards_by_node = RTensor.collect_shards(targets) | ||
| if shards_by_node: | ||
|
|
||
| async def _clear_storage(): | ||
| await asyncio.gather( | ||
| *[ | ||
| RTensor.clear_node(addr, sids) | ||
| for addr, sids in shards_by_node.items() | ||
| ], | ||
| return_exceptions=True, | ||
| ) | ||
|
|
||
| run_async_task(_clear_storage) | ||
|
|
||
| # Step 2: Drain _fetch_buffer on workers via engine.clear_batches(shard_ids) | ||
| shard_ids = flatten_shard_ids(targets) | ||
| if not shard_ids: | ||
| return | ||
| payload = { | ||
| "args": serialize_value(list(targets)), | ||
| "args": serialize_value([shard_ids]), | ||
| "kwargs": serialize_value({}), | ||
| } | ||
| self._gateway_post("/clear_batches", payload) | ||
|
|
@@ -883,6 +927,135 @@ def data_parallel_rank(self) -> int: | |
| def cpu_group(self): | ||
| return None | ||
|
|
||
| @property | ||
| def train_worker_urls(self) -> list[str]: | ||
| return list(self._worker_addrs) | ||
|
|
||
| # -- RL parity methods (connect_engine / update_weights / batch) -------- | ||
|
|
||
| def connect_engine(self, rollout: Any, meta: Any) -> None: | ||
| self._ensure_initialized() | ||
| import requests | ||
|
|
||
| from areal.experimental.inference_service.controller.controller import ( | ||
| RolloutControllerV2, | ||
| ) | ||
| from areal.experimental.weight_update.controller.config import ( | ||
| WeightUpdateControllerConfig, | ||
| ) | ||
| from areal.experimental.weight_update.controller.controller import ( | ||
| WeightUpdateController, | ||
| ) | ||
|
|
||
| if not isinstance(rollout, RolloutControllerV2): | ||
| raise TypeError( | ||
| f"GatewayTrainController requires RolloutControllerV2, " | ||
| f"got {type(rollout).__name__}. " | ||
| f"Ensure _version='v2' is set on InferenceEngineConfig." | ||
| ) | ||
|
|
||
| self.rollout = rollout | ||
|
|
||
| if meta.type != "awex": | ||
| raise ValueError( | ||
| f"GatewayTrainController only supports 'awex' weight updates, got '{meta.type}'" | ||
| ) | ||
|
|
||
| ctrl = WeightUpdateController( | ||
| WeightUpdateControllerConfig( | ||
| admin_api_key=self.config.admin_api_key, | ||
| log_level=self.config.log_level, | ||
| ) | ||
| ) | ||
| ctrl.initialize() | ||
|
|
||
| inference_urls: list[str] = rollout.inference_worker_urls | ||
|
|
||
| nccl_master_addr = "" | ||
| nccl_master_port = 0 | ||
| if self._guard_addrs: | ||
| resp = requests.post( | ||
| f"{self._guard_addrs[0]}/alloc_ports", | ||
| json={"count": 1}, | ||
| timeout=30, | ||
| ) | ||
| resp.raise_for_status() | ||
| port_data = resp.json() | ||
| nccl_master_addr = port_data["host"] | ||
| nccl_master_port = port_data["ports"][0] | ||
|
|
||
| pair_name = f"{self._role}-rollout" | ||
| ctrl.connect( | ||
| pair_name=pair_name, | ||
| train_worker_urls=self._worker_addrs, | ||
| inference_worker_urls=inference_urls, | ||
| nccl_master_addr=nccl_master_addr, | ||
| nccl_master_port=nccl_master_port, | ||
| ) | ||
| self._weight_update_ctrl = ctrl | ||
| logger.info( | ||
| "WeightUpdateController connected (pair=%s, train=%d, inf=%d)", | ||
| pair_name, | ||
| len(self._worker_addrs), | ||
| len(inference_urls), | ||
| ) | ||
|
|
||
| def update_weights(self, meta: Any) -> None: | ||
| if self._weight_update_ctrl is None or self.rollout is None: | ||
| raise RuntimeError( | ||
| "connect_engine() must be called before update_weights()" | ||
| ) | ||
| self.rollout.pause_generation() | ||
| assert meta.version is not None and meta.version > 0, ( | ||
| f"meta.version must be a positive integer, got {meta.version}" | ||
| ) | ||
| result = self._weight_update_ctrl.update_weights(version=meta.version) | ||
| self.rollout.continue_generation() | ||
|
Comment on lines
+1008
to
+1013
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block has multiple issues:
if meta.version is None or meta.version <= 0:
raise ValueError(f"meta.version must be a positive integer, got {meta.version}")
result = self._weight_update_ctrl.update_weights(version=meta.version) |
||
| logger.info( | ||
| "Weight update v%d completed (%s, %.0fms)", | ||
| meta.version, | ||
| result.status, | ||
| result.duration_ms, | ||
| ) | ||
|
|
||
| def prepare_batch( | ||
| self, | ||
| dataloader: Any, | ||
| workflow: Any, | ||
| workflow_kwargs: dict[str, Any], | ||
| should_accept_fn: str | None = None, | ||
| group_size: int = 1, | ||
| dynamic_bs: bool = False, | ||
| ) -> list[dict[str, Any]]: | ||
| if self.rollout is None: | ||
| raise RuntimeError("connect_engine() must be called before prepare_batch()") | ||
| return self.rollout.prepare_batch( | ||
| dataloader=dataloader, | ||
| workflow=workflow, | ||
| workflow_kwargs=workflow_kwargs, | ||
| should_accept_fn=should_accept_fn, | ||
| group_size=group_size, | ||
| dynamic_bs=dynamic_bs, | ||
| ) | ||
|
|
||
| def rollout_batch( | ||
| self, | ||
| data: list[dict[str, Any]], | ||
| workflow: Any, | ||
| workflow_kwargs: dict[str, Any], | ||
| should_accept_fn: str | None = None, | ||
| group_size: int = 1, | ||
| ) -> list[dict[str, Any]]: | ||
| if self.rollout is None: | ||
| raise RuntimeError("connect_engine() must be called before rollout_batch()") | ||
| return self.rollout.rollout_batch( | ||
| data=data, | ||
| workflow=workflow, | ||
| workflow_kwargs=workflow_kwargs, | ||
| should_accept_fn=should_accept_fn, | ||
| group_size=group_size, | ||
| ) | ||
|
|
||
| def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): | ||
| self._parallel_strategy = parallel_strategy | ||
| import torch.distributed as dist | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attribute
inference_worker_urlsis not defined inRolloutControllerV2, which will cause anAttributeErrorat runtime. You should use the internal_inf_addrsattribute or add a public property toRolloutControllerV2to expose these URLs.