Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import json
import os
import warnings
from dataclasses import MISSING as dataclass_missing
from dataclasses import asdict, dataclass, field, fields
from enum import Enum
Expand Down Expand Up @@ -2877,7 +2878,32 @@ def __post_init__(self):


@dataclass
class TeacherConfig(PPOActorConfig):
class TeacherConfig:
engine_type: str = field(
default="rollout",
metadata={
"help": "Teacher engine type. 'rollout' uses inference engine scoring; "
"'train' uses the legacy train-engine teacher path.",
"choices": ["rollout", "train"],
},
)
rollout: InferenceEngineConfig | None = field(default=None)
train: PPOActorConfig | None = field(
default=None,
metadata={
"help": "Legacy train-engine teacher config. Required when engine_type='train'."
},
)
path: str = field(
default="",
metadata={
"help": "Teacher model path. If set, overrides shared rollout backend model path."
},
)
offload: bool = field(
default=False,
metadata={"help": "Whether to offload teacher rollout model between steps"},
)
rl_loss_weight: float = field(
default=1.0,
metadata={"help": "RL loss weight"},
Expand All @@ -2888,6 +2914,22 @@ class TeacherConfig(PPOActorConfig):
metadata={"help": "Distillation loss weight"},
)

def __post_init__(self):
if self.rollout is not None and self.train is not None:
warnings.warn(
"Both teacher.rollout and teacher.train are configured; "
f"teacher.engine_type={self.engine_type!r} selects which one is used.",
stacklevel=2,
)
if self.engine_type == "rollout" and self.rollout is None:
raise ValueError(
"teacher.rollout must be provided when teacher.engine_type='rollout'."
)
if self.engine_type == "train" and self.train is None:
raise ValueError(
"teacher.train must be provided when teacher.engine_type='train'."
)


@dataclass
class PPOConfig(BaseExperimentConfig):
Expand Down
9 changes: 9 additions & 0 deletions areal/api/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,15 @@ def get_version(self) -> int:
"""
raise NotImplementedError()

def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor]:
"""Compute token log-probabilities for teacher distillation.

Implementations support this as an inference-side scoring API.
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement compute_logp()."
)

def submit(
self,
data: dict[str, Any],
Expand Down
38 changes: 38 additions & 0 deletions areal/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import pybase64
import torch
from torchdata.stateful_dataloader import StatefulDataLoader

from areal.api import (
Expand Down Expand Up @@ -126,6 +127,40 @@ def parse_generation_response(
routed_experts=routed_experts,
)

def build_score_request(
self, input_ids: list[int], target_len: int, with_lora: bool, version: int
) -> HttpRequest:
payload: dict[str, Any] = {
"input_ids": input_ids,
"sampling_params": {
"max_new_tokens": 1,
"temperature": 0.0,
},
"return_logprob": True,
"logprob_start_len": max(0, len(input_ids) - target_len - 1),
"top_logprobs_num": 0,
"stream": False,
}
if with_lora:
raise NotImplementedError(
"LoRA scoring request is not supported in SGLang teacher compute_logp yet."
)
return HttpRequest(endpoint="/generate", payload=payload)

def parse_score_response(
self, response: dict[str, Any], target_len: int
) -> list[float]:
meta_info = response.get("meta_info")
if meta_info is None:
raise ValueError("SGLang response missing meta_info for score request")
# SGLang returns [logprob, token_id, ...]
all_logprobs = [float(x[0]) for x in meta_info.get("input_token_logprobs", [])]
if len(all_logprobs) < target_len:
raise ValueError(
f"SGLang returned insufficient input_token_logprobs: {len(all_logprobs)} < {target_len}"
)
return all_logprobs[-target_len:]

def build_disk_weight_update_requests(
self, meta: WeightUpdateMeta
) -> WeightUpdateRequests:
Expand Down Expand Up @@ -502,6 +537,9 @@ def prepare_batch(
dynamic_bs=dynamic_bs,
)

def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor]:
return self._engine.compute_logp(data)

def pause(self):
return self._engine.pause()

Expand Down
44 changes: 44 additions & 0 deletions areal/engine/vllm_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from concurrent.futures import Future
from typing import Any

import torch
from torchdata.stateful_dataloader import StatefulDataLoader

from areal.api import (
Expand Down Expand Up @@ -126,6 +127,46 @@ def parse_generation_response(
stop_reason=stop_reason,
)

def build_score_request(
self, input_ids: list[int], target_len: int, with_lora: bool, version: int
) -> HttpRequest:
payload: dict[str, Any] = {
"prompt": input_ids,
"max_tokens": 1,
"temperature": 0.0,
"logprobs": 1,
"prompt_logprobs": 1,
"echo": True,
}
if with_lora:
raise NotImplementedError(
"LoRA scoring request is not supported in vLLM teacher compute_logp yet."
)
return HttpRequest(endpoint="/v1/completions", payload=payload)

def parse_score_response(
self, response: dict[str, Any], target_len: int
) -> list[float]:
choices = response.get("choices")
if not choices:
raise ValueError("vLLM response missing choices for score request")
prompt_logprobs = choices[0].get("prompt_logprobs")
if prompt_logprobs is None:
raise ValueError("vLLM response missing prompt_logprobs for score request")
if len(prompt_logprobs) < target_len + 1:
raise ValueError(
f"prompt_logprobs too short: got {len(prompt_logprobs)}, need {target_len + 1}"
)
sliced = prompt_logprobs[-target_len:]
token_logps: list[float] = []
for item in sliced:
if not item:
token_logps.append(0.0)
continue
top = next(iter(item.values()))
token_logps.append(float(top["logprob"] if isinstance(top, dict) else top))
return token_logps

def build_disk_weight_update_requests(
self, meta: WeightUpdateMeta
) -> WeightUpdateRequests:
Expand Down Expand Up @@ -465,6 +506,9 @@ def prepare_batch(
dynamic_bs=dynamic_bs,
)

def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor]:
return self._engine.compute_logp(data)

def pause(self):
return self._engine.pause()

Expand Down
40 changes: 40 additions & 0 deletions areal/infra/controller/rollout_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,46 @@ def task_input_generator():
trajectories = [r.trajectory if r is not None else None for r in results]
return [t for t in trajectories if t is not None]

def compute_logp(self, data: list[dict[str, Any]]) -> list[Any]:
"""Compute token log-probabilities for trajectories via remote workers."""
if len(data) == 0:
return []

async def _compute():
indexed_chunks: list[list[int]] = []
tasks = []
n_workers = len(self.workers)
if n_workers == 0:
raise RuntimeError("No workers available for compute_logp.")

for rank, worker in enumerate(self.workers):
idxs = list(range(rank, len(data), n_workers))
if not idxs:
continue
chunk = [data[i] for i in idxs]
indexed_chunks.append(idxs)
tasks.append(
self.scheduler.async_call_engine(
worker_id=worker.id,
method="compute_logp",
engine_name=self._engine_name(rank),
data=chunk,
http_timeout=self.config.request_timeout,
)
)
rpc_results = await asyncio.gather(*tasks)
merged: list[Any] = [None] * len(data)
for idxs, chunk_result in zip(indexed_chunks, rpc_results):
if len(chunk_result) != len(idxs):
raise RuntimeError(
f"compute_logp result length mismatch: got {len(chunk_result)}, expected {len(idxs)}"
)
for out_idx, value in zip(idxs, chunk_result):
merged[out_idx] = value
return merged

return run_async_task(_compute)

async def agenerate(self, req: ModelRequest) -> ModelResponse:
"""Asynchronously generate a response for the given request.

Expand Down
62 changes: 62 additions & 0 deletions areal/infra/remote_inf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
import ray
import requests
import torch
import torch.distributed as dist
import uvloop
from torchdata.stateful_dataloader import StatefulDataLoader
Expand Down Expand Up @@ -174,6 +175,18 @@ def parse_generation_response(
"""
...

def build_score_request(
self, input_ids: list[int], target_len: int, with_lora: bool, version: int
) -> HttpRequest:
"""Build HTTP request for token log-prob scoring."""
...

def parse_score_response(
self, response: dict[str, Any], target_len: int
) -> list[float]:
"""Parse token log-prob scoring response."""
...

def build_disk_weight_update_requests(
self, meta: WeightUpdateMeta
) -> WeightUpdateRequests:
Expand Down Expand Up @@ -502,6 +515,55 @@ def get_version(self):
with self.lock:
return self._version

def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor]:
results: list[torch.Tensor] = []
timeout = self.config.request_timeout
version = self.get_version()
for traj in data:
input_ids = traj["input_ids"]
loss_mask = traj["loss_mask"]
if input_ids.dim() != 2 or loss_mask.dim() != 2:
raise ValueError("input_ids and loss_mask must be 2D tensors")
bs = input_ids.shape[0]
out = torch.zeros_like(loss_mask, dtype=torch.float32)
for i in range(bs):
token_ids = input_ids[i].tolist()
target_len = int(loss_mask[i].sum().item())
if target_len <= 0:
continue
if "attention_mask" in traj:
attn_mask = traj["attention_mask"][i]
active_idx = torch.nonzero(attn_mask, as_tuple=False).squeeze(-1)
token_ids = input_ids[i, active_idx].tolist()
else:
token_ids = input_ids[i].tolist()
server_addr = self.choose_server()
http_req = self.backend.build_score_request(
input_ids=token_ids,
target_len=target_len,
with_lora=self.config.use_lora,
version=version,
)
response = requests.request(
http_req.method,
f"http://{server_addr}{http_req.endpoint}",
json=http_req.payload,
timeout=timeout,
)
response.raise_for_status()
payload = response.json()
token_logps = self.backend.parse_score_response(payload, target_len)
if len(token_logps) != target_len:
raise ValueError(
f"Expected {target_len} token logprobs, got {len(token_logps)}"
)
write_idx = torch.nonzero(loss_mask[i], as_tuple=False).squeeze(-1)
out[i, write_idx] = torch.tensor(
token_logps, device=out.device, dtype=out.dtype
)
results.append(out)
return results

def set_proxy_gateway_addr(self, addr: str) -> None:
"""Set the proxy gateway address.

Expand Down
Loading
Loading