Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions areal/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,13 @@ def export_stats(self) -> dict[str, float]:
return stats_tracker.export_all(reduce_group=None)

@classmethod
def as_controller(
cls, config: InferenceEngineConfig, scheduler: Scheduler
) -> RolloutController:
def as_controller(cls, config: InferenceEngineConfig, scheduler: Scheduler):
if config._version == "v2":
from areal.experimental.inference_service.controller.controller import (
RolloutControllerV2,
)

return RolloutControllerV2(config=config, scheduler=scheduler)
return RolloutController(cls, config=config, scheduler=scheduler)

def clear_batches(self, shard_ids: list[str]) -> None:
Expand Down
10 changes: 7 additions & 3 deletions areal/engine/vllm_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,9 +493,13 @@ def export_stats(self) -> dict[str, float]:
return stats_tracker.export_all(reduce_group=None)

@classmethod
def as_controller(
cls, config: InferenceEngineConfig, scheduler: Scheduler
) -> RolloutController:
def as_controller(cls, config: InferenceEngineConfig, scheduler: Scheduler):
if config._version == "v2":
from areal.experimental.inference_service.controller.controller import (
RolloutControllerV2,
)

return RolloutControllerV2(config=config, scheduler=scheduler)
return RolloutController(cls, config=config, scheduler=scheduler)

def clear_batches(self, shard_ids: list[str]) -> None:
Expand Down
207 changes: 190 additions & 17 deletions areal/experimental/training_service/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
import traceback
from threading import Lock
from typing import TYPE_CHECKING, Any
from uuid import uuid4

Expand All @@ -30,11 +31,6 @@
class GatewayTrainController:
_GUARD_SUFFIX = "-guard"

# TODO(agent): Controller v2 is not yet a drop-in replacement for
# TrainController on PPO/GRPO paths. Add parity for connect_engine,
# prepare_batch/rollout_batch, and update_weights (plus the matching
# gateway/data-proxy/worker endpoints), or keep RL controllers on v1.

def __init__(
self,
train_engine: type[TrainEngine] | str,
Expand All @@ -52,11 +48,20 @@ def __init__(
self._router_addr: str = ""
self._model_addr: str = ""
self._worker_addrs: list[str] = []
self._guard_addrs: list[str] = []
self._forked_services: list[tuple[str, str, int]] = []
self._service_roles: list[str] = []
self._role: str = ""
self._parallel_strategy = self.train_alloc.parallel
self._own_process_group = False
self.rollout: Any | None = None
self._weight_update_ctrl: Any | None = None

# Version management
self._version_lock = Lock()
self._version = 0

# Shared HTTP client (lazy, per-event-loop)
self._async_client: Any | None = None
self._async_client_loop: asyncio.AbstractEventLoop | None = None

Expand Down Expand Up @@ -205,6 +210,15 @@ async def _async_initialize(
guard_addr_0 = f"http://{format_hostport(guard_workers[0].ip, int(guard_workers[0].worker_ports[0]))}"
master_addr = guard_workers[0].ip

# Persist guard addresses so connect_engine() can allocate
# ports later (e.g. for the weight-update NCCL group).
def _guard_addr(worker: Worker) -> str:
return (
f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}"
)

self._guard_addrs = [_guard_addr(w) for w in guard_workers]

client = await self._get_async_client()
resp = await client.post(
f"{guard_addr_0}/alloc_ports", json={"count": 1}, timeout=30.0
Expand All @@ -215,10 +229,6 @@ async def _async_initialize(
# ==============================================================
# Step 1.5: Set NCCL env on each guard so forked workers inherit it
# ==============================================================
def _guard_addr(worker: Worker) -> str:
return (
f"http://{format_hostport(worker.ip, int(worker.worker_ports[0]))}"
)

await self._async_set_guards_env(
guard_workers,
Expand Down Expand Up @@ -767,6 +777,9 @@ def eval(self) -> GatewayTrainController:
def set_version(self, version: int) -> None:
from areal.infra.rpc.serialization import serialize_value

with self._version_lock:
self._version = version

self._gateway_post(
"/set_version",
{
Expand All @@ -776,7 +789,8 @@ def set_version(self, version: int) -> None:
)

def get_version(self) -> int:
return int(self._gateway_get_result("/get_version"))
with self._version_lock:
return self._version

def save(self, meta: Any) -> None:
from areal.infra.rpc.serialization import serialize_value
Expand Down Expand Up @@ -832,13 +846,22 @@ def get_device_stats(self) -> Any:
return self._gateway_post_result("/get_device_stats", payload)

def config_perf_tracer(self, config: Any, role: str) -> None:
from areal.infra.rpc.serialization import serialize_value
self._ensure_initialized()

payload = {
"args": serialize_value([]),
"kwargs": serialize_value({"config": config, "role": role}),
}
self._gateway_post("/config_perf_tracer", payload)
async def _call() -> None:
tasks = [
self._call_worker_engine_endpoint(
addr,
"/config_perf_tracer",
args=[],
kwargs={"config": config, "rank": rank, "role": role},
timeout=self.config.request_timeout,
)
for rank, addr in enumerate(self._worker_addrs)
]
await asyncio.gather(*tasks)

run_async_task(_call)

def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
from areal.infra.rpc.serialization import serialize_value
Expand All @@ -850,10 +873,31 @@ def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None
self._gateway_post("/save_perf_tracer", payload)

def clear_batches(self, *targets: Any) -> None:
from areal.infra.rpc.rtensor import RTensor, flatten_shard_ids
from areal.infra.rpc.serialization import serialize_value

# Step 1: HTTP DELETE to storage nodes to evict _storage entries
# (mirrors TrainController._async_clear_batches)
shards_by_node = RTensor.collect_shards(targets)
if shards_by_node:

async def _clear_storage():
await asyncio.gather(
*[
RTensor.clear_node(addr, sids)
for addr, sids in shards_by_node.items()
],
return_exceptions=True,
)

run_async_task(_clear_storage)

# Step 2: Drain _fetch_buffer on workers via engine.clear_batches(shard_ids)
shard_ids = flatten_shard_ids(targets)
if not shard_ids:
return
payload = {
"args": serialize_value(list(targets)),
"args": serialize_value([shard_ids]),
"kwargs": serialize_value({}),
}
self._gateway_post("/clear_batches", payload)
Expand Down Expand Up @@ -883,6 +927,135 @@ def data_parallel_rank(self) -> int:
def cpu_group(self):
return None

@property
def train_worker_urls(self) -> list[str]:
return list(self._worker_addrs)

# -- RL parity methods (connect_engine / update_weights / batch) --------

def connect_engine(self, rollout: Any, meta: Any) -> None:
self._ensure_initialized()
import requests

from areal.experimental.inference_service.controller.controller import (
RolloutControllerV2,
)
from areal.experimental.weight_update.controller.config import (
WeightUpdateControllerConfig,
)
from areal.experimental.weight_update.controller.controller import (
WeightUpdateController,
)

if not isinstance(rollout, RolloutControllerV2):
raise TypeError(
f"GatewayTrainController requires RolloutControllerV2, "
f"got {type(rollout).__name__}. "
f"Ensure _version='v2' is set on InferenceEngineConfig."
)

self.rollout = rollout

if meta.type != "awex":
raise ValueError(
f"GatewayTrainController only supports 'awex' weight updates, got '{meta.type}'"
)

ctrl = WeightUpdateController(
WeightUpdateControllerConfig(
admin_api_key=self.config.admin_api_key,
log_level=self.config.log_level,
)
)
ctrl.initialize()

inference_urls: list[str] = rollout.inference_worker_urls
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The attribute inference_worker_urls is not defined in RolloutControllerV2, which will cause an AttributeError at runtime. You should use the internal _inf_addrs attribute or add a public property to RolloutControllerV2 to expose these URLs.

Suggested change
inference_urls: list[str] = rollout.inference_worker_urls
inference_urls: list[str] = rollout._inf_addrs


nccl_master_addr = ""
nccl_master_port = 0
if self._guard_addrs:
resp = requests.post(
f"{self._guard_addrs[0]}/alloc_ports",
json={"count": 1},
timeout=30,
)
resp.raise_for_status()
port_data = resp.json()
nccl_master_addr = port_data["host"]
nccl_master_port = port_data["ports"][0]

pair_name = f"{self._role}-rollout"
ctrl.connect(
pair_name=pair_name,
train_worker_urls=self._worker_addrs,
inference_worker_urls=inference_urls,
nccl_master_addr=nccl_master_addr,
nccl_master_port=nccl_master_port,
)
self._weight_update_ctrl = ctrl
logger.info(
"WeightUpdateController connected (pair=%s, train=%d, inf=%d)",
pair_name,
len(self._worker_addrs),
len(inference_urls),
)

def update_weights(self, meta: Any) -> None:
if self._weight_update_ctrl is None or self.rollout is None:
raise RuntimeError(
"connect_engine() must be called before update_weights()"
)
self.rollout.pause_generation()
assert meta.version is not None and meta.version > 0, (
f"meta.version must be a positive integer, got {meta.version}"
)
result = self._weight_update_ctrl.update_weights(version=meta.version)
self.rollout.continue_generation()
Comment on lines +1008 to +1013
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block has multiple issues:

  1. The calls to pause_generation() and continue_generation() are asynchronous but called synchronously, returning coroutines without executing them.
  2. These calls are redundant and potentially harmful here. PPOTrainer.train already manages the rollout pause/resume state. In RolloutControllerV2, pause() correctly stops generation. Resuming it here via continue_generation() would break the trainer's expectation that inference remains paused during the subsequent save and evaluation steps.
  3. The assert should be replaced with a proper runtime check as assertions can be disabled in production.
        if meta.version is None or meta.version <= 0:
            raise ValueError(f"meta.version must be a positive integer, got {meta.version}")
        result = self._weight_update_ctrl.update_weights(version=meta.version)

logger.info(
"Weight update v%d completed (%s, %.0fms)",
meta.version,
result.status,
result.duration_ms,
)

def prepare_batch(
self,
dataloader: Any,
workflow: Any,
workflow_kwargs: dict[str, Any],
should_accept_fn: str | None = None,
group_size: int = 1,
dynamic_bs: bool = False,
) -> list[dict[str, Any]]:
if self.rollout is None:
raise RuntimeError("connect_engine() must be called before prepare_batch()")
return self.rollout.prepare_batch(
dataloader=dataloader,
workflow=workflow,
workflow_kwargs=workflow_kwargs,
should_accept_fn=should_accept_fn,
group_size=group_size,
dynamic_bs=dynamic_bs,
)

def rollout_batch(
self,
data: list[dict[str, Any]],
workflow: Any,
workflow_kwargs: dict[str, Any],
should_accept_fn: str | None = None,
group_size: int = 1,
) -> list[dict[str, Any]]:
if self.rollout is None:
raise RuntimeError("connect_engine() must be called before rollout_batch()")
return self.rollout.rollout_batch(
data=data,
workflow=workflow,
workflow_kwargs=workflow_kwargs,
should_accept_fn=should_accept_fn,
group_size=group_size,
)

def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
self._parallel_strategy = parallel_strategy
import torch.distributed as dist
Expand Down
25 changes: 15 additions & 10 deletions areal/experimental/weight_update/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,25 @@ def connect(
use_lora: bool = False,
lora_name: str = "",
colocate: bool = False,
nccl_master_addr: str = "",
nccl_master_port: int = 0,
) -> None:
self._pair_name = pair_name
payload: dict[str, Any] = {
"pair_name": pair_name,
"train_worker_urls": train_worker_urls,
"inference_worker_urls": inference_worker_urls,
"mode": mode,
"save_path": save_path,
"use_lora": use_lora,
"lora_name": lora_name,
"colocate": colocate,
"nccl_master_addr": nccl_master_addr,
"nccl_master_port": nccl_master_port,
}
resp = self._http.post(
f"{self._gateway_url}/connect",
json={
"pair_name": pair_name,
"train_worker_urls": train_worker_urls,
"inference_worker_urls": inference_worker_urls,
"mode": mode,
"save_path": save_path,
"use_lora": use_lora,
"lora_name": lora_name,
"colocate": colocate,
},
json=payload,
timeout=self.config.request_timeout,
)
resp.raise_for_status()
Expand Down
Loading
Loading