From 4e6ba4ea70b8a6d4c10726a248253bda24bb1d79 Mon Sep 17 00:00:00 2001 From: zackcxb Date: Fri, 8 May 2026 14:03:33 +0000 Subject: [PATCH 01/22] init commit for external agent framework+gateway --- uni_agent/trainer/__init__.py | 2 + uni_agent/trainer/framework/__init__.py | 12 + uni_agent/trainer/framework/entry.py | 55 ++ uni_agent/trainer/framework/framework.py | 581 ++++++++++++++ uni_agent/trainer/framework/helpers.py | 71 ++ .../framework/multi_modal_postprocess.py | 94 +++ uni_agent/trainer/framework/types.py | 63 ++ uni_agent/trainer/gateway/__init__.py | 9 + uni_agent/trainer/gateway/gateway.py | 718 ++++++++++++++++++ uni_agent/trainer/gateway/manager.py | 57 ++ uni_agent/trainer/gateway/runtime.py | 97 +++ uni_agent/trainer/gateway/types.py | 43 ++ 12 files changed, 1802 insertions(+) create mode 100644 uni_agent/trainer/__init__.py create mode 100644 uni_agent/trainer/framework/__init__.py create mode 100644 uni_agent/trainer/framework/entry.py create mode 100644 uni_agent/trainer/framework/framework.py create mode 100644 uni_agent/trainer/framework/helpers.py create mode 100644 uni_agent/trainer/framework/multi_modal_postprocess.py create mode 100644 uni_agent/trainer/framework/types.py create mode 100644 uni_agent/trainer/gateway/__init__.py create mode 100644 uni_agent/trainer/gateway/gateway.py create mode 100644 uni_agent/trainer/gateway/manager.py create mode 100644 uni_agent/trainer/gateway/runtime.py create mode 100644 uni_agent/trainer/gateway/types.py diff --git a/uni_agent/trainer/__init__.py b/uni_agent/trainer/__init__.py new file mode 100644 index 0000000..d8946c8 --- /dev/null +++ b/uni_agent/trainer/__init__.py @@ -0,0 +1,2 @@ +"""Agent framework and gateway packages.""" + diff --git a/uni_agent/trainer/framework/__init__.py b/uni_agent/trainer/framework/__init__.py new file mode 100644 index 0000000..425d2f9 --- /dev/null +++ b/uni_agent/trainer/framework/__init__.py @@ -0,0 +1,12 @@ +from .framework import AgentFramework, OpenAICompatibleAgentFramework +from .helpers import normalize_trajectory_rewards, validate_trajectory +from .types import SessionHandle, Trajectory + +__all__ = [ + "AgentFramework", + "OpenAICompatibleAgentFramework", + "SessionHandle", + "Trajectory", + "normalize_trajectory_rewards", + "validate_trajectory", +] diff --git a/uni_agent/trainer/framework/entry.py b/uni_agent/trainer/framework/entry.py new file mode 100644 index 0000000..883d941 --- /dev/null +++ b/uni_agent/trainer/framework/entry.py @@ -0,0 +1,55 @@ +"""Factory entry for session runtime construction and framework FQN dispatch. + +entry owns gateway-universal wiring so framework subclasses only handle their +own agent runner, reward bridge, and framework-specific config fields. +Phase A: recipe adapter calls this. Phase B: main_ppo_sync.py calls it directly. +""" + +from __future__ import annotations + +from omegaconf import OmegaConf + +from verl.agent.framework.framework import AgentFramework, OpenAICompatibleAgentFramework +from verl.agent.gateway.runtime import GatewayServingRuntime +from verl.utils.import_utils import load_class_from_fqn + +_DEFAULT_FRAMEWORK_CLASS = f"{OpenAICompatibleAgentFramework.__module__}.OpenAICompatibleAgentFramework" +_DEFAULT_GATEWAY_COUNT = 0 +_DEFAULT_TOOL_PARSER = "hermes" + + +async def build_agent_framework( + *, + config, + llm_client, + tokenizer, + processor=None, + replay_buffer, +) -> AgentFramework: + """Build GatewayServingRuntime, then delegate subclass-specific wiring.""" + # TODO(phase-b): switch this to actor_rollout_ref.rollout.agent_framework.* + af_cfg = OmegaConf.select(config, "actor_rollout_ref.rollout.custom.agent_framework", default={}) or {} + + gateway_actor_kwargs = { + "tokenizer": tokenizer, + "processor": processor, + "tool_parser_name": config.actor_rollout_ref.rollout.get("multi_turn", {}).get("format") + or _DEFAULT_TOOL_PARSER, + } + if "host" in af_cfg and af_cfg["host"] is not None: + gateway_actor_kwargs["host"] = af_cfg["host"] + + session_runtime = GatewayServingRuntime( + llm_client=llm_client, + gateway_count=int(af_cfg.get("gateway_count", _DEFAULT_GATEWAY_COUNT)), + gateway_actor_kwargs=gateway_actor_kwargs, + ) + + framework_cls = load_class_from_fqn(str(af_cfg.get("framework_class_fqn", _DEFAULT_FRAMEWORK_CLASS))) + return await framework_cls.from_config( + config=config, + session_runtime=session_runtime, + tokenizer=tokenizer, + processor=processor, + replay_buffer=replay_buffer, + ) diff --git a/uni_agent/trainer/framework/framework.py b/uni_agent/trainer/framework/framework.py new file mode 100644 index 0000000..2ebcdbc --- /dev/null +++ b/uni_agent/trainer/framework/framework.py @@ -0,0 +1,581 @@ +from __future__ import annotations + +import asyncio +import inspect +import logging +from abc import ABC, abstractmethod +from dataclasses import replace +from functools import partial +from uuid import uuid4 + +from omegaconf import OmegaConf +import torch +from tensordict import TensorDict +from tensordict.tensorclass import NonTensorData, NonTensorStack + +from verl.tools.utils.tool_registry import initialize_tools_from_config +from verl.trainer.ppo.reward import get_custom_reward_fn +from verl.utils.import_utils import load_class_from_fqn +from verl.utils.transferqueue_utils import tq +from verl.utils import tensordict_utils as tu +from verl.utils.model import compute_position_id_with_mask + +from .multi_modal_postprocess import compute_multi_modal_inputs, compute_position_ids +from .types import RewardFn, SessionRewardContext, SessionRuntime, Trajectory + +logger = logging.getLogger(__name__) + + +class AgentFramework(ABC): + """Abstract base for framework implementations. + + Phase A: entry.py owns session runtime construction and passes it in. + Subclasses receive shared entry resources plus the raw config for + subclass-specific field parsing. + + Phase B: trainer inlines entry; this from_config contract remains. + """ + + @classmethod + @abstractmethod + async def from_config( + cls, + *, + config, + session_runtime, + tokenizer=None, + processor=None, + replay_buffer, + ) -> "AgentFramework": + ... + + @abstractmethod + async def generate_sequences(self, prompts: TensorDict) -> None: + """Run agent sessions and write finalized trajectories to TransferQueue.""" + ... + + +def _to_long_tensor(values) -> torch.Tensor: + return torch.tensor(list(values), dtype=torch.long) + + +def _to_float_tensor(values) -> torch.Tensor: + return torch.tensor(list(values), dtype=torch.float32) + + +def _short_failure_reason(error: BaseException) -> str: + message = str(error) + if not message: + message = error.__class__.__name__ + return message[:512] + + +_TQ_NESTED_SEQUENCE_FIELDS = { + "prompts", + "responses", + "response_mask", + "loss_mask", + "input_ids", + "attention_mask", + "position_ids", + "rollout_log_probs", + "rm_scores", + "teacher_logprobs", + "teacher_ids", +} + + +def _list_of_tq_fields_to_tensordict(fields: list[dict[str, object]]) -> TensorDict: + td = tu.list_of_dict_to_tensordict(fields) + for key in _TQ_NESTED_SEQUENCE_FIELDS: + if key not in fields[0]: + continue + values = [field[key] for field in fields] + if not all(isinstance(value, torch.Tensor) for value in values): + continue + ragged_idx = 2 if key == "position_ids" and values[0].dim() == 2 else None + td[key] = tu.nested_tensor_from_tensor_list(values, ragged_idx=ragged_idx) + return td + + +def _build_reward_fn(config, tokenizer): + # Phase A keeps reward configuration in the existing VERL path and only + # bridges framework trajectories to the raw custom_reward_function + # signature. Phase B should reuse the main reward manager path directly + # instead of growing a parallel agent-framework reward config surface. + custom_reward_fn = get_custom_reward_fn(config) + if custom_reward_fn is None: + return None + + async def reward_fn(ctx): + data_source = ctx.sample_fields.get("data_source") + reward_model = ctx.sample_fields.get("reward_model") + if isinstance(reward_model, dict): + ground_truth = reward_model.get("ground_truth") + elif reward_model is None: + ground_truth = None + else: + ground_truth = getattr(reward_model, "ground_truth", None) + extra_info = ctx.sample_fields.get("extra_info") + scores = [] + for trajectory in ctx.trajectories: + response_text = tokenizer.decode(trajectory.response_ids, skip_special_tokens=True) + score = custom_reward_fn(data_source, response_text, ground_truth, extra_info) + if inspect.isawaitable(score): + score = await score + scores.append(score) + return scores + + return reward_fn + + +class OpenAICompatibleAgentFramework(AgentFramework): + """Reference AgentFramework implementation for OpenAI-compatible agent loops. + + Each sample in the batch is run as an independent session: the agent + communicates with the Gateway via standard ``/v1/chat/completions`` + requests, and the Gateway collects token-level trajectories. After + finalization, ``reward_fn`` scores the session's trajectories and the + framework writes them to the TransferQueue schema consumed by sync training. + """ + + def __init__( + self, + session_runtime: SessionRuntime, + agent_runner, + reward_fn: RewardFn | None, + *, + processor=None, + replay_buffer=None, + rollout_config=None, + completion_timeout: float | None = 30.0, + wait_for_completion_after_agent_run: bool = False, + ): + self.session_runtime = session_runtime + self.agent_runner = agent_runner + self.reward_fn = reward_fn + self._processor = processor + # TODO(phase-b): once trainer constructs framework directly, these become + # constructor-required and no transitional dual-path is needed. + self._replay_buffer = replay_buffer + self._rollout_config = rollout_config + self.completion_timeout = completion_timeout + self.wait_for_completion_after_agent_run = wait_for_completion_after_agent_run + + @classmethod + async def from_config( + cls, + *, + config, + session_runtime, + tokenizer=None, + processor=None, + replay_buffer, + ) -> "OpenAICompatibleAgentFramework": + if tokenizer is None: + raise ValueError("OpenAICompatibleAgentFramework requires tokenizer for reward bridge") + + # TODO(phase-b): switch this to actor_rollout_ref.rollout.agent_framework.* + af_cfg = OmegaConf.select(config, "actor_rollout_ref.rollout.custom.agent_framework", default={}) or {} + agent_runner_fqn = af_cfg.get("agent_runner_fqn") + if not agent_runner_fqn: + raise ValueError("actor_rollout_ref.rollout.custom.agent_framework.agent_runner_fqn is required") + + agent_runner = load_class_from_fqn(str(agent_runner_fqn), description="agent runner") + runner_kwargs = dict( + OmegaConf.to_container(OmegaConf.create(af_cfg.get("agent_runner_kwargs", {})), resolve=True) or {} + ) + tool_config_path = af_cfg.get("tool_config_path") + if tool_config_path: + tool_config = initialize_tools_from_config(tool_config_path) + if not tool_config: + raise ValueError(f"tool config did not initialize any tools: {tool_config_path}") + runner_kwargs["tool_config"] = tool_config + if runner_kwargs: + agent_runner = partial(agent_runner, **runner_kwargs) + + # TODO(phase-x): when reward_loop_worker_handles is available from + # trainer, accept reward_fn as an entry-injected resource and skip + # bridge construction. Bridge remains available for simple recipes + # that supply reward.custom_reward_function directly. + reward_fn = _build_reward_fn(config, tokenizer) + + completion_timeout = af_cfg.get("completion_timeout_seconds") + return cls( + session_runtime=session_runtime, + agent_runner=agent_runner, + reward_fn=reward_fn, + processor=processor, + replay_buffer=replay_buffer, + rollout_config=config.actor_rollout_ref.rollout, + completion_timeout=completion_timeout, + wait_for_completion_after_agent_run=completion_timeout is not None, + ) + + async def generate_sequences(self, prompts: TensorDict) -> None: + """Run rollout-manager generation and write outputs into TransferQueue.""" + if self._replay_buffer is None: + raise RuntimeError("OpenAICompatibleAgentFramework requires replay_buffer for generate_sequences") + if self._rollout_config is None: + raise RuntimeError("OpenAICompatibleAgentFramework requires rollout_config for generate_sequences") + + global_steps = tu.get(prompts, "global_steps") + if global_steps is None: + raise ValueError("OpenAICompatibleAgentFramework requires prompts['global_steps']") + + partition_id = "val" if "validate" in prompts.keys() else "train" + num_sessions = self._num_sessions_for_partition(partition_id) + + uids = tu.get(prompts, "uid") + if uids is None: + raise ValueError("OpenAICompatibleAgentFramework requires prompts['uid'] for replay_buffer") + uid_values = uids.tolist() if hasattr(uids, "tolist") else list(uids) + self._replay_buffer.add( + partition_id, + {str(uid): {"global_steps": global_steps, "status": "running"} for uid in uid_values}, + ) + + stats = await self._generate_to_tq( + prompts, + global_steps=global_steps, + partition_id=partition_id, + num_sessions=num_sessions, + ) + logger.info( + "generate_sequences summary: num_input_prompts=%s num_success_sessions=%s " + "num_failed_sessions=%s num_success_outputs=%s num_failed_uids=%s failure_reasons=%s", + stats["num_input_prompts"], + stats["num_success_sessions"], + stats["num_failed_sessions"], + stats["num_success_outputs"], + stats["num_failed_uids"], + stats["failure_reasons"][:3], + ) + if stats["num_success_outputs"] == 0: + raise RuntimeError( + f"All rollouts failed at global_steps={global_steps}. " + f"failures={stats['num_failed_uids']}/{stats['num_input_prompts']}" + ) + return None + + def _num_sessions_for_partition(self, partition_id: str) -> int: + if partition_id == "val": + val_kwargs = self._rollout_config.get("val_kwargs", {}) + return int(val_kwargs.get("n")) + return int(self._rollout_config.get("n")) + + async def _generate_to_tq( + self, + prompts: TensorDict, + *, + global_steps: int, + partition_id: str, + num_sessions: int = 1, + ) -> dict: + """Run agent sessions and write finalized trajectories to TransferQueue. + + This is the TransferQueue-oriented sibling of ``generate_sequences``. + It preserves the same session lifecycle, but writes each finalized + trajectory with the key/tag/field schema consumed by + ``verl.trainer.main_ppo_sync`` instead of returning a batch. + """ + assert len(prompts) > 0, "generate_sequences requires a non-empty batch" + if num_sessions <= 0: + raise ValueError(f"num_sessions must be positive, got {num_sessions}") + + raw_prompts = tu.get(prompts, "raw_prompt") + if raw_prompts is None: + raise ValueError("OpenAICompatibleAgentFramework requires prompts['raw_prompt']") + + tasks = [ + self._run_prompt_to_replay_buffer( + prompts=prompts, + raw_prompt=raw_prompts[sample_index], + sample_index=sample_index, + global_steps=global_steps, + partition_id=partition_id, + num_sessions=num_sessions, + ) + for sample_index in range(len(prompts)) + ] + outcomes = await asyncio.gather(*tasks, return_exceptions=True) + + failure_reasons: list[str] = [] + stats = { + "num_input_prompts": len(prompts), + "num_success_sessions": 0, + "num_failed_sessions": 0, + "num_success_outputs": 0, + "num_failed_uids": 0, + "failure_reasons": failure_reasons, + } + for outcome in outcomes: + if isinstance(outcome, Exception): + stats["num_failed_sessions"] += num_sessions + stats["num_failed_uids"] += 1 + failure_reasons.append(_short_failure_reason(outcome)) + continue + stats["num_success_sessions"] += outcome["num_success_sessions"] + stats["num_failed_sessions"] += outcome["num_failed_sessions"] + stats["num_success_outputs"] += outcome["num_success_outputs"] + stats["num_failed_uids"] += outcome["num_failed_uids"] + failure_reasons.extend(outcome["failure_reasons"]) + return stats + + async def _run_prompt_to_replay_buffer( + self, + *, + prompts: TensorDict, + raw_prompt, + sample_index: int, + global_steps: int, + partition_id: str, + num_sessions: int, + ) -> dict: + sample_fields = self._extract_sample_fields(prompts=prompts, sample_index=sample_index) + uid = sample_fields.get("uid") + if uid is None: + raise ValueError("OpenAICompatibleAgentFramework requires prompts['uid'] for TransferQueue output") + uid = str(uid) + + tasks = [ + self._run_session( + prompts=prompts, + raw_prompt=raw_prompt, + sample_index=sample_index, + session_id=self._build_session_id( + prompts=prompts, + sample_index=sample_index, + session_index=session_index, + ), + runner_kwargs=( + {"tools_kwargs": sample_fields["tools_kwargs"]} if "tools_kwargs" in sample_fields else {} + ), + ) + for session_index in range(num_sessions) + ] + outcomes = await asyncio.gather(*tasks, return_exceptions=True) + + success_sessions = 0 + failed_sessions = 0 + success_outputs = 0 + failure_reasons: list[str] = [] + for session_index, outcome in enumerate(outcomes): + if isinstance(outcome, Exception): + failed_sessions += 1 + failure_reasons.append(_short_failure_reason(outcome)) + continue + + trajectories, session_sample_fields = outcome + if not trajectories: + failed_sessions += 1 + failure_reasons.append(f"empty trajectories for uid={uid} session_id={session_index}") + continue + + success_sessions += 1 + await self._write_session_trajectories_to_tq( + uid=uid, + session_id=session_index, + trajectories=trajectories, + sample_fields=session_sample_fields, + global_steps=global_steps, + partition_id=partition_id, + ) + success_outputs += len(trajectories) + + if success_sessions > 0: + await tq.async_kv_put(key=uid, partition_id=partition_id, tag={"status": "finished"}) + failed_uids = 0 + else: + await tq.async_kv_put(key=uid, partition_id=partition_id, tag={"status": "failure"}) + failed_uids = 1 + + return { + "num_success_sessions": success_sessions, + "num_failed_sessions": failed_sessions, + "num_success_outputs": success_outputs, + "num_failed_uids": failed_uids, + "failure_reasons": failure_reasons, + } + + async def _run_session( + self, + *, + prompts: TensorDict, + raw_prompt, + sample_index: int, + session_id: str | None = None, + runner_kwargs: dict[str, object] | None = None, + ) -> tuple[list[Trajectory], dict[str, object]]: + session_id = session_id or self._build_session_id(prompts=prompts, sample_index=sample_index) + sample_fields = self._extract_sample_fields(prompts=prompts, sample_index=sample_index) + session = await self.session_runtime.create_session(session_id) + try: + await self.agent_runner( + raw_prompt=raw_prompt, + session=session, + sample_index=sample_index, + **(runner_kwargs or {}), + ) + if self.wait_for_completion_after_agent_run: + await self.session_runtime.wait_for_completion(session_id, timeout=self.completion_timeout) + session_trajectories = await self.session_runtime.finalize_session(session_id) + except Exception: + await self.session_runtime.abort_session(session_id) + raise + + # Score the session's trajectories immediately after finalization, + # consistent with VERL's per-sample reward path. + if self.reward_fn is None: + return session_trajectories, sample_fields + + normalized_scores = await self._score_trajectories(session_trajectories, sample_fields) + return ( + [ + replace(traj, reward_score=score) + for traj, score in zip(session_trajectories, normalized_scores, strict=True) + ], + sample_fields, + ) + + async def _score_trajectories( + self, + session_trajectories: list[Trajectory], + sample_fields: dict[str, object], + ) -> list[float]: + assert self.reward_fn is not None + ctx = SessionRewardContext(trajectories=session_trajectories, sample_fields=sample_fields) + scores = self.reward_fn(ctx) + if inspect.isawaitable(scores): + scores = await scores + if len(scores) != len(session_trajectories): + raise ValueError( + f"reward_fn returned {len(scores)} scores for {len(session_trajectories)} trajectories" + ) + normalized_scores: list[float] = [] + for _, score in zip(session_trajectories, scores, strict=True): + if score is None: + raise ValueError( + "reward_fn must return a score for every trajectory; " + f"got None for uid={sample_fields.get('uid')}" + ) + normalized_scores.append(float(score)) + return normalized_scores + + def _extract_sample_fields(self, *, prompts: TensorDict, sample_index: int) -> dict[str, object]: + sample_fields = {} + for key, value in prompts.items(): + if isinstance(value, torch.Tensor): + sample_fields[key] = value if value.ndim == 0 else value[sample_index] + elif isinstance(value, NonTensorStack): + sample_fields[key] = tu.get(prompts, key)[sample_index] + else: + assert isinstance(value, NonTensorData) + sample_fields[key] = value.data + return sample_fields + + async def _write_session_trajectories_to_tq( + self, + *, + uid: str, + session_id: int, + trajectories: list[Trajectory], + sample_fields: dict[str, object], + global_steps: int, + partition_id: str, + ) -> None: + keys = [] + fields = [] + tags = [] + for index, trajectory in enumerate(trajectories): + field, tag = self._trajectory_to_tq_field_and_tag( + trajectory=trajectory, + sample_fields=sample_fields, + session_id=session_id, + global_steps=global_steps, + ) + keys.append(f"{uid}_{session_id}_{index}") + fields.append(field) + tags.append(tag) + + await tq.async_kv_batch_put( + keys=keys, + fields=_list_of_tq_fields_to_tensordict(fields), + tags=tags, + partition_id=partition_id, + ) + + def _trajectory_to_tq_field_and_tag( + self, + *, + trajectory: Trajectory, + sample_fields: dict[str, object], + session_id: int, + global_steps: int, + ) -> tuple[dict[str, object], dict[str, object]]: + prompts = _to_long_tensor(trajectory.prompt_ids) + responses = _to_long_tensor(trajectory.response_ids) + response_mask = _to_long_tensor(trajectory.response_mask) + input_ids = torch.cat([prompts, responses], dim=0) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + multi_modal_inputs = compute_multi_modal_inputs( + self._processor, + input_ids.unsqueeze(0), + trajectory.multi_modal_data, + ) + if self._processor is None: + position_ids = compute_position_id_with_mask(attention_mask.unsqueeze(0)).squeeze(0) + else: + position_ids = compute_position_ids( + self._processor, + input_ids.unsqueeze(0), + attention_mask.unsqueeze(0), + multi_modal_inputs, + ).squeeze(0) + + field: dict[str, object] = { + "prompts": prompts, + "responses": responses, + "response_mask": response_mask, + "loss_mask": response_mask, + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "multi_modal_inputs": multi_modal_inputs, + } + if trajectory.response_logprobs is not None: + field["rollout_log_probs"] = _to_float_tensor(trajectory.response_logprobs) + if trajectory.routed_experts is not None: + field["routed_experts"] = ( + torch.from_numpy(trajectory.routed_experts.copy()) + if hasattr(trajectory.routed_experts, "copy") and not isinstance(trajectory.routed_experts, torch.Tensor) + else trajectory.routed_experts + ) + if trajectory.reward_score is not None: + rm_scores = torch.zeros_like(responses, dtype=torch.float32) + if responses.numel() > 0: + rm_scores[-1] = float(trajectory.reward_score) + field["rm_scores"] = rm_scores + + field.update(trajectory.extra_fields) + field.pop("multi_modal_data", None) + for key in ("uid", "raw_prompt", "data_source", "reward_model", "extra_info", "tools_kwargs", "agent_name"): + if key in sample_fields: + field[key] = sample_fields[key] + field["session_id"] = session_id + field["global_steps"] = global_steps + field["num_turns"] = torch.tensor(int(trajectory.num_turns), dtype=torch.long) + + prompt_len = prompts.size(0) + response_len = responses.size(0) + tag = { + "global_steps": global_steps, + "status": "success", + "prompt_len": prompt_len, + "response_len": response_len, + "seq_len": prompt_len + response_len, + } + return field, tag + + def _build_session_id(self, prompts: TensorDict, sample_index: int, session_index: int = 0) -> str: + return f"session-{sample_index}-{session_index}-{uuid4().hex}" diff --git a/uni_agent/trainer/framework/helpers.py b/uni_agent/trainer/framework/helpers.py new file mode 100644 index 0000000..29e0160 --- /dev/null +++ b/uni_agent/trainer/framework/helpers.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import replace +from typing import Any + +import numpy as np +import torch + +from .types import Trajectory + + +def _resolve_trajectory_value(value: Any, index: int, count: int) -> Any: + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray, list)): + # Keep tuple-like values broadcastable only when they are not trajectory-aligned containers. + return value + if isinstance(value, list): + if len(value) != count: + raise ValueError(f"reward_info sequence length must match trajectories: {len(value)} != {count}") + return value[index] + return value + + +def normalize_trajectory_rewards( + trajectories: Sequence[Trajectory], + reward_info: Mapping[str, Any] | None = None, +) -> list[Trajectory]: + normalized: list[Trajectory] = [] + count = len(trajectories) + + for index, trajectory in enumerate(trajectories): + merged_reward_info = dict(trajectory.reward_info) + if reward_info is not None: + for key, value in reward_info.items(): + merged_reward_info[key] = _resolve_trajectory_value(value, index=index, count=count) + + if trajectory.reward_score is None: + raise ValueError( + f"Trajectory at index {index} has no reward_score. " + "reward_fn must return a score for every trajectory." + ) + + normalized.append(replace(trajectory, reward_info=merged_reward_info)) + + return normalized + + +def validate_trajectory(trajectory: Trajectory) -> Trajectory: + if len(trajectory.response_ids) != len(trajectory.response_mask): + raise ValueError("response_mask length must match response_ids length") + + if trajectory.response_logprobs is not None and len(trajectory.response_logprobs) != len(trajectory.response_ids): + raise ValueError("response_logprobs length must match response_ids length") + + if trajectory.num_turns < 0: + raise ValueError("num_turns must be non-negative") + if trajectory.routed_experts is not None: + if isinstance(trajectory.routed_experts, np.ndarray): + routed_experts = trajectory.routed_experts + elif isinstance(trajectory.routed_experts, torch.Tensor): + routed_experts = trajectory.routed_experts + else: + raise TypeError(f"Unsupported routed_experts type: {type(trajectory.routed_experts)}") + + if routed_experts.ndim != 3: + raise ValueError("routed_experts must have shape [total_tokens, num_layers, topk]") + expected_length = len(trajectory.prompt_ids) + len(trajectory.response_ids) + if routed_experts.shape[0] != expected_length: + raise ValueError("routed_experts token dimension must match prompt_ids + response_ids") + + return trajectory diff --git a/uni_agent/trainer/framework/multi_modal_postprocess.py b/uni_agent/trainer/framework/multi_modal_postprocess.py new file mode 100644 index 0000000..67694cf --- /dev/null +++ b/uni_agent/trainer/framework/multi_modal_postprocess.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from verl.utils.model import compute_position_id_with_mask + +# Behavior mirrors legacy `_compute_multi_modal_inputs` / `_compute_position_ids` +# in `verl/experimental/agent_loop/agent_loop.py`. + + +def _split_videos_and_metadata(videos: list[Any] | None) -> tuple[list[Any] | None, list[Any] | None]: + if not videos: + return videos, None + + first_video = videos[0] + if isinstance(first_video, tuple) and len(first_video) == 2: + split_videos, video_metadata = zip(*videos, strict=False) + return list(split_videos), list(video_metadata) + + return list(videos), None + + +def _to_plain_tensor_dict(processor_output) -> dict[str, torch.Tensor]: + if hasattr(processor_output, "convert_to_tensors"): + processor_output = processor_output.convert_to_tensors("pt") + return dict(processor_output) + + +def compute_multi_modal_inputs( + processor, + input_ids: torch.Tensor, + multi_modal_data: dict[str, Any] | None, +) -> dict[str, torch.Tensor]: + """Return processor-produced multimodal tensors for a single sample.""" + if processor is None or not multi_modal_data: + return {} + + images = multi_modal_data.get("images") + videos, video_metadata = _split_videos_and_metadata(multi_modal_data.get("videos")) + current_text = processor.tokenizer.decode(input_ids.squeeze(0), skip_special_tokens=True) + multi_modal_inputs = _to_plain_tensor_dict( + processor( + text=[current_text], + images=images, + videos=videos, + video_metadata=video_metadata, + return_tensors="pt", + do_sample_frames=False, + ) + ) + multi_modal_inputs.pop("input_ids", None) + multi_modal_inputs.pop("attention_mask", None) + + image_grid_thw = multi_modal_inputs.get("image_grid_thw") + if image_grid_thw is not None: + images_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]) + multi_modal_inputs["images_seqlens"] = images_seqlens + return multi_modal_inputs + + +def compute_position_ids( + processor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + multi_modal_inputs: dict[str, torch.Tensor], +) -> torch.Tensor: + """Return text-only or multimodal-aware position ids for a single sample.""" + if processor is None: + return compute_position_id_with_mask(attention_mask) + + multi_modal_kwargs = { + "image_grid_thw": multi_modal_inputs.get("image_grid_thw"), + "video_grid_thw": multi_modal_inputs.get("video_grid_thw"), + } + if multi_modal_inputs.get("mm_token_type_ids") is not None: + mm_token_type_ids = torch.zeros_like(input_ids) + mm_token_type_ids[0][input_ids[0] == processor.image_token_id] = 1 + mm_token_type_ids[0][input_ids[0] == processor.video_token_id] = 2 + multi_modal_kwargs["mm_token_type_ids"] = mm_token_type_ids + + vision_position_ids, _ = processor.get_rope_index( + input_ids=input_ids, + attention_mask=attention_mask, + **multi_modal_kwargs, + ) + vision_position_ids = vision_position_ids.transpose(0, 1) + + valid_mask = attention_mask[0].bool() + text_position_ids = torch.ones((1, input_ids.shape[1]), dtype=torch.long, device=input_ids.device) + text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item(), device=input_ids.device) + text_position_ids = text_position_ids.unsqueeze(0) + return torch.cat((text_position_ids, vision_position_ids), dim=1) diff --git a/uni_agent/trainer/framework/types.py b/uni_agent/trainer/framework/types.py new file mode 100644 index 0000000..9c5545b --- /dev/null +++ b/uni_agent/trainer/framework/types.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any, Protocol + +import numpy as np +import torch + + +@dataclass +class SessionHandle: + session_id: str + base_url: str | None = None + + +@dataclass +class Trajectory: + prompt_ids: list[int] + response_ids: list[int] + response_mask: list[int] + response_logprobs: list[float] | None = None + reward_info: dict[str, Any] = field(default_factory=dict) + reward_score: float | None = None + num_turns: int = 0 + routed_experts: torch.Tensor | np.ndarray | None = None + multi_modal_data: dict[str, Any] | None = None + extra_fields: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class SessionRewardContext: + """Context passed to ``reward_fn`` after a session is finalized. + + A single session may produce multiple trajectories (e.g. when the agent + switches conversation context mid-session). ``reward_fn`` receives all of + them together so the implementor can choose the session-to-trajectory + scoring policy, but it must return one score per trajectory. + + ``sample_fields`` carries per-sample dataset fields (``data_source``, + ``reward_model.ground_truth``, ``extra_info``, ...) — the same dict that + ``AgentLoopWorker._compute_score`` forwards as ``kwargs`` to the reward + worker. + """ + + trajectories: list[Trajectory] + sample_fields: dict[str, Any] = field(default_factory=dict) + +RewardFn = Callable[[SessionRewardContext], Awaitable[list[float]] | list[float]] + + +class SessionRuntime(Protocol): + """Protocol for gateway-backed session lifecycle. + + Used by OpenAICompatibleAgentFramework to decouple the framework from the + concrete AsyncLLMServerManager / GatewayManager implementation, making it + testable without a Ray cluster. + """ + + async def create_session(self, session_id: str, **kwargs) -> SessionHandle: ... + async def finalize_session(self, session_id: str) -> list[Trajectory]: ... + async def abort_session(self, session_id: str) -> None: ... + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: ... diff --git a/uni_agent/trainer/gateway/__init__.py b/uni_agent/trainer/gateway/__init__.py new file mode 100644 index 0000000..8fb072a --- /dev/null +++ b/uni_agent/trainer/gateway/__init__.py @@ -0,0 +1,9 @@ +from .gateway import GatewayActor +from .manager import GatewayManager +from .runtime import GatewayServingRuntime + +__all__ = [ + "GatewayActor", + "GatewayManager", + "GatewayServingRuntime", +] diff --git a/uni_agent/trainer/gateway/gateway.py b/uni_agent/trainer/gateway/gateway.py new file mode 100644 index 0000000..46009d5 --- /dev/null +++ b/uni_agent/trainer/gateway/gateway.py @@ -0,0 +1,718 @@ +from __future__ import annotations + +import asyncio +import json +import time +from dataclasses import replace +from typing import Any +from uuid import uuid4 + +import ray +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +from verl.agent.framework.types import SessionHandle, Trajectory +from verl.agent.gateway.types import GatewaySessionState, SessionPhase, TrajectoryBuffer +from verl.experimental.agent_loop.tool_parser import ToolParser +from verl.utils.chat_template import apply_chat_template as _apply_chat_template, initialize_system_prompt +from verl.utils.tokenizer import normalize_token_ids +from verl.workers.rollout.utils import run_uvicorn + + +class MalformedRequestError(ValueError): + pass + + +_DEFAULT_ALLOWED_REQUEST_SAMPLING_PARAM_KEYS = frozenset({ + "temperature", + "top_p", + "top_k", + "max_tokens", +}) + + +# TODO: double-check if all these validations/normalization are necessary +# Make sure they don't alter messages in unexpected ways. +def _normalize_message_content(content: Any) -> Any: + """Normalize message content: coerce None to empty string, validate type.""" + if isinstance(content, list | dict | str): + return content + if content is None: + return "" + raise MalformedRequestError(f"Unsupported content type: {type(content).__name__}") + + +def _validate_tool_calls(tool_calls: Any) -> None: + """Validate tool_calls structure. Does not modify content.""" + if not isinstance(tool_calls, list): + raise MalformedRequestError("tool_calls must be a list") + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + raise MalformedRequestError("tool_calls entries must be objects") + function = tool_call.get("function") + if not isinstance(function, dict): + raise MalformedRequestError("tool_call.function must be an object") + + +def _normalize_message(message: Any) -> dict[str, Any]: + """Normalize a single message: validate structure, coerce types, filter to known fields. + + Constructs a new dict with only role/content/tool_calls/tool_call_id. + This ensures prefix comparison is not affected by extraneous fields. + """ + if not isinstance(message, dict): + raise MalformedRequestError("messages entries must be objects") + + role = message.get("role") + if not isinstance(role, str) or not role: + raise MalformedRequestError("message.role must be a non-empty string") + + normalized: dict[str, Any] = { + "role": role, + "content": _normalize_message_content(message.get("content", "")), + } + if "name" in message: + name = message["name"] + if not isinstance(name, str): + raise MalformedRequestError("message.name must be a string") + normalized["name"] = name + if "tool_calls" in message: + _validate_tool_calls(message["tool_calls"]) + normalized["tool_calls"] = list(message["tool_calls"]) + if "tool_call_id" in message: + normalized["tool_call_id"] = str(message["tool_call_id"]) + return normalized + + +def _validate_tools(tools: Any) -> list[dict[str, Any]] | None: + """Validate tools structure. Does not modify content.""" + if tools is None: + return None + if not isinstance(tools, list): + raise MalformedRequestError("tools must be a list") + for tool in tools: + if not isinstance(tool, dict): + raise MalformedRequestError("tools entries must be objects") + return tools + + +def _normalize_request_context(payload: dict[str, Any]) -> dict[str, Any]: + """Normalize and validate the request payload, extracting messages and tools. + """ + messages = payload.get("messages") + if not isinstance(messages, list) or not messages: + raise MalformedRequestError("messages must be non-empty") + return { + "messages": [_normalize_message(message) for message in messages], + "tools": _validate_tools(payload.get("tools")), + } + + +def _build_sampling_params( + payload: dict[str, Any], + *, + base_sampling_params: dict[str, Any], + allowed_request_sampling_param_keys: frozenset[str], +) -> dict[str, Any]: + sampling_params = dict(base_sampling_params) + for key in allowed_request_sampling_param_keys: + if key in payload: + sampling_params[key] = payload[key] + return sampling_params + + +def _canonicalize_tool_arguments_for_comparison(arguments: Any) -> tuple[str, Any]: + if isinstance(arguments, dict | list): + return ("json", arguments) + if isinstance(arguments, str): + try: + return ("json", json.loads(arguments)) + except json.JSONDecodeError: + return ("raw", arguments) + return ("raw", arguments) + + +def _canonicalize_message_for_prefix_comparison(message: dict[str, Any]) -> dict[str, Any]: + normalized = dict(message) + tool_calls = normalized.get("tool_calls") + if not isinstance(tool_calls, list): + return normalized + + normalized_tool_calls: list[dict[str, Any]] = [] + for tool_call in tool_calls: + normalized_tool_call = dict(tool_call) + function = normalized_tool_call.get("function") + if isinstance(function, dict) and "arguments" in function: + normalized_function = dict(function) + normalized_function["arguments"] = _canonicalize_tool_arguments_for_comparison(function["arguments"]) + normalized_tool_call["function"] = normalized_function + normalized_tool_calls.append(normalized_tool_call) + normalized["tool_calls"] = normalized_tool_calls + return normalized + + +def _is_message_prefix(prefix: list[dict[str, Any]], messages: list[dict[str, Any]]) -> bool: + if len(prefix) > len(messages): + return False + return [ + _canonicalize_message_for_prefix_comparison(message) + for message in prefix + ] == [ + _canonicalize_message_for_prefix_comparison(message) + for message in messages[: len(prefix)] + ] + + +def _is_request_context_prefix( + *, + session: GatewaySessionState, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, +) -> bool: + if session.request_tools != tools: + return False + # TODO: dict equality is not token-level equivalent — two tool schemas with + # different key order compare equal in Python but may tokenize differently. + # This could cause a false prefix match on the tools path. Low practical + # risk (same agent rarely reorders keys within a session), but worth noting. + #TODO: need to improve the prefix check logic, e.g.,how to handle tool lists and multimodal data + return _is_message_prefix(session.message_history, messages) + + +def _copy_trajectory_buffer(buffer: TrajectoryBuffer | None) -> TrajectoryBuffer | None: + if buffer is None: + return None + return TrajectoryBuffer( + prompt_ids=list(buffer.prompt_ids), + response_ids=list(buffer.response_ids), + response_mask=list(buffer.response_mask), + response_logprobs=list(buffer.response_logprobs), + ) + + +def _count_chat_turns(message_history: list[dict[str, Any]]) -> int: + """Count chat turns consistent with ToolAgentLoop semantics. + + ToolAgentLoop computes: user_turns + assistant_turns + 1 (the +1 accounts + for the initial prompt). We count user + assistant role messages and add 1. + System and tool messages are excluded. + """ + return sum(1 for m in message_history if m.get("role") in ("user", "assistant")) + 1 + + +def _materialize_response_logprobs(buffer: TrajectoryBuffer) -> list[float] | None: + if not buffer.response_logprobs: + return None + return list(buffer.response_logprobs) + + +def _build_multi_modal_trajectory_data( + image_data: list[Any] | None, + video_data: list[Any] | None, +) -> dict[str, Any] | None: + multi_modal_data: dict[str, Any] = {} + if image_data: + multi_modal_data["images"] = list(image_data) + if video_data: + multi_modal_data["videos"] = list(video_data) + return multi_modal_data or None + + +class _GatewayActor: + def __init__( + self, + tokenizer, + backend, + host: str | None = None, + *, + processor=None, + vision_info_extractor=None, + vision_info_extractor_kwargs: dict[str, Any] | None = None, + tool_parser_name: str | None = None, + apply_chat_template_kwargs: dict[str, Any] | None = None, + base_sampling_params: dict[str, Any] | None = None, + allowed_request_sampling_param_keys: set[str] | frozenset[str] | None = None, + ): + # Same pattern as vllm_async_server.py / async_sglang_server.py: + # use the node's routable IP for both bind and URL by default. + self._server_address = host if host is not None else ray.util.get_node_ip_address() + self._tokenizer = tokenizer + self._processor = processor + self._backend = backend + self._vision_info_extractor = vision_info_extractor or self._default_vision_info_extractor + self._vision_info_extractor_kwargs = dict(vision_info_extractor_kwargs or {}) + self._apply_chat_template_kwargs = apply_chat_template_kwargs or {} + self._base_sampling_params = dict(base_sampling_params or {}) + allowed_keys = ( + _DEFAULT_ALLOWED_REQUEST_SAMPLING_PARAM_KEYS + if allowed_request_sampling_param_keys is None + else frozenset(allowed_request_sampling_param_keys) + ) + self._allowed_request_sampling_param_keys = allowed_keys + self._system_prompt = initialize_system_prompt( + tokenizer, + **self._apply_chat_template_kwargs, + ) + self._tool_parser = ( + ToolParser.get_tool_parser(tool_parser_name, tokenizer) if tool_parser_name else None + ) + self._sessions: dict[str, GatewaySessionState] = {} + self._app = FastAPI() + self._server_port: int | None = None + self._server_task: asyncio.Task | None = None + self._server_base_url: str | None = None + self._register_routes() + + def _register_routes(self) -> None: + @self._app.post("/sessions/{session_id}/v1/chat/completions") + async def _chat_completions(session_id: str, request: Request): + payload = await request.json() + return await self._handle_chat_completions(session_id=session_id, payload=payload) + + @self._app.post("/sessions/{session_id}/complete") + async def _complete(session_id: str, request: Request): + payload = await request.json() + reward_info = payload.get("reward_info") + try: + await self.complete_session(session_id=session_id, reward_info=reward_info) + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + return JSONResponse({"status": "ok"}) + + def _require_started(self) -> None: + if self._server_base_url is None: + raise RuntimeError("GatewayActor.start() must be called before session creation") + + def _get_session(self, session_id: str) -> GatewaySessionState: + session = self._sessions.get(session_id) + if session is None: + raise KeyError(f"Unknown session_id: {session_id}") + return session + + def _set_phase(self, session: GatewaySessionState, phase: SessionPhase) -> None: + session.phase = phase + self._touch_session(session) + + def _touch_session(self, session: GatewaySessionState) -> None: + session.updated_at = time.time() + + def _materialize_active_trajectory(self, session: GatewaySessionState) -> None: + active = session.active_trajectory + if active is None: + return + + self._touch_session(session) + session.trajectories.append( + self._build_materialized_trajectory( + session=session, + active=active, + ) + ) + session.active_trajectory = None + + def _build_materialized_trajectory( + self, + *, + session: GatewaySessionState, + active: TrajectoryBuffer, + ) -> Trajectory: + return Trajectory( + prompt_ids=list(active.prompt_ids), + response_ids=list(active.response_ids), + response_mask=list(active.response_mask), + response_logprobs=_materialize_response_logprobs(active), + reward_info={}, + num_turns=_count_chat_turns(session.message_history), + multi_modal_data=_build_multi_modal_trajectory_data(session.image_data, session.video_data), + ) + + async def _default_vision_info_extractor( + self, + messages: list[dict[str, Any]], + *, + image_patch_size: int, + ) -> tuple[list[Any] | None, list[Any] | None]: + # Keep the dataset dependency lazy so custom extractors do not pay for + # RLHFDataset imports unless they actually use the default path. + from verl.utils.dataset.rl_dataset import RLHFDataset + + return await RLHFDataset.process_vision_info( + messages, + image_patch_size=image_patch_size, + config=self._vision_info_extractor_kwargs.get("config"), + ) + + async def _extract_multi_modal_data( + self, + messages: list[dict[str, Any]], + ) -> tuple[list[Any] | None, list[Any] | None]: + if self._processor is None: + return None, None + + has_multi_modal_blocks = False + for message in messages: + content = message.get("content") + if not isinstance(content, list): + continue + for part in content: + if isinstance(part, dict) and part.get("type") in {"image", "image_url", "video", "video_url"}: + has_multi_modal_blocks = True + break + if has_multi_modal_blocks: + break + + if not has_multi_modal_blocks: + return None, None + + return await self._vision_info_extractor( + messages, + image_patch_size=self._processor.image_processor.patch_size, + **self._vision_info_extractor_kwargs, + ) + + def _encode_full( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + image_data: list[Any] | None = None, + video_data: list[Any] | None = None, + ) -> list[int]: + """Encode a full conversation for a new trajectory (includes system prompt + generation prompt).""" + if self._processor is not None: + raw_prompt = _apply_chat_template( + self._processor, + messages, + tools=tools, + add_generation_prompt=True, + tokenize=False, + **self._apply_chat_template_kwargs, + ) + videos = video_data + video_metadata = None + if videos is not None: + videos, video_metadata = zip(*videos, strict=False) + videos, video_metadata = list(videos), list(video_metadata) + model_inputs = self._processor( + text=[raw_prompt], + images=image_data, + videos=videos, + video_metadata=video_metadata, + return_tensors="pt", + do_sample_frames=False, + ) + return normalize_token_ids(model_inputs["input_ids"]) + + return normalize_token_ids( + _apply_chat_template( + self._tokenizer, messages, tools=tools, add_generation_prompt=True, + **self._apply_chat_template_kwargs, + ) + ) + # TODO: check if delta tokenization is better than remove_system_prompt + def _encode_incremental( + self, + messages: list[dict[str, Any]], + image_data: list[Any] | None = None, + video_data: list[Any] | None = None, + ) -> list[int]: + """Encode incremental messages (tool results, user follow-ups) for a continuation turn. + + Uses the remove_system_prompt pattern from ToolAgentLoop: encode the new messages + alone (which prepends a system prompt), then strip the known system_prompt prefix. + No tools parameter — tool schema is already in the initial prompt_ids. + """ + if self._processor is not None: + raw_prompt = _apply_chat_template( + self._processor, + messages, + add_generation_prompt=True, + tokenize=False, + **self._apply_chat_template_kwargs, + ) + videos = video_data + video_metadata = None + if videos is not None: + videos, video_metadata = zip(*videos, strict=False) + videos, video_metadata = list(videos), list(video_metadata) + model_inputs = self._processor( + text=[raw_prompt], + images=image_data, + videos=videos, + video_metadata=video_metadata, + return_tensors="pt", + do_sample_frames=False, + ) + ids = normalize_token_ids(model_inputs["input_ids"]) + else: + ids = normalize_token_ids( + _apply_chat_template( + self._tokenizer, messages, add_generation_prompt=True, + **self._apply_chat_template_kwargs, + ) + ) + return ids[len(self._system_prompt):] + + async def _decode_response( + self, response_ids: list[int], *, tools: list[dict[str, Any]] | None = None, + stop_reason: str | None = None, + ) -> tuple[dict[str, Any], str]: + """Decode model output tokens into an OpenAI-compatible assistant message. + + Returns: + message: OpenAI-compatible assistant message. + finish_reason: "tool_calls" when tool calls are present, else stop_reason or "stop". + """ + if self._tool_parser is not None and tools: + content, function_calls = await self._tool_parser.extract_tool_calls(response_ids) + if function_calls: + tool_calls = [ + { + "id": f"call_{uuid4().hex[:8]}", + "type": "function", + "function": {"name": fc.name, "arguments": fc.arguments}, + } + for fc in function_calls + ] + message = { + "role": "assistant", + # Use "" instead of None so that prefix comparison with + # _normalize_message_content (which also coerces None → "") + # stays consistent. Both must agree on the None policy. + "content": content or "", + "tool_calls": tool_calls, + } + return message, "tool_calls" + response_text = self._tokenizer.decode(response_ids, skip_special_tokens=True) + return {"role": "assistant", "content": response_text}, stop_reason or "stop" + + async def _handle_chat_completions(self, session_id: str, payload: dict[str, Any]) -> JSONResponse: + session = self._sessions.get(session_id) + if session is None: + raise HTTPException(status_code=404, detail=f"Unknown session_id: {session_id}") + + try: + request_context = _normalize_request_context(payload) + except MalformedRequestError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + async with session.generation_lock: + if session.phase != SessionPhase.ACTIVE: + raise HTTPException(status_code=409, detail=f"Session {session_id} is {session.phase.value.lower()}") + + async with session.request_lock: + if session.phase != SessionPhase.ACTIVE: + raise HTTPException( + status_code=409, detail=f"Session {session_id} is {session.phase.value.lower()}" + ) + + self._touch_session(session) + messages = request_context["messages"] + tools = request_context["tools"] + materialized_trajectory = None + image_data = None + video_data = None + + if session.active_trajectory is None: + image_data, video_data = await self._extract_multi_modal_data(messages) + active_trajectory = TrajectoryBuffer( + prompt_ids=self._encode_full( + messages, tools=tools, image_data=image_data, video_data=video_data + ) + ) + elif _is_request_context_prefix(session=session, messages=messages, tools=tools): + active_trajectory = _copy_trajectory_buffer(session.active_trajectory) + image_data = list(session.image_data) if session.image_data is not None else None + video_data = list(session.video_data) if session.video_data is not None else None + incremental_messages = messages[len(session.message_history) :] + if incremental_messages: + new_image_data, new_video_data = await self._extract_multi_modal_data(incremental_messages) + if new_image_data: + if image_data is None: + image_data = [] + image_data.extend(new_image_data) + if new_video_data: + if video_data is None: + video_data = [] + video_data.extend(new_video_data) + incremental_ids = self._encode_incremental( + incremental_messages, + image_data=new_image_data, + video_data=new_video_data, + ) + active_trajectory.response_ids.extend(incremental_ids) + active_trajectory.response_mask.extend([0] * len(incremental_ids)) + if active_trajectory.response_logprobs: + active_trajectory.response_logprobs.extend([0.0] * len(incremental_ids)) + else: + materialized_trajectory = self._build_materialized_trajectory( + session=session, + active=session.active_trajectory, + ) + image_data, video_data = await self._extract_multi_modal_data(messages) + active_trajectory = TrajectoryBuffer( + prompt_ids=self._encode_full( + messages, tools=tools, image_data=image_data, video_data=video_data + ) + ) + + generation_context_ids = active_trajectory.prompt_ids + active_trajectory.response_ids + sampling_params = _build_sampling_params( + payload, + base_sampling_params=self._base_sampling_params, + allowed_request_sampling_param_keys=self._allowed_request_sampling_param_keys, + ) + + output = await self._backend.generate( + request_id=session_id, + prompt_ids=generation_context_ids, + sampling_params=sampling_params, + image_data=image_data, + video_data=video_data, + ) + + response_ids = list(output.token_ids) + active_trajectory.response_ids.extend(response_ids) + active_trajectory.response_mask.extend([1] * len(response_ids)) + if output.log_probs is not None: + active_trajectory.response_logprobs.extend(list(output.log_probs)) + + assistant_msg, finish_reason = await self._decode_response( + response_ids, tools=tools, stop_reason=output.stop_reason, + ) + async with session.request_lock: + if session.phase != SessionPhase.ACTIVE: + raise HTTPException( + status_code=409, detail=f"Session {session_id} is {session.phase.value.lower()}" + ) + + if materialized_trajectory is not None: + session.trajectories.append(materialized_trajectory) + session.active_trajectory = active_trajectory + session.image_data = list(image_data) if image_data is not None else None + session.video_data = list(video_data) if video_data is not None else None + session.message_history = messages + [assistant_msg] + session.request_tools = tools + self._touch_session(session) + + return JSONResponse( + { + "id": f"chatcmpl-{uuid4().hex}", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": assistant_msg, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": len(generation_context_ids), + "completion_tokens": len(response_ids), + "total_tokens": len(generation_context_ids) + len(response_ids), + }, + } + ) + + async def start(self) -> None: + if self._server_task is not None: + return + self._server_port, self._server_task = await run_uvicorn(self._app, None, self._server_address) + self._server_base_url = f"http://{self._server_address}:{self._server_port}" + + async def shutdown(self) -> None: + if self._server_task is None: + return + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass + self._server_task = None + self._server_port = None + self._server_base_url = None + + async def create_session(self, session_id: str, metadata: dict[str, Any] | None = None) -> SessionHandle: + self._require_started() + if session_id in self._sessions: + raise RuntimeError(f"Session {session_id} already exists") + + handle = SessionHandle( + session_id=session_id, + base_url=f"{self._server_base_url}/sessions/{session_id}/v1", + ) + self._sessions[session_id] = GatewaySessionState(handle=handle, metadata=dict(metadata or {})) + return handle + + async def complete_session(self, session_id: str, reward_info: dict[str, Any] | None = None) -> None: + session = self._get_session(session_id) + async with session.request_lock: + # Accommodate retry attempts + if session.phase not in {SessionPhase.COMPLETED, SessionPhase.ACTIVE}: + raise RuntimeError(f"Session {session_id} is {session.phase.value.lower()}") + + if reward_info is not None: + session.reward_info = dict(reward_info) + + self._set_phase(session, SessionPhase.COMPLETED) + session.completed.set() + + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: + session = self._sessions.get(session_id) + if session is None: + # Already finalized or aborted by a concurrent caller — nothing to wait for. + return + if session.phase == SessionPhase.COMPLETED: + # Fast path: agent already called /complete, no need to wait. + return + + await asyncio.wait_for(session.completed.wait(), timeout=timeout) + + # Post-await: the session may have been aborted during the wait. + # The local reference is still valid even if the session was removed from _sessions. + if session.phase == SessionPhase.ABORTED: + raise RuntimeError(f"Session {session_id} is aborted") + + async def finalize_session(self, session_id: str) -> list[Trajectory]: + session = self._get_session(session_id) + async with session.request_lock: + if session.phase == SessionPhase.ABORTED: + raise RuntimeError(f"Session {session_id} is aborted") + if session.phase == SessionPhase.FINALIZED: + raise RuntimeError(f"Session {session_id} is finalized") + + self._touch_session(session) + self._materialize_active_trajectory(session) + self._set_phase(session, SessionPhase.FINALIZED) + session.completed.set() + trajectories = [replace(trajectory, reward_info=dict(session.reward_info)) for trajectory in session.trajectories] + self._sessions.pop(session_id, None) + return trajectories + + async def abort_session(self, session_id: str) -> None: + session = self._sessions.get(session_id) + if session is None: + return # Already finalized or aborted — treat as idempotent. + async with session.request_lock: + if session.phase == SessionPhase.ABORTED: + return # Concurrent abort — idempotent. + if session.phase == SessionPhase.FINALIZED: + raise RuntimeError(f"Session {session_id} is finalized") + + self._set_phase(session, SessionPhase.ABORTED) + session.completed.set() + self._sessions.pop(session_id, None) + + async def get_session_state(self, session_id: str) -> dict[str, Any]: + session = self._get_session(session_id) + return { + "session_id": session.handle.session_id, + "metadata": dict(session.metadata), + "phase": session.phase.value, + "created_at": session.created_at, + "updated_at": session.updated_at, + "num_trajectories": len(session.trajectories), + "has_active_trajectory": session.active_trajectory is not None, + } + + +GatewayActor = ray.remote(_GatewayActor) diff --git a/uni_agent/trainer/gateway/manager.py b/uni_agent/trainer/gateway/manager.py new file mode 100644 index 0000000..17f038c --- /dev/null +++ b/uni_agent/trainer/gateway/manager.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import asyncio + + +async def _await_object_ref(object_ref): + return await asyncio.wrap_future(object_ref.future()) + + +class GatewayManager: + """Session-routing component owned by the serving runtime.""" + + def __init__(self, gateways: list): + self.gateways = gateways + self.gateway_count = len(gateways) + self.active_sessions_per_gateway = [0 for _ in gateways] + self._session_to_gateway_index: dict[str, int] = {} + + def _select_gateway_index(self) -> int: + if not self.gateways: + raise RuntimeError("No gateway actors configured") + return min(range(len(self.gateways)), key=lambda index: self.active_sessions_per_gateway[index]) + + def _get_gateway_index(self, session_id: str) -> int: + gateway_index = self._session_to_gateway_index.get(session_id) + if gateway_index is None: + raise KeyError(session_id) + return gateway_index + + def _get_gateway(self, session_id: str): + gateway_index = self._get_gateway_index(session_id) + return self.gateways[gateway_index], gateway_index + + async def create_session(self, session_id: str, **kwargs): + gateway_index = self._select_gateway_index() + gateway = self.gateways[gateway_index] + handle = await _await_object_ref(gateway.create_session.remote(session_id=session_id, **kwargs)) + self._session_to_gateway_index[session_id] = gateway_index + self.active_sessions_per_gateway[gateway_index] += 1 + return handle + + async def finalize_session(self, session_id: str): + gateway, gateway_index = self._get_gateway(session_id) + trajectories = await _await_object_ref(gateway.finalize_session.remote(session_id=session_id)) + self._session_to_gateway_index.pop(session_id, None) + self.active_sessions_per_gateway[gateway_index] -= 1 + return trajectories + + async def abort_session(self, session_id: str) -> None: + gateway, gateway_index = self._get_gateway(session_id) + await _await_object_ref(gateway.abort_session.remote(session_id=session_id)) + self._session_to_gateway_index.pop(session_id, None) + self.active_sessions_per_gateway[gateway_index] -= 1 + + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: + gateway, _ = self._get_gateway(session_id) + await _await_object_ref(gateway.wait_for_completion.remote(session_id=session_id, timeout=timeout)) diff --git a/uni_agent/trainer/gateway/runtime.py b/uni_agent/trainer/gateway/runtime.py new file mode 100644 index 0000000..058c207 --- /dev/null +++ b/uni_agent/trainer/gateway/runtime.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +import ray + +from verl.workers.rollout.llm_server import LLMServerClient + + +async def _await_ray_ref(object_ref): + return await asyncio.wrap_future(object_ref.future()) + + +class GatewayServingRuntime: + """Standalone serving runtime that owns gateway actors and delegates backend routing.""" + + def __init__( + self, + llm_client: LLMServerClient, + *, + gateway_manager=None, + gateway_count: int = 0, + gateway_actor_kwargs: dict[str, Any] | None = None, + ): + self._llm_client = llm_client + self.owned_gateway_actors: list[ray.actor.ActorHandle] = [] + self.gateway_manager = gateway_manager + + if self.gateway_manager is None and gateway_count > 0: + self._initialize_gateway_runtime( + gateway_count=gateway_count, + gateway_actor_kwargs=gateway_actor_kwargs, + ) + + def _initialize_gateway_runtime( + self, + *, + gateway_count: int, + gateway_actor_kwargs: dict[str, Any] | None = None, + ) -> None: + from verl.agent.gateway.gateway import GatewayActor + from verl.agent.gateway.manager import GatewayManager + + gateway_actor_kwargs = dict(gateway_actor_kwargs or {}) + if "backend" not in gateway_actor_kwargs: + gateway_actor_kwargs["backend"] = self + + self.owned_gateway_actors = [GatewayActor.remote(**gateway_actor_kwargs) for _ in range(gateway_count)] + ray.get([gateway.start.remote() for gateway in self.owned_gateway_actors]) + self.gateway_manager = GatewayManager(self.owned_gateway_actors) + + def _require_session_runtime(self): + if self.gateway_manager is None: + raise RuntimeError("Session runtime is disabled because gateway_count=0") + return self.gateway_manager + + async def create_session(self, session_id: str, **kwargs): + gateway_manager = self._require_session_runtime() + return await gateway_manager.create_session(session_id=session_id, **kwargs) + + async def finalize_session(self, session_id: str): + gateway_manager = self._require_session_runtime() + return await gateway_manager.finalize_session(session_id=session_id) + + async def abort_session(self, session_id: str) -> None: + gateway_manager = self._require_session_runtime() + await gateway_manager.abort_session(session_id=session_id) + + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: + gateway_manager = self._require_session_runtime() + await gateway_manager.wait_for_completion(session_id=session_id, timeout=timeout) + + async def shutdown(self) -> None: + if self.owned_gateway_actors: + await asyncio.gather(*[_await_ray_ref(gateway.shutdown.remote()) for gateway in self.owned_gateway_actors]) + self.owned_gateway_actors = [] + self.gateway_manager = None + + async def generate( + self, + request_id, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + image_data: list[Any] | None = None, + video_data: list[Any] | None = None, + **kwargs: Any, + ) -> Any: + return await self._llm_client.generate( + request_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, + image_data=image_data, + video_data=video_data, + **kwargs, + ) diff --git a/uni_agent/trainer/gateway/types.py b/uni_agent/trainer/gateway/types.py new file mode 100644 index 0000000..b690ab7 --- /dev/null +++ b/uni_agent/trainer/gateway/types.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from verl.agent.framework.types import SessionHandle, Trajectory + + +class SessionPhase(str, Enum): + ACTIVE = "ACTIVE" + COMPLETED = "COMPLETED" + FINALIZED = "FINALIZED" + ABORTED = "ABORTED" + + +@dataclass +class TrajectoryBuffer: + prompt_ids: list[int] + response_ids: list[int] = field(default_factory=list) + response_mask: list[int] = field(default_factory=list) + response_logprobs: list[float] = field(default_factory=list) + + +@dataclass +class GatewaySessionState: + handle: SessionHandle + metadata: dict[str, Any] = field(default_factory=dict) + request_tools: list[dict[str, Any]] | None = None + message_history: list[dict[str, Any]] = field(default_factory=list) + image_data: list[Any] | None = None + video_data: list[Any] | None = None + active_trajectory: TrajectoryBuffer | None = None + trajectories: list[Trajectory] = field(default_factory=list) + reward_info: dict[str, Any] = field(default_factory=dict) + completed: asyncio.Event = field(default_factory=asyncio.Event) + phase: SessionPhase = SessionPhase.ACTIVE + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + request_lock: asyncio.Lock = field(default_factory=asyncio.Lock) + generation_lock: asyncio.Lock = field(default_factory=asyncio.Lock) From 3b3adace9891f82120bb94dd0e2165f32c9b1a61 Mon Sep 17 00:00:00 2001 From: zackcxb Date: Thu, 21 May 2026 02:29:30 +0000 Subject: [PATCH 02/22] feat(trainer): sync framework+gateway from verl PR #6299 HEAD Overwrite stale init-commit code with latest verl gateway-framework-pr-source (4605deb4). Key changes since init: - Inline reward scoring (remove callable abstraction) - Gateway round-robin placement, finish_reason map, tool parse tolerance - Zero-fill rollout_log_probs/rm_scores for trainer compatibility - Session concurrency cap (max_concurrent_sessions) - Delete helpers.py (inlined into framework.py) - Entry.py now includes AgentFrameworkRolloutAdapter Import paths rewritten: verl.agent.{framework,gateway} -> uni_agent.trainer.{framework,gateway} External verl imports (verl.utils.*, verl.workers.*) kept as-is (accessed via submodule). Co-authored-by: Claude Opus 4.6 --- uni_agent/trainer/framework/__init__.py | 3 - uni_agent/trainer/framework/entry.py | 93 ++++++-- uni_agent/trainer/framework/framework.py | 283 +++++++++++++---------- uni_agent/trainer/framework/helpers.py | 71 ------ uni_agent/trainer/framework/types.py | 23 +- uni_agent/trainer/gateway/gateway.py | 58 ++++- uni_agent/trainer/gateway/manager.py | 4 + uni_agent/trainer/gateway/runtime.py | 28 ++- uni_agent/trainer/gateway/types.py | 2 +- 9 files changed, 307 insertions(+), 258 deletions(-) delete mode 100644 uni_agent/trainer/framework/helpers.py diff --git a/uni_agent/trainer/framework/__init__.py b/uni_agent/trainer/framework/__init__.py index 425d2f9..5b23f9d 100644 --- a/uni_agent/trainer/framework/__init__.py +++ b/uni_agent/trainer/framework/__init__.py @@ -1,5 +1,4 @@ from .framework import AgentFramework, OpenAICompatibleAgentFramework -from .helpers import normalize_trajectory_rewards, validate_trajectory from .types import SessionHandle, Trajectory __all__ = [ @@ -7,6 +6,4 @@ "OpenAICompatibleAgentFramework", "SessionHandle", "Trajectory", - "normalize_trajectory_rewards", - "validate_trajectory", ] diff --git a/uni_agent/trainer/framework/entry.py b/uni_agent/trainer/framework/entry.py index 883d941..7872041 100644 --- a/uni_agent/trainer/framework/entry.py +++ b/uni_agent/trainer/framework/entry.py @@ -1,47 +1,56 @@ -"""Factory entry for session runtime construction and framework FQN dispatch. +"""Factory entry + trainer-facing adapter for the agent framework stack. -entry owns gateway-universal wiring so framework subclasses only handle their -own agent runner, reward bridge, and framework-specific config fields. -Phase A: recipe adapter calls this. Phase B: main_ppo_sync.py calls it directly. +`build_agent_framework` owns gateway-universal wiring so framework subclasses +only handle their own agent runner, reward dispatch, and framework-specific +config fields. + +`AgentFrameworkRolloutAdapter` satisfies the trainer's +`agent_loop_manager_class` extension-point contract; recipes wire it in via +yaml without authoring per-recipe glue: + + actor_rollout_ref.rollout.agent.agent_loop_manager_class: + uni_agent.trainer.framework.entry.AgentFrameworkRolloutAdapter """ from __future__ import annotations from omegaconf import OmegaConf -from verl.agent.framework.framework import AgentFramework, OpenAICompatibleAgentFramework -from verl.agent.gateway.runtime import GatewayServingRuntime +from uni_agent.trainer.framework.framework import AgentFramework +from uni_agent.trainer.gateway.runtime import GatewayServingRuntime +from verl.utils.config import omega_conf_to_dataclass from verl.utils.import_utils import load_class_from_fqn +from verl.utils.ray_utils import auto_await +from verl.workers.config.model import HFModelConfig -_DEFAULT_FRAMEWORK_CLASS = f"{OpenAICompatibleAgentFramework.__module__}.OpenAICompatibleAgentFramework" -_DEFAULT_GATEWAY_COUNT = 0 -_DEFAULT_TOOL_PARSER = "hermes" +_DEFAULT_FRAMEWORK_CLASS = "uni_agent.trainer.framework.framework.OpenAICompatibleAgentFramework" async def build_agent_framework( *, config, llm_client, - tokenizer, - processor=None, replay_buffer, + reward_loop_worker_handles=None, ) -> AgentFramework: """Build GatewayServingRuntime, then delegate subclass-specific wiring.""" # TODO(phase-b): switch this to actor_rollout_ref.rollout.agent_framework.* af_cfg = OmegaConf.select(config, "actor_rollout_ref.rollout.custom.agent_framework", default={}) or {} + # Match AgentLoopWorker pattern: self-load tokenizer/processor via HFModelConfig. + model_config: HFModelConfig = omega_conf_to_dataclass(config.actor_rollout_ref.model) + gateway_actor_kwargs = { - "tokenizer": tokenizer, - "processor": processor, - "tool_parser_name": config.actor_rollout_ref.rollout.get("multi_turn", {}).get("format") - or _DEFAULT_TOOL_PARSER, + "tokenizer": model_config.tokenizer, + "processor": model_config.processor, } - if "host" in af_cfg and af_cfg["host"] is not None: - gateway_actor_kwargs["host"] = af_cfg["host"] + tool_parser_name = config.actor_rollout_ref.rollout.get("multi_turn", {}).get("format") + if tool_parser_name is not None: + gateway_actor_kwargs["tool_parser_name"] = tool_parser_name session_runtime = GatewayServingRuntime( llm_client=llm_client, - gateway_count=int(af_cfg.get("gateway_count", _DEFAULT_GATEWAY_COUNT)), + gateway_count=int(af_cfg["gateway_count"]), gateway_actor_kwargs=gateway_actor_kwargs, ) @@ -49,7 +58,51 @@ async def build_agent_framework( return await framework_cls.from_config( config=config, session_runtime=session_runtime, - tokenizer=tokenizer, - processor=processor, + processor=model_config.processor, replay_buffer=replay_buffer, + reward_loop_worker_handles=reward_loop_worker_handles, ) + + +class AgentFrameworkRolloutAdapter: + """Trainer-facing adapter satisfying the `agent_loop_manager_class` contract. + + Holds zero recipe-specific logic; every agent-framework recipe wires the + same class in yaml. Phase B will let `main_ppo_sync.py` call + `build_agent_framework` directly and this adapter can retire. + """ + + def __init__(self) -> None: + self.framework = None + + @classmethod + @auto_await + async def create( + cls, + *, + config, + llm_client, + teacher_client=None, + reward_loop_worker_handles=None, + replay_buffer=None, + **_, + ) -> "AgentFrameworkRolloutAdapter": + del teacher_client + assert replay_buffer is not None, "AgentFrameworkRolloutAdapter requires replay_buffer" + + framework = await build_agent_framework( + config=config, + llm_client=llm_client, + replay_buffer=replay_buffer, + reward_loop_worker_handles=reward_loop_worker_handles, + ) + + instance = cls() + instance.framework = framework + return instance + + @auto_await + async def generate_sequences(self, prompts) -> None: + if self.framework is None: + raise RuntimeError("framework must be initialized before generate_sequences") + return await self.framework.generate_sequences(prompts) diff --git a/uni_agent/trainer/framework/framework.py b/uni_agent/trainer/framework/framework.py index 2ebcdbc..f672f75 100644 --- a/uni_agent/trainer/framework/framework.py +++ b/uni_agent/trainer/framework/framework.py @@ -1,8 +1,8 @@ from __future__ import annotations import asyncio -import inspect import logging +import random from abc import ABC, abstractmethod from dataclasses import replace from functools import partial @@ -14,14 +14,13 @@ from tensordict.tensorclass import NonTensorData, NonTensorStack from verl.tools.utils.tool_registry import initialize_tools_from_config -from verl.trainer.ppo.reward import get_custom_reward_fn from verl.utils.import_utils import load_class_from_fqn from verl.utils.transferqueue_utils import tq from verl.utils import tensordict_utils as tu from verl.utils.model import compute_position_id_with_mask from .multi_modal_postprocess import compute_multi_modal_inputs, compute_position_ids -from .types import RewardFn, SessionRewardContext, SessionRuntime, Trajectory +from .types import SessionRuntime, Trajectory logger = logging.getLogger(__name__) @@ -43,9 +42,9 @@ async def from_config( *, config, session_runtime, - tokenizer=None, processor=None, replay_buffer, + reward_loop_worker_handles=None, ) -> "AgentFramework": ... @@ -55,14 +54,6 @@ async def generate_sequences(self, prompts: TensorDict) -> None: ... -def _to_long_tensor(values) -> torch.Tensor: - return torch.tensor(list(values), dtype=torch.long) - - -def _to_float_tensor(values) -> torch.Tensor: - return torch.tensor(list(values), dtype=torch.float32) - - def _short_failure_reason(error: BaseException) -> str: message = str(error) if not message: @@ -98,35 +89,40 @@ def _list_of_tq_fields_to_tensordict(fields: list[dict[str, object]]) -> TensorD return td -def _build_reward_fn(config, tokenizer): - # Phase A keeps reward configuration in the existing VERL path and only - # bridges framework trajectories to the raw custom_reward_function - # signature. Phase B should reuse the main reward manager path directly - # instead of growing a parallel agent-framework reward config surface. - custom_reward_fn = get_custom_reward_fn(config) - if custom_reward_fn is None: - return None +def _trajectory_to_reward_dataproto(trajectory, sample_fields): + """Build a single-sample DataProto for RewardLoopWorker.compute_score. - async def reward_fn(ctx): - data_source = ctx.sample_fields.get("data_source") - reward_model = ctx.sample_fields.get("reward_model") - if isinstance(reward_model, dict): - ground_truth = reward_model.get("ground_truth") - elif reward_model is None: - ground_truth = None - else: - ground_truth = getattr(reward_model, "ground_truth", None) - extra_info = ctx.sample_fields.get("extra_info") - scores = [] - for trajectory in ctx.trajectories: - response_text = tokenizer.decode(trajectory.response_ids, skip_special_tokens=True) - score = custom_reward_fn(data_source, response_text, ground_truth, extra_info) - if inspect.isawaitable(score): - score = await score - scores.append(score) - return scores + Field shape matches AgentLoopWorker._compute_score + (verl/experimental/agent_loop/agent_loop.py:753-772). Only fields actually + consumed by NaiveRewardManager.run_single / RewardLoopWorker dispatch are + populated; tool_extra_fields / num_turns are passed via non_tensor_batch + for parity. + """ + import numpy as np + from verl.protocol import DataProto + + prompt_ids = torch.tensor(trajectory.prompt_ids, dtype=torch.long).unsqueeze(0) + response_ids = torch.tensor(trajectory.response_ids, dtype=torch.long).unsqueeze(0) + input_ids = torch.cat([prompt_ids, response_ids], dim=1) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + batch = TensorDict( + { + "prompts": prompt_ids, + "responses": response_ids, + "input_ids": input_ids, + "attention_mask": attention_mask, + }, + batch_size=1, + ) - return reward_fn + non_tensor_batch: dict[str, object] = {} + for key in ("raw_prompt", "data_source", "reward_model", "extra_info", "tools_kwargs", "agent_name"): + if key in sample_fields: + non_tensor_batch[key] = np.array([sample_fields[key]], dtype=object) + non_tensor_batch["__num_turns__"] = np.array([trajectory.num_turns]) + + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) class OpenAICompatibleAgentFramework(AgentFramework): @@ -135,25 +131,29 @@ class OpenAICompatibleAgentFramework(AgentFramework): Each sample in the batch is run as an independent session: the agent communicates with the Gateway via standard ``/v1/chat/completions`` requests, and the Gateway collects token-level trajectories. After - finalization, ``reward_fn`` scores the session's trajectories and the - framework writes them to the TransferQueue schema consumed by sync training. + finalization, ``_score_trajectories`` dispatches the session's final + trajectory to a RewardLoopWorker and broadcasts the score back to all + trajectories in the session (matching + ``AgentLoopWorkerTQ._agent_loop_postprocess``); the framework then writes + them to the TransferQueue schema consumed by sync training. """ def __init__( self, session_runtime: SessionRuntime, agent_runner, - reward_fn: RewardFn | None, *, + reward_loop_worker_handles=None, processor=None, replay_buffer=None, rollout_config=None, completion_timeout: float | None = 30.0, wait_for_completion_after_agent_run: bool = False, + max_concurrent_sessions: int = 0, ): self.session_runtime = session_runtime self.agent_runner = agent_runner - self.reward_fn = reward_fn + self.reward_loop_worker_handles = list(reward_loop_worker_handles) if reward_loop_worker_handles else None self._processor = processor # TODO(phase-b): once trainer constructs framework directly, these become # constructor-required and no transitional dual-path is needed. @@ -161,6 +161,9 @@ def __init__( self._rollout_config = rollout_config self.completion_timeout = completion_timeout self.wait_for_completion_after_agent_run = wait_for_completion_after_agent_run + self._max_concurrent_sessions = max_concurrent_sessions + self._semaphore: asyncio.Semaphore | None = None + self._semaphore_loop: asyncio.AbstractEventLoop | None = None @classmethod async def from_config( @@ -168,13 +171,10 @@ async def from_config( *, config, session_runtime, - tokenizer=None, processor=None, replay_buffer, + reward_loop_worker_handles=None, ) -> "OpenAICompatibleAgentFramework": - if tokenizer is None: - raise ValueError("OpenAICompatibleAgentFramework requires tokenizer for reward bridge") - # TODO(phase-b): switch this to actor_rollout_ref.rollout.agent_framework.* af_cfg = OmegaConf.select(config, "actor_rollout_ref.rollout.custom.agent_framework", default={}) or {} agent_runner_fqn = af_cfg.get("agent_runner_fqn") @@ -194,22 +194,17 @@ async def from_config( if runner_kwargs: agent_runner = partial(agent_runner, **runner_kwargs) - # TODO(phase-x): when reward_loop_worker_handles is available from - # trainer, accept reward_fn as an entry-injected resource and skip - # bridge construction. Bridge remains available for simple recipes - # that supply reward.custom_reward_function directly. - reward_fn = _build_reward_fn(config, tokenizer) - completion_timeout = af_cfg.get("completion_timeout_seconds") return cls( session_runtime=session_runtime, agent_runner=agent_runner, - reward_fn=reward_fn, + reward_loop_worker_handles=reward_loop_worker_handles, processor=processor, replay_buffer=replay_buffer, rollout_config=config.actor_rollout_ref.rollout, completion_timeout=completion_timeout, wait_for_completion_after_agent_run=completion_timeout is not None, + max_concurrent_sessions=int(af_cfg.get("max_concurrent_sessions", 0)), ) async def generate_sequences(self, prompts: TensorDict) -> None: @@ -224,7 +219,11 @@ async def generate_sequences(self, prompts: TensorDict) -> None: raise ValueError("OpenAICompatibleAgentFramework requires prompts['global_steps']") partition_id = "val" if "validate" in prompts.keys() else "train" - num_sessions = self._num_sessions_for_partition(partition_id) + if partition_id == "val": + val_kwargs = self._rollout_config.get("val_kwargs", {}) + num_sessions = int(val_kwargs.get("n")) + else: + num_sessions = int(self._rollout_config.get("n")) uids = tu.get(prompts, "uid") if uids is None: @@ -235,7 +234,7 @@ async def generate_sequences(self, prompts: TensorDict) -> None: {str(uid): {"global_steps": global_steps, "status": "running"} for uid in uid_values}, ) - stats = await self._generate_to_tq( + stats = await self._run_batch_to_tq( prompts, global_steps=global_steps, partition_id=partition_id, @@ -258,13 +257,7 @@ async def generate_sequences(self, prompts: TensorDict) -> None: ) return None - def _num_sessions_for_partition(self, partition_id: str) -> int: - if partition_id == "val": - val_kwargs = self._rollout_config.get("val_kwargs", {}) - return int(val_kwargs.get("n")) - return int(self._rollout_config.get("n")) - - async def _generate_to_tq( + async def _run_batch_to_tq( self, prompts: TensorDict, *, @@ -272,13 +265,7 @@ async def _generate_to_tq( partition_id: str, num_sessions: int = 1, ) -> dict: - """Run agent sessions and write finalized trajectories to TransferQueue. - - This is the TransferQueue-oriented sibling of ``generate_sequences``. - It preserves the same session lifecycle, but writes each finalized - trajectory with the key/tag/field schema consumed by - ``verl.trainer.main_ppo_sync`` instead of returning a batch. - """ + """Run all prompts in a batch and aggregate prompt/session stats.""" assert len(prompts) > 0, "generate_sequences requires a non-empty batch" if num_sessions <= 0: raise ValueError(f"num_sessions must be positive, got {num_sessions}") @@ -287,8 +274,10 @@ async def _generate_to_tq( if raw_prompts is None: raise ValueError("OpenAICompatibleAgentFramework requires prompts['raw_prompt']") + # Batch layer: each sample/prompt owns its own group of rollout.n sessions. + # Prompt tasks are isolated so one prompt failure does not drop the whole batch. tasks = [ - self._run_prompt_to_replay_buffer( + self._run_prompt_sessions_to_tq( prompts=prompts, raw_prompt=raw_prompts[sample_index], sample_index=sample_index, @@ -322,7 +311,7 @@ async def _generate_to_tq( failure_reasons.extend(outcome["failure_reasons"]) return stats - async def _run_prompt_to_replay_buffer( + async def _run_prompt_sessions_to_tq( self, *, prompts: TensorDict, @@ -338,19 +327,19 @@ async def _run_prompt_to_replay_buffer( raise ValueError("OpenAICompatibleAgentFramework requires prompts['uid'] for TransferQueue output") uid = str(uid) + # Prompt layer: rollout.n sessions race independently for the same uid. + # Successful sessions are written to TQ; failed sessions only affect this uid's stats. tasks = [ - self._run_session( + self._run_session_with_concurrency_limit( prompts=prompts, raw_prompt=raw_prompt, sample_index=sample_index, - session_id=self._build_session_id( - prompts=prompts, - sample_index=sample_index, - session_index=session_index, - ), - runner_kwargs=( - {"tools_kwargs": sample_fields["tools_kwargs"]} if "tools_kwargs" in sample_fields else {} - ), + session_id=f"session-{sample_index}-{session_index}-{uuid4().hex}", + runner_kwargs={ + key: sample_fields[key] + for key in ("tools_kwargs", "agent_name") + if key in sample_fields + }, ) for session_index in range(num_sessions) ] @@ -369,13 +358,13 @@ async def _run_prompt_to_replay_buffer( trajectories, session_sample_fields = outcome if not trajectories: failed_sessions += 1 - failure_reasons.append(f"empty trajectories for uid={uid} session_id={session_index}") + failure_reasons.append(f"empty trajectories for uid={uid} session_index={session_index}") continue success_sessions += 1 await self._write_session_trajectories_to_tq( uid=uid, - session_id=session_index, + session_index=session_index, trajectories=trajectories, sample_fields=session_sample_fields, global_steps=global_steps, @@ -398,6 +387,39 @@ async def _run_prompt_to_replay_buffer( "failure_reasons": failure_reasons, } + async def _run_session_with_concurrency_limit( + self, + *, + prompts: TensorDict, + raw_prompt, + sample_index: int, + session_id: str | None = None, + runner_kwargs: dict[str, object] | None = None, + ) -> tuple[list[Trajectory], dict[str, object]]: + if self._max_concurrent_sessions <= 0: + return await self._run_session( + prompts=prompts, + raw_prompt=raw_prompt, + sample_index=sample_index, + session_id=session_id, + runner_kwargs=runner_kwargs, + ) + # Lazy-init Semaphore on first use and rebind if the running loop + # changed: asyncio.Semaphore binds to the loop at construction, but + # Ray actors may run sessions on a different loop than __init__. + loop = asyncio.get_running_loop() + if self._semaphore is None or self._semaphore_loop is not loop: + self._semaphore = asyncio.Semaphore(self._max_concurrent_sessions) + self._semaphore_loop = loop + async with self._semaphore: + return await self._run_session( + prompts=prompts, + raw_prompt=raw_prompt, + sample_index=sample_index, + session_id=session_id, + runner_kwargs=runner_kwargs, + ) + async def _run_session( self, *, @@ -407,7 +429,8 @@ async def _run_session( session_id: str | None = None, runner_kwargs: dict[str, object] | None = None, ) -> tuple[list[Trajectory], dict[str, object]]: - session_id = session_id or self._build_session_id(prompts=prompts, sample_index=sample_index) + """Run one gateway session lifecycle and return finalized trajectories.""" + session_id = session_id or f"session-{sample_index}-0-{uuid4().hex}" sample_fields = self._extract_sample_fields(prompts=prompts, sample_index=sample_index) session = await self.session_runtime.create_session(session_id) try: @@ -426,41 +449,50 @@ async def _run_session( # Score the session's trajectories immediately after finalization, # consistent with VERL's per-sample reward path. - if self.reward_fn is None: + if not self.reward_loop_worker_handles or not session_trajectories: return session_trajectories, sample_fields - normalized_scores = await self._score_trajectories(session_trajectories, sample_fields) - return ( - [ - replace(traj, reward_score=score) - for traj, score in zip(session_trajectories, normalized_scores, strict=True) - ], - sample_fields, - ) + annotations = await self._score_trajectories(session_trajectories, sample_fields) + scored_trajectories = [] + for traj, (score, extra) in zip(session_trajectories, annotations, strict=True): + scored_trajectories.append( + replace( + traj, + reward_score=score, + extra_fields={**traj.extra_fields, "reward_extra_info": extra}, + ) + ) + return scored_trajectories, sample_fields async def _score_trajectories( self, session_trajectories: list[Trajectory], sample_fields: dict[str, object], - ) -> list[float]: - assert self.reward_fn is not None - ctx = SessionRewardContext(trajectories=session_trajectories, sample_fields=sample_fields) - scores = self.reward_fn(ctx) - if inspect.isawaitable(scores): - scores = await scores - if len(scores) != len(session_trajectories): + ) -> list[tuple[float, dict[str, object]]]: + """Score the session's final trajectory and broadcast (score, extra_info) to all. + + Mirrors AgentLoopWorkerTQ._agent_loop_postprocess + (verl/trainer/main_ppo_sync.py:353-396): only the final trajectory (the + session's last interaction segment) is dispatched to RewardLoopWorker; + its score + reward_extra_info are then broadcast to every trajectory in + the session. Subclasses can override this method to implement custom + session-to-trajectory scoring policies. + """ + assert self.reward_loop_worker_handles is not None + assert session_trajectories, "expected non-empty session_trajectories" + + final_trajectory = session_trajectories[-1] + data = _trajectory_to_reward_dataproto(final_trajectory, sample_fields) + worker = random.choice(self.reward_loop_worker_handles) + result = await worker.compute_score.remote(data) + + if "reward_score" not in result: raise ValueError( - f"reward_fn returned {len(scores)} scores for {len(session_trajectories)} trajectories" + f"RewardLoopWorker result missing 'reward_score' key for uid={sample_fields.get('uid')}" ) - normalized_scores: list[float] = [] - for _, score in zip(session_trajectories, scores, strict=True): - if score is None: - raise ValueError( - "reward_fn must return a score for every trajectory; " - f"got None for uid={sample_fields.get('uid')}" - ) - normalized_scores.append(float(score)) - return normalized_scores + score = float(result["reward_score"]) + extra = dict(result.get("reward_extra_info") or {}) + return [(score, extra)] * len(session_trajectories) def _extract_sample_fields(self, *, prompts: TensorDict, sample_index: int) -> dict[str, object]: sample_fields = {} @@ -478,7 +510,7 @@ async def _write_session_trajectories_to_tq( self, *, uid: str, - session_id: int, + session_index: int, trajectories: list[Trajectory], sample_fields: dict[str, object], global_steps: int, @@ -491,10 +523,10 @@ async def _write_session_trajectories_to_tq( field, tag = self._trajectory_to_tq_field_and_tag( trajectory=trajectory, sample_fields=sample_fields, - session_id=session_id, + session_index=session_index, global_steps=global_steps, ) - keys.append(f"{uid}_{session_id}_{index}") + keys.append(f"{uid}_{session_index}_{index}") fields.append(field) tags.append(tag) @@ -510,12 +542,12 @@ def _trajectory_to_tq_field_and_tag( *, trajectory: Trajectory, sample_fields: dict[str, object], - session_id: int, + session_index: int, global_steps: int, ) -> tuple[dict[str, object], dict[str, object]]: - prompts = _to_long_tensor(trajectory.prompt_ids) - responses = _to_long_tensor(trajectory.response_ids) - response_mask = _to_long_tensor(trajectory.response_mask) + prompts = torch.tensor(trajectory.prompt_ids, dtype=torch.long) + responses = torch.tensor(trajectory.response_ids, dtype=torch.long) + response_mask = torch.tensor(trajectory.response_mask, dtype=torch.long) input_ids = torch.cat([prompts, responses], dim=0) attention_mask = torch.ones_like(input_ids, dtype=torch.long) multi_modal_inputs = compute_multi_modal_inputs( @@ -544,25 +576,26 @@ def _trajectory_to_tq_field_and_tag( "multi_modal_inputs": multi_modal_inputs, } if trajectory.response_logprobs is not None: - field["rollout_log_probs"] = _to_float_tensor(trajectory.response_logprobs) + field["rollout_log_probs"] = torch.tensor(trajectory.response_logprobs, dtype=torch.float32) + else: + field["rollout_log_probs"] = torch.zeros_like(responses, dtype=torch.float32) if trajectory.routed_experts is not None: field["routed_experts"] = ( torch.from_numpy(trajectory.routed_experts.copy()) if hasattr(trajectory.routed_experts, "copy") and not isinstance(trajectory.routed_experts, torch.Tensor) else trajectory.routed_experts ) - if trajectory.reward_score is not None: - rm_scores = torch.zeros_like(responses, dtype=torch.float32) - if responses.numel() > 0: - rm_scores[-1] = float(trajectory.reward_score) - field["rm_scores"] = rm_scores + rm_scores = torch.zeros_like(responses, dtype=torch.float32) + if trajectory.reward_score is not None and responses.numel() > 0: + rm_scores[-1] = float(trajectory.reward_score) + field["rm_scores"] = rm_scores field.update(trajectory.extra_fields) field.pop("multi_modal_data", None) for key in ("uid", "raw_prompt", "data_source", "reward_model", "extra_info", "tools_kwargs", "agent_name"): if key in sample_fields: field[key] = sample_fields[key] - field["session_id"] = session_id + field["session_id"] = session_index field["global_steps"] = global_steps field["num_turns"] = torch.tensor(int(trajectory.num_turns), dtype=torch.long) @@ -577,5 +610,3 @@ def _trajectory_to_tq_field_and_tag( } return field, tag - def _build_session_id(self, prompts: TensorDict, sample_index: int, session_index: int = 0) -> str: - return f"session-{sample_index}-{session_index}-{uuid4().hex}" diff --git a/uni_agent/trainer/framework/helpers.py b/uni_agent/trainer/framework/helpers.py deleted file mode 100644 index 29e0160..0000000 --- a/uni_agent/trainer/framework/helpers.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from dataclasses import replace -from typing import Any - -import numpy as np -import torch - -from .types import Trajectory - - -def _resolve_trajectory_value(value: Any, index: int, count: int) -> Any: - if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray, list)): - # Keep tuple-like values broadcastable only when they are not trajectory-aligned containers. - return value - if isinstance(value, list): - if len(value) != count: - raise ValueError(f"reward_info sequence length must match trajectories: {len(value)} != {count}") - return value[index] - return value - - -def normalize_trajectory_rewards( - trajectories: Sequence[Trajectory], - reward_info: Mapping[str, Any] | None = None, -) -> list[Trajectory]: - normalized: list[Trajectory] = [] - count = len(trajectories) - - for index, trajectory in enumerate(trajectories): - merged_reward_info = dict(trajectory.reward_info) - if reward_info is not None: - for key, value in reward_info.items(): - merged_reward_info[key] = _resolve_trajectory_value(value, index=index, count=count) - - if trajectory.reward_score is None: - raise ValueError( - f"Trajectory at index {index} has no reward_score. " - "reward_fn must return a score for every trajectory." - ) - - normalized.append(replace(trajectory, reward_info=merged_reward_info)) - - return normalized - - -def validate_trajectory(trajectory: Trajectory) -> Trajectory: - if len(trajectory.response_ids) != len(trajectory.response_mask): - raise ValueError("response_mask length must match response_ids length") - - if trajectory.response_logprobs is not None and len(trajectory.response_logprobs) != len(trajectory.response_ids): - raise ValueError("response_logprobs length must match response_ids length") - - if trajectory.num_turns < 0: - raise ValueError("num_turns must be non-negative") - if trajectory.routed_experts is not None: - if isinstance(trajectory.routed_experts, np.ndarray): - routed_experts = trajectory.routed_experts - elif isinstance(trajectory.routed_experts, torch.Tensor): - routed_experts = trajectory.routed_experts - else: - raise TypeError(f"Unsupported routed_experts type: {type(trajectory.routed_experts)}") - - if routed_experts.ndim != 3: - raise ValueError("routed_experts must have shape [total_tokens, num_layers, topk]") - expected_length = len(trajectory.prompt_ids) + len(trajectory.response_ids) - if routed_experts.shape[0] != expected_length: - raise ValueError("routed_experts token dimension must match prompt_ids + response_ids") - - return trajectory diff --git a/uni_agent/trainer/framework/types.py b/uni_agent/trainer/framework/types.py index 9c5545b..2abf18a 100644 --- a/uni_agent/trainer/framework/types.py +++ b/uni_agent/trainer/framework/types.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import Any, Protocol @@ -28,27 +27,6 @@ class Trajectory: extra_fields: dict[str, Any] = field(default_factory=dict) -@dataclass -class SessionRewardContext: - """Context passed to ``reward_fn`` after a session is finalized. - - A single session may produce multiple trajectories (e.g. when the agent - switches conversation context mid-session). ``reward_fn`` receives all of - them together so the implementor can choose the session-to-trajectory - scoring policy, but it must return one score per trajectory. - - ``sample_fields`` carries per-sample dataset fields (``data_source``, - ``reward_model.ground_truth``, ``extra_info``, ...) — the same dict that - ``AgentLoopWorker._compute_score`` forwards as ``kwargs`` to the reward - worker. - """ - - trajectories: list[Trajectory] - sample_fields: dict[str, Any] = field(default_factory=dict) - -RewardFn = Callable[[SessionRewardContext], Awaitable[list[float]] | list[float]] - - class SessionRuntime(Protocol): """Protocol for gateway-backed session lifecycle. @@ -58,6 +36,7 @@ class SessionRuntime(Protocol): """ async def create_session(self, session_id: str, **kwargs) -> SessionHandle: ... + async def complete_session(self, session_id: str) -> None: ... async def finalize_session(self, session_id: str) -> list[Trajectory]: ... async def abort_session(self, session_id: str) -> None: ... async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: ... diff --git a/uni_agent/trainer/gateway/gateway.py b/uni_agent/trainer/gateway/gateway.py index 46009d5..a46a926 100644 --- a/uni_agent/trainer/gateway/gateway.py +++ b/uni_agent/trainer/gateway/gateway.py @@ -11,8 +11,8 @@ from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse -from verl.agent.framework.types import SessionHandle, Trajectory -from verl.agent.gateway.types import GatewaySessionState, SessionPhase, TrajectoryBuffer +from uni_agent.trainer.framework.types import SessionHandle, Trajectory +from uni_agent.trainer.gateway.types import GatewaySessionState, SessionPhase, TrajectoryBuffer from verl.experimental.agent_loop.tool_parser import ToolParser from verl.utils.chat_template import apply_chat_template as _apply_chat_template, initialize_system_prompt from verl.utils.tokenizer import normalize_token_ids @@ -31,6 +31,33 @@ class MalformedRequestError(ValueError): }) +# Map backend stop_reason values to OpenAI-spec finish_reason values. +# OpenAI Chat Completions spec defines finish_reason ∈ +# {"stop", "length", "tool_calls", "content_filter", "function_call"}. +# +# Note on vLLM information loss: the vLLM rollout adapter +# (verl/workers/rollout/vllm_rollout/vllm_async_server.py:538-545) +# collapses vLLM's raw finish_reason "stop" and "length" into a single +# "completed" stop_reason before the gateway sees it. As a result, +# mapping "completed" -> "stop" here cannot recover whether generation +# actually hit max_tokens; recovering that distinction requires the +# vLLM adapter to preserve the raw finish_reason on TokenOutput. +# TODO(phase-c): preserve raw backend finish_reason on TokenOutput +# (e.g. a new TokenOutput.finish_reason field or +# extra_fields["finish_reason"]) so the gateway can distinguish vLLM +# "length" from "stop" instead of mapping both to "stop". +_FINISH_REASON_MAP = { + "completed": "stop", + "stop": "stop", + "matched_stop": "stop", + "eos": "stop", + "length": "length", + "max_tokens": "length", + "aborted": "stop", + "abort": "stop", +} + + # TODO: double-check if all these validations/normalization are necessary # Make sure they don't alter messages in unexpected ways. def _normalize_message_content(content: Any) -> Any: @@ -84,15 +111,12 @@ def _normalize_message(message: Any) -> dict[str, Any]: return normalized -def _validate_tools(tools: Any) -> list[dict[str, Any]] | None: +def _validate_tools(tools: Any) -> list[Any] | None: """Validate tools structure. Does not modify content.""" if tools is None: return None if not isinstance(tools, list): raise MalformedRequestError("tools must be a list") - for tool in tools: - if not isinstance(tool, dict): - raise MalformedRequestError("tools entries must be objects") return tools @@ -223,7 +247,6 @@ def __init__( self, tokenizer, backend, - host: str | None = None, *, processor=None, vision_info_extractor=None, @@ -234,8 +257,8 @@ def __init__( allowed_request_sampling_param_keys: set[str] | frozenset[str] | None = None, ): # Same pattern as vllm_async_server.py / async_sglang_server.py: - # use the node's routable IP for both bind and URL by default. - self._server_address = host if host is not None else ray.util.get_node_ip_address() + # use the node's routable IP for both bind and URL. + self._server_address = ray.util.get_node_ip_address() self._tokenizer = tokenizer self._processor = processor self._backend = backend @@ -460,10 +483,20 @@ async def _decode_response( Returns: message: OpenAI-compatible assistant message. - finish_reason: "tool_calls" when tool calls are present, else stop_reason or "stop". + finish_reason: "tool_calls" when tool calls are present, else the + OpenAI-spec-normalized stop_reason (see _FINISH_REASON_MAP). """ if self._tool_parser is not None and tools: - content, function_calls = await self._tool_parser.extract_tool_calls(response_ids) + parsed_tools = None + try: + from verl.tools.schemas import OpenAIFunctionToolSchema + parsed_tools = [ + OpenAIFunctionToolSchema(**t) if isinstance(t, dict) else t + for t in tools + ] + except Exception: + pass + content, function_calls = await self._tool_parser.extract_tool_calls(response_ids, parsed_tools) if function_calls: tool_calls = [ { @@ -483,7 +516,8 @@ async def _decode_response( } return message, "tool_calls" response_text = self._tokenizer.decode(response_ids, skip_special_tokens=True) - return {"role": "assistant", "content": response_text}, stop_reason or "stop" + finish_reason = _FINISH_REASON_MAP.get(stop_reason, stop_reason) if stop_reason else "stop" + return {"role": "assistant", "content": response_text}, finish_reason async def _handle_chat_completions(self, session_id: str, payload: dict[str, Any]) -> JSONResponse: session = self._sessions.get(session_id) diff --git a/uni_agent/trainer/gateway/manager.py b/uni_agent/trainer/gateway/manager.py index 17f038c..c439c7e 100644 --- a/uni_agent/trainer/gateway/manager.py +++ b/uni_agent/trainer/gateway/manager.py @@ -46,6 +46,10 @@ async def finalize_session(self, session_id: str): self.active_sessions_per_gateway[gateway_index] -= 1 return trajectories + async def complete_session(self, session_id: str) -> None: + gateway, _ = self._get_gateway(session_id) + await _await_object_ref(gateway.complete_session.remote(session_id=session_id)) + async def abort_session(self, session_id: str) -> None: gateway, gateway_index = self._get_gateway(session_id) await _await_object_ref(gateway.abort_session.remote(session_id=session_id)) diff --git a/uni_agent/trainer/gateway/runtime.py b/uni_agent/trainer/gateway/runtime.py index 058c207..3010ce9 100644 --- a/uni_agent/trainer/gateway/runtime.py +++ b/uni_agent/trainer/gateway/runtime.py @@ -39,14 +39,32 @@ def _initialize_gateway_runtime( gateway_count: int, gateway_actor_kwargs: dict[str, Any] | None = None, ) -> None: - from verl.agent.gateway.gateway import GatewayActor - from verl.agent.gateway.manager import GatewayManager + from uni_agent.trainer.gateway.gateway import GatewayActor + from uni_agent.trainer.gateway.manager import GatewayManager gateway_actor_kwargs = dict(gateway_actor_kwargs or {}) if "backend" not in gateway_actor_kwargs: gateway_actor_kwargs["backend"] = self - self.owned_gateway_actors = [GatewayActor.remote(**gateway_actor_kwargs) for _ in range(gateway_count)] + # Round-robin across alive CPU nodes so gateway actors do not all pack onto + # the driver node under Ray's default PACK scheduling. Mirrors + # AgentLoopWorker placement (verl/experimental/agent_loop/agent_loop.py). + node_ids = [ + node["NodeID"] + for node in ray.nodes() + if node["Alive"] and node["Resources"].get("CPU", 0) > 0 + ] + if not node_ids: + raise RuntimeError("No alive CPU nodes available for GatewayActor placement") + + self.owned_gateway_actors = [ + GatewayActor.options( + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=node_ids[i % len(node_ids)], soft=True, + ), + ).remote(**gateway_actor_kwargs) + for i in range(gateway_count) + ] ray.get([gateway.start.remote() for gateway in self.owned_gateway_actors]) self.gateway_manager = GatewayManager(self.owned_gateway_actors) @@ -63,6 +81,10 @@ async def finalize_session(self, session_id: str): gateway_manager = self._require_session_runtime() return await gateway_manager.finalize_session(session_id=session_id) + async def complete_session(self, session_id: str) -> None: + gateway_manager = self._require_session_runtime() + await gateway_manager.complete_session(session_id=session_id) + async def abort_session(self, session_id: str) -> None: gateway_manager = self._require_session_runtime() await gateway_manager.abort_session(session_id=session_id) diff --git a/uni_agent/trainer/gateway/types.py b/uni_agent/trainer/gateway/types.py index b690ab7..9a0f4c2 100644 --- a/uni_agent/trainer/gateway/types.py +++ b/uni_agent/trainer/gateway/types.py @@ -6,7 +6,7 @@ from enum import Enum from typing import Any -from verl.agent.framework.types import SessionHandle, Trajectory +from uni_agent.trainer.framework.types import SessionHandle, Trajectory class SessionPhase(str, Enum): From ef462658de8231dfd633a7fc472880de73c94eff Mon Sep 17 00:00:00 2001 From: zackcxb Date: Thu, 21 May 2026 07:55:52 +0000 Subject: [PATCH 03/22] complete session interface takes in reward_info --- uni_agent/trainer/framework/types.py | 2 +- uni_agent/trainer/gateway/manager.py | 5 +++-- uni_agent/trainer/gateway/runtime.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/uni_agent/trainer/framework/types.py b/uni_agent/trainer/framework/types.py index 2abf18a..be94c0c 100644 --- a/uni_agent/trainer/framework/types.py +++ b/uni_agent/trainer/framework/types.py @@ -36,7 +36,7 @@ class SessionRuntime(Protocol): """ async def create_session(self, session_id: str, **kwargs) -> SessionHandle: ... - async def complete_session(self, session_id: str) -> None: ... + async def complete_session(self, session_id: str, reward_info: dict[str, Any] | None = None) -> None: ... async def finalize_session(self, session_id: str) -> list[Trajectory]: ... async def abort_session(self, session_id: str) -> None: ... async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: ... diff --git a/uni_agent/trainer/gateway/manager.py b/uni_agent/trainer/gateway/manager.py index c439c7e..84669d2 100644 --- a/uni_agent/trainer/gateway/manager.py +++ b/uni_agent/trainer/gateway/manager.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from typing import Any async def _await_object_ref(object_ref): @@ -46,9 +47,9 @@ async def finalize_session(self, session_id: str): self.active_sessions_per_gateway[gateway_index] -= 1 return trajectories - async def complete_session(self, session_id: str) -> None: + async def complete_session(self, session_id: str, reward_info: dict[str, Any] | None = None) -> None: gateway, _ = self._get_gateway(session_id) - await _await_object_ref(gateway.complete_session.remote(session_id=session_id)) + await _await_object_ref(gateway.complete_session.remote(session_id=session_id, reward_info=reward_info)) async def abort_session(self, session_id: str) -> None: gateway, gateway_index = self._get_gateway(session_id) diff --git a/uni_agent/trainer/gateway/runtime.py b/uni_agent/trainer/gateway/runtime.py index 3010ce9..f4250d2 100644 --- a/uni_agent/trainer/gateway/runtime.py +++ b/uni_agent/trainer/gateway/runtime.py @@ -81,9 +81,9 @@ async def finalize_session(self, session_id: str): gateway_manager = self._require_session_runtime() return await gateway_manager.finalize_session(session_id=session_id) - async def complete_session(self, session_id: str) -> None: + async def complete_session(self, session_id: str, reward_info: dict[str, Any] | None = None) -> None: gateway_manager = self._require_session_runtime() - await gateway_manager.complete_session(session_id=session_id) + await gateway_manager.complete_session(session_id=session_id, reward_info=reward_info) async def abort_session(self, session_id: str) -> None: gateway_manager = self._require_session_runtime() From 80f05a1dda72264ab3a0072f39c2e37f78f4e64b Mon Sep 17 00:00:00 2001 From: zackcxb Date: Sun, 24 May 2026 14:29:47 +0000 Subject: [PATCH 04/22] feat(trainer): migrate framework gateway tests Co-Authored-By: Claude Opus 4.6 --- tests/__init__.py | 1 + tests/uni_agent/__init__.py | 1 + tests/uni_agent/trainer/__init__.py | 1 + .../test_generate_sequences_on_cpu.py | 495 ++++++++ .../test_multi_modal_postprocess_on_cpu.py | 67 ++ .../gateway/test_gateway_actor_on_cpu.py | 1019 +++++++++++++++++ .../gateway/test_gateway_manager_on_cpu.py | 106 ++ .../gateway/test_session_runtime_on_cpu.py | 158 +++ tests/uni_agent/trainer/support.py | 439 +++++++ verl | 2 +- 10 files changed, 2288 insertions(+), 1 deletion(-) create mode 100644 tests/__init__.py create mode 100644 tests/uni_agent/__init__.py create mode 100644 tests/uni_agent/trainer/__init__.py create mode 100644 tests/uni_agent/trainer/framework/test_generate_sequences_on_cpu.py create mode 100644 tests/uni_agent/trainer/framework/test_multi_modal_postprocess_on_cpu.py create mode 100644 tests/uni_agent/trainer/gateway/test_gateway_actor_on_cpu.py create mode 100644 tests/uni_agent/trainer/gateway/test_gateway_manager_on_cpu.py create mode 100644 tests/uni_agent/trainer/gateway/test_session_runtime_on_cpu.py create mode 100644 tests/uni_agent/trainer/support.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/uni_agent/__init__.py b/tests/uni_agent/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/uni_agent/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/uni_agent/trainer/__init__.py b/tests/uni_agent/trainer/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/uni_agent/trainer/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/uni_agent/trainer/framework/test_generate_sequences_on_cpu.py b/tests/uni_agent/trainer/framework/test_generate_sequences_on_cpu.py new file mode 100644 index 0000000..e95b425 --- /dev/null +++ b/tests/uni_agent/trainer/framework/test_generate_sequences_on_cpu.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import pytest + +from uni_agent.trainer.framework.types import SessionHandle, Trajectory +from verl.utils import tensordict_utils as tu + + +class _FakeTransferQueue: + def __init__(self): + self.puts = [] + self.batch_puts = [] + + async def async_kv_put(self, *, key, partition_id, tag): + self.puts.append({"key": key, "partition_id": partition_id, "tag": dict(tag)}) + + async def async_kv_batch_put(self, *, keys, fields, tags, partition_id): + self.batch_puts.append( + { + "keys": list(keys), + "fields": fields, + "tags": [dict(tag) for tag in tags], + "partition_id": partition_id, + } + ) + + +class _FakeReplayBuffer: + def __init__(self): + self.adds = [] + + def add(self, partition_id, items): + self.adds.append({"partition_id": partition_id, "items": dict(items)}) + + +class _FakeSessionRuntime: + """Fake runtime that matches session IDs by prefix (``session-{sample}-{session}``) + to support the real uuid-suffixed IDs produced by the framework.""" + + def __init__(self, finalized_by_session_prefix: dict[str, list[Trajectory]]): + self._finalized_by_prefix = finalized_by_session_prefix + self.created_sessions = [] + self.finalized_sessions = [] + self.aborted_sessions = [] + + def _lookup(self, session_id: str) -> list[Trajectory]: + for prefix, trajectories in self._finalized_by_prefix.items(): + if session_id.startswith(prefix): + return trajectories + raise KeyError(f"No prefix match for session_id={session_id}") + + async def create_session(self, session_id: str, **kwargs): + self.created_sessions.append(session_id) + return SessionHandle(session_id=session_id, base_url=f"http://fake/{session_id}/v1") + + async def finalize_session(self, session_id: str): + self.finalized_sessions.append(session_id) + return self._lookup(session_id) + + async def abort_session(self, session_id: str) -> None: + self.aborted_sessions.append(session_id) + + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: + return None + + +def _build_prompts(count: int = 2, *, global_steps: int = 7, validate: bool = False): + non_tensor_dict = {"global_steps": global_steps} + if validate: + non_tensor_dict["validate"] = True + return tu.get_tensordict( + tensor_dict={ + "raw_prompt": [[{"role": "user", "content": f"sample {i}"}] for i in range(count)], + "uid": [f"uid-{i}" for i in range(count)], + "data_source": ["deepeyes"] * count, + "reward_model": [{"ground_truth": f"answer-{i}"} for i in range(count)], + "extra_info": [{"index": i} for i in range(count)], + "tools_kwargs": [{"tool": i} for i in range(count)], + "agent_name": ["deepeyes"] * count, + }, + non_tensor_dict=non_tensor_dict, + ) + + +def _trajectory( + *, + prompt_ids: list[int] | None = None, + response_ids: list[int] | None = None, + response_logprobs: list[float] | None = None, + num_turns: int = 2, +): + prompt_ids = prompt_ids or [10, 11] + response_ids = response_ids or [20, 21] + return Trajectory( + prompt_ids=prompt_ids, + response_ids=response_ids, + response_mask=[1] * len(response_ids), + response_logprobs=response_logprobs, + reward_score=None, + num_turns=num_turns, + multi_modal_data={"images": ["raw-image-should-not-be-written"]}, + ) + + +def _install_fake_score(monkeypatch, *, score_from_sample_fields=None, default_score=1.0): + """Replace OpenAICompatibleAgentFramework._score_trajectories with a fake. + + Mirrors the production "score-last + broadcast" behavior: returns the same + (score, extra_info) for every trajectory in the session. The score is + derived from sample_fields if a callable is provided; otherwise default_score. + """ + from uni_agent.trainer.framework.framework import OpenAICompatibleAgentFramework + + async def fake_score(self, trajectories, sample_fields): + if score_from_sample_fields is not None: + score = float(score_from_sample_fields(sample_fields)) + else: + score = float(default_score) + return [(score, {})] * len(trajectories) + + monkeypatch.setattr(OpenAICompatibleAgentFramework, "_score_trajectories", fake_score) + + +@pytest.mark.asyncio +async def test_generate_sequences_writes_tq_schema_for_each_session(monkeypatch): + from uni_agent.trainer.framework import framework as framework_module + from uni_agent.trainer.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + + runtime = _FakeSessionRuntime( + { + "session-0-0": [_trajectory(response_logprobs=[-0.1, -0.2])], + "session-0-1": [_trajectory(response_logprobs=[-0.3, -0.4])], + "session-1-0": [_trajectory(response_logprobs=[-0.5, -0.6])], + "session-1-1": [_trajectory(response_logprobs=[-0.7, -0.8])], + } + ) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + assert raw_prompt == [{"role": "user", "content": f"sample {sample_index}"}] + assert tools_kwargs == {"tool": sample_index} + + # Score derived from sample_fields["extra_info"]["index"] + 0.25 (same as legacy lambda) + _install_fake_score( + monkeypatch, + score_from_sample_fields=lambda sf: sf["extra_info"]["index"] + 0.25, + ) + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + reward_loop_worker_handles=["sentinel"], + replay_buffer=replay_buffer, + rollout_config={"n": 2, "val_kwargs": {"n": 2}}, + ) + + + await framework.generate_sequences(_build_prompts(global_steps=7)) + + assert replay_buffer.adds == [ + { + "partition_id": "train", + "items": { + "uid-0": {"global_steps": 7, "status": "running"}, + "uid-1": {"global_steps": 7, "status": "running"}, + }, + } + ] + assert fake_tq.batch_puts[0]["keys"] == ["uid-0_0_0"] + assert fake_tq.batch_puts[1]["keys"] == ["uid-0_1_0"] + assert fake_tq.batch_puts[2]["keys"] == ["uid-1_0_0"] + assert fake_tq.batch_puts[3]["keys"] == ["uid-1_1_0"] + assert fake_tq.puts == [ + {"key": "uid-0", "partition_id": "train", "tag": {"status": "finished"}}, + {"key": "uid-1", "partition_id": "train", "tag": {"status": "finished"}}, + ] + + first = fake_tq.batch_puts[0] + fields = first["fields"] + assert first["partition_id"] == "train" + assert first["tags"] == [ + {"global_steps": 7, "status": "success", "prompt_len": 2, "response_len": 2, "seq_len": 4} + ] + assert fields["input_ids"].is_nested + assert fields["response_mask"].is_nested + assert fields["position_ids"].is_nested + assert fields["prompts"][0].tolist() == [10, 11] + assert fields["responses"][0].tolist() == [20, 21] + assert fields["response_mask"][0].tolist() == [1, 1] + assert fields["loss_mask"][0].tolist() == [1, 1] + assert fields["input_ids"][0].tolist() == [10, 11, 20, 21] + assert fields["attention_mask"][0].tolist() == [1, 1, 1, 1] + assert fields["position_ids"][0].tolist() == [0, 1, 2, 3] + assert fields["rollout_log_probs"][0].tolist() == pytest.approx([-0.1, -0.2]) + assert fields["rm_scores"][0].tolist() == [0.0, 0.25] + assert tu.get(fields, "multi_modal_inputs") == [{}] + assert tu.get(fields, "uid") == ["uid-0"] + assert tu.get(fields, "raw_prompt") == [[{"role": "user", "content": "sample 0"}]] + assert tu.get(fields, "data_source") == ["deepeyes"] + assert tu.get(fields, "reward_model") == [{"ground_truth": "answer-0"}] + assert tu.get(fields, "extra_info") == [{"index": 0}] + assert tu.get(fields, "tools_kwargs") == [{"tool": 0}] + assert tu.get(fields, "agent_name") == ["deepeyes"] + assert tu.get(fields, "session_id") == [0] + assert tu.get(fields, "global_steps") == [7] + assert fields["num_turns"].tolist() == [2] + assert "multi_modal_data" not in fields.keys() + + +@pytest.mark.asyncio +async def test_generate_sequences_keeps_successful_sessions_when_one_session_fails(monkeypatch): + from uni_agent.trainer.framework import framework as framework_module + from uni_agent.trainer.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + runtime = _FakeSessionRuntime( + { + "session-0-0": [_trajectory()], + "session-0-1": [_trajectory()], + } + ) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + if session.session_id.startswith("session-0-1-"): + raise RuntimeError("gateway failed once") + + _install_fake_score(monkeypatch, default_score=1.0) + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + reward_loop_worker_handles=["sentinel"], + replay_buffer=replay_buffer, + rollout_config={"n": 2, "val_kwargs": {"n": 2}}, + ) + + + await framework.generate_sequences(_build_prompts(count=1, global_steps=8)) + + assert replay_buffer.adds == [ + {"partition_id": "train", "items": {"uid-0": {"global_steps": 8, "status": "running"}}} + ] + assert fake_tq.batch_puts[0]["keys"] == ["uid-0_0_0"] + assert fake_tq.puts == [{"key": "uid-0", "partition_id": "train", "tag": {"status": "finished"}}] + assert len(runtime.aborted_sessions) == 1 + assert runtime.aborted_sessions[0].startswith("session-0-1-") + + +@pytest.mark.asyncio +async def test_generate_sequences_marks_prompt_failure_when_all_sessions_fail(monkeypatch): + from uni_agent.trainer.framework import framework as framework_module + from uni_agent.trainer.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + runtime = _FakeSessionRuntime({"session-0-0": [], "session-0-1": []}) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + raise RuntimeError(f"failed {session.session_id}") + + _install_fake_score(monkeypatch, default_score=1.0) + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + reward_loop_worker_handles=["sentinel"], + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 2}}, + ) + + + with pytest.raises(RuntimeError, match="All rollouts failed at global_steps=9"): + await framework.generate_sequences(_build_prompts(count=1, global_steps=9, validate=True)) + + assert replay_buffer.adds == [ + {"partition_id": "val", "items": {"uid-0": {"global_steps": 9, "status": "running"}}} + ] + assert fake_tq.batch_puts == [] + assert fake_tq.puts == [{"key": "uid-0", "partition_id": "val", "tag": {"status": "failure"}}] + + +@pytest.mark.asyncio +async def test_generate_sequences_zero_fills_rm_scores_when_no_reward_handles(monkeypatch): + from uni_agent.trainer.framework import framework as framework_module + from uni_agent.trainer.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + runtime = _FakeSessionRuntime({"session-0-0": [_trajectory()]}) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + return None + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + ) + + + await framework.generate_sequences(_build_prompts(count=1, global_steps=10)) + + # rm_scores is always written (zero-filled when no reward) so the trainer's + # KVBatchMeta select_fields never hits a missing field across the batch. + rm_scores = fake_tq.batch_puts[0]["fields"]["rm_scores"] + assert rm_scores[0].tolist() == [0.0, 0.0] + + +@pytest.mark.asyncio +async def test_generate_sequences_keeps_other_prompts_when_prompt_task_raises(monkeypatch, caplog): + from uni_agent.trainer.framework.framework import OpenAICompatibleAgentFramework + + replay_buffer = _FakeReplayBuffer() + runtime = _FakeSessionRuntime( + { + "session-1-0": [_trajectory()], + } + ) + + _install_fake_score(monkeypatch, default_score=1.0) + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=lambda **_: None, + reward_loop_worker_handles=["sentinel"], + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + ) + + async def fake_run_prompt_sessions_to_tq(*, sample_index, **kwargs): + if sample_index == 0: + raise RuntimeError("prompt 0 exploded") + return { + "num_success_sessions": 1, + "num_failed_sessions": 0, + "num_success_outputs": 1, + "num_failed_uids": 0, + "failure_reasons": [], + } + + monkeypatch.setattr(framework, "_run_prompt_sessions_to_tq", fake_run_prompt_sessions_to_tq) + + caplog.set_level("INFO") + await framework.generate_sequences(_build_prompts(count=2, global_steps=11)) + + assert replay_buffer.adds == [ + { + "partition_id": "train", + "items": { + "uid-0": {"global_steps": 11, "status": "running"}, + "uid-1": {"global_steps": 11, "status": "running"}, + }, + } + ] + assert "num_failed_uids=1" in caplog.text + assert "prompt 0 exploded" in caplog.text + + +@pytest.mark.asyncio +async def test_generate_sequences_zero_fills_rollout_log_probs_when_missing(monkeypatch): + from uni_agent.trainer.framework import framework as framework_module + from uni_agent.trainer.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + # Trajectory without response_logprobs (e.g. backend returned no logprobs). + runtime = _FakeSessionRuntime({"session-0-0": [_trajectory(response_logprobs=None)]}) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + return None + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + ) + + + await framework.generate_sequences(_build_prompts(count=1, global_steps=10)) + + # rollout_log_probs is zero-filled rather than omitted so the trainer's + # bypass-mode select_fields(["rollout_log_probs"]) never KeyErrors. + rollout_log_probs = fake_tq.batch_puts[0]["fields"]["rollout_log_probs"] + assert rollout_log_probs[0].tolist() == [0.0, 0.0] + + +@pytest.mark.asyncio +async def test_max_concurrent_sessions_caps_in_flight_sessions(monkeypatch): + import asyncio + + from uni_agent.trainer.framework import framework as framework_module + from uni_agent.trainer.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + runtime = _FakeSessionRuntime( + {f"session-{i}-0": [_trajectory()] for i in range(4)} + ) + + in_flight = 0 + max_observed = 0 + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + nonlocal in_flight, max_observed + in_flight += 1 + max_observed = max(max_observed, in_flight) + await asyncio.sleep(0.01) + in_flight -= 1 + return None + + _install_fake_score(monkeypatch, default_score=1.0) + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + max_concurrent_sessions=2, + ) + + + await framework.generate_sequences(_build_prompts(count=4, global_steps=10)) + + assert max_observed <= 2 + + +# --------------------------------------------------------------------------- +# _score_trajectories method-level tests +# --------------------------------------------------------------------------- + + +@pytest.fixture +def ray_runtime(): + import ray + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.mark.asyncio +async def test_score_trajectories_dispatches_only_final_trajectory_and_broadcasts(ray_runtime): + """_score_trajectories scores trajectories[-1] only, broadcasts to all (matches AgentLoopWorkerTQ).""" + import ray as ray_module + from uni_agent.trainer.framework.framework import OpenAICompatibleAgentFramework + from uni_agent.trainer.framework.types import Trajectory + + @ray_module.remote + class _StubWorker: + def __init__(self): + self.calls = [] + + def compute_score(self, data): + self.calls.append(data) + return {"reward_score": 0.42, "reward_extra_info": {"acc": 1.0, "format": 0.8}} + + def get_call_count(self): + return len(self.calls) + + worker = _StubWorker.remote() + + runtime = _FakeSessionRuntime({}) # not used in this test + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=lambda **_: None, + reward_loop_worker_handles=[worker], + replay_buffer=_FakeReplayBuffer(), + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + ) + + trajectories = [ + Trajectory(prompt_ids=[1, 2], response_ids=[3, 4], response_mask=[1, 1], num_turns=1), + Trajectory(prompt_ids=[5, 6], response_ids=[7, 8], response_mask=[1, 1], num_turns=2), + Trajectory(prompt_ids=[9, 10], response_ids=[11, 12], response_mask=[1, 1], num_turns=3), + ] + sample_fields = {"data_source": "test", "raw_prompt": [{"role": "user", "content": "hi"}]} + annotations = await framework._score_trajectories(trajectories, sample_fields) + + # Score-last + broadcast: 3 trajectories, but only 1 worker call + assert ray_module.get(worker.get_call_count.remote()) == 1 + # All 3 trajectories get the same score and extra_info + assert annotations == [ + (0.42, {"acc": 1.0, "format": 0.8}), + (0.42, {"acc": 1.0, "format": 0.8}), + (0.42, {"acc": 1.0, "format": 0.8}), + ] diff --git a/tests/uni_agent/trainer/framework/test_multi_modal_postprocess_on_cpu.py b/tests/uni_agent/trainer/framework/test_multi_modal_postprocess_on_cpu.py new file mode 100644 index 0000000..351e953 --- /dev/null +++ b/tests/uni_agent/trainer/framework/test_multi_modal_postprocess_on_cpu.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import torch + +from tests.uni_agent.trainer.support import FakeProcessor + + +def test_compute_multi_modal_inputs_returns_empty_dict_without_processor(): + from uni_agent.trainer.framework.multi_modal_postprocess import compute_multi_modal_inputs + + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + + assert compute_multi_modal_inputs(None, input_ids, {"images": ["image://a.png"]}) == {} + + +def test_compute_multi_modal_inputs_returns_image_tensors_and_images_seqlens(): + from uni_agent.trainer.framework.multi_modal_postprocess import compute_multi_modal_inputs + + processor = FakeProcessor() + input_ids = torch.tensor([[11, processor.image_token_id, 12]], dtype=torch.long) + + multi_modal_inputs = compute_multi_modal_inputs( + processor, + input_ids, + {"images": ["image://a.png"]}, + ) + + assert "input_ids" not in multi_modal_inputs + assert "attention_mask" not in multi_modal_inputs + assert tuple(multi_modal_inputs["pixel_values"].shape) == (1, 3, 2, 2) + assert multi_modal_inputs["image_grid_thw"].tolist() == [[1, 2, 3]] + assert multi_modal_inputs["images_seqlens"].tolist() == [6] + assert "mm_token_type_ids" in multi_modal_inputs + + +def test_compute_position_ids_returns_text_shape_without_processor(): + from uni_agent.trainer.framework.multi_modal_postprocess import compute_position_ids + + input_ids = torch.tensor([[7, 8, 9, 10]], dtype=torch.long) + attention_mask = torch.tensor([[0, 1, 1, 1]], dtype=torch.long) + + position_ids = compute_position_ids(None, input_ids, attention_mask, {}) + + assert tuple(position_ids.shape) == (1, 4) + assert position_ids.tolist() == [[0, 0, 1, 2]] + + +def test_compute_position_ids_returns_multimodal_shape_with_processor(): + from uni_agent.trainer.framework.multi_modal_postprocess import compute_position_ids + + processor = FakeProcessor() + input_ids = torch.tensor( + [[11, processor.image_token_id, processor.video_token_id, 12]], + dtype=torch.long, + ) + attention_mask = torch.ones_like(input_ids) + multi_modal_inputs = { + "image_grid_thw": torch.tensor([[1, 2, 3]], dtype=torch.long), + "video_grid_thw": torch.tensor([[1, 3, 4]], dtype=torch.long), + "mm_token_type_ids": torch.ones_like(input_ids), + } + + position_ids = compute_position_ids(processor, input_ids, attention_mask, multi_modal_inputs) + + assert tuple(position_ids.shape) == (1, 4, 4) + assert position_ids[0, 0].tolist() == [0, 1, 2, 3] + assert processor.last_get_rope_index_call["mm_token_type_ids"].tolist() == [[0, 1, 2, 0]] diff --git a/tests/uni_agent/trainer/gateway/test_gateway_actor_on_cpu.py b/tests/uni_agent/trainer/gateway/test_gateway_actor_on_cpu.py new file mode 100644 index 0000000..db2608a --- /dev/null +++ b/tests/uni_agent/trainer/gateway/test_gateway_actor_on_cpu.py @@ -0,0 +1,1019 @@ +import asyncio +import copy +import json + +import httpx +import pytest +import ray + +from tests.uni_agent.trainer.support import ( + FailingBackend, + FakeProcessor, + FakeTokenizer, + InspectingBackend, + InspectingSequencedBackend, + QueuedBackend, + RejectConcurrentSessionBackend, + RejectRequestEnvelopeBackend, + RejectToolsSamplingParamsBackend, + SequencedBackend, + SingleUseVisionInfoExtractor, + SlowBackend, + fake_vision_info_extractor, +) + + +@pytest.fixture +def ray_runtime(): + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.mark.asyncio +async def test_gateway_actor_abort_session_does_not_wait_for_backend_generate(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=SlowBackend(delay_s=1.5), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-abort-during-generate")) + + async with httpx.AsyncClient(timeout=5.0) as client: + request_task = asyncio.create_task( + client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "slow path"}]}, + ) + ) + await asyncio.sleep(0.1) + + abort_ref = actor.abort_session.remote("session-abort-during-generate") + await asyncio.wait_for(asyncio.wrap_future(abort_ref.future()), timeout=0.3) + + request_task.cancel() + try: + await request_task + except (asyncio.CancelledError, httpx.HTTPError): + pass + + ray.get(actor.shutdown.remote()) + + +def test_normalize_request_context_preserves_multimodal_blocks_for_later_extraction(): + from uni_agent.trainer.gateway.gateway import _normalize_request_context + + context = _normalize_request_context( + { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "look"}, + {"type": "image_url", "image_url": {"url": "file://image.png"}}, + ], + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call-1", + "type": "function", + "function": {"name": "search", "arguments": "{\"query\": \"weather\"}"}, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call-1", + "content": [{"type": "text", "text": "sunny"}], + }, + ], + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + } + ) + + assert context["tools"][0]["function"]["name"] == "search" + assert context["messages"][0]["content"][1]["type"] == "image_url" + assert context["messages"][1]["tool_calls"][0]["id"] == "call-1" + assert context["messages"][2]["tool_call_id"] == "call-1" + + +@pytest.mark.asyncio +async def test_gateway_actor_forwards_image_data_on_initial_multimodal_request(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor, _normalize_request_context + + processor = FakeProcessor() + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + processor=processor, + vision_info_extractor=fake_vision_info_extractor, + backend=InspectingBackend(), + ) + ray.get(actor.start.remote()) + + session = ray.get(actor.create_session.remote("session-mm-initial")) + payload = { + "model": "dummy-model", + "temperature": 0.25, + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://a.png"}}, + {"type": "text", "text": "describe this image"}, + ], + } + ], + } + + normalized = _normalize_request_context(payload) + raw_prompt = processor.apply_chat_template( + normalized["messages"], + tokenize=False, + add_generation_prompt=True, + tools=normalized["tools"], + ) + expected_prompt_ids = processor( + text=[raw_prompt], + images=["image://a.png"], + videos=None, + return_tensors="pt", + do_sample_frames=False, + )["input_ids"][0].tolist() + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json=payload, + ) + + trajectories = ray.get(actor.finalize_session.remote("session-mm-initial")) + ray.get(actor.shutdown.remote()) + + assert response.status_code == 200 + backend_request = json.loads(response.json()["choices"][0]["message"]["content"]) + assert backend_request["image_data"] == ["image://a.png"] + assert backend_request["video_data"] is None + assert backend_request["prompt_ids"] == expected_prompt_ids + assert backend_request["sampling_params"] == {"temperature": 0.25} + assert len(trajectories) == 1 + assert trajectories[0].multi_modal_data == {"images": ["image://a.png"]} + + +@pytest.mark.asyncio +async def test_gateway_actor_complete_wait_and_finalize(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["ANSWER: A"])) + ray.get(actor.start.remote()) + + session = ray.get(actor.create_session.remote("session-0")) + wait_ref = actor.wait_for_completion.remote("session-0", timeout=2.0) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [{"role": "user", "content": "Pick label A"}], + }, + ) + assert response.status_code == 200 + assert response.json()["choices"][0]["message"]["content"] == "ANSWER: A" + + complete = await client.post( + f"{session.base_url.removesuffix('/v1')}/complete", + json={"reward_info": {"score": 1.0, "label": "A"}}, + ) + assert complete.status_code == 200 + + ray.get(wait_ref) + trajectories = ray.get(actor.finalize_session.remote("session-0")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + assert trajectories[0].reward_info == {"score": 1.0, "label": "A"} + assert trajectories[0].response_ids + assert all(mask == 1 for mask in trajectories[0].response_mask) + + +@pytest.mark.asyncio +async def test_gateway_actor_continuation_reuses_accumulated_media_context(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + processor=FakeProcessor(), + vision_info_extractor=SingleUseVisionInfoExtractor(), + backend=InspectingBackend(), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-mm-continuation")) + + initial_message = { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://a.png"}}, + {"type": "text", "text": "describe this image"}, + ], + } + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [initial_message], + }, + ) + assert first.status_code == 200 + assistant_message = first.json()["choices"][0]["message"] + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [ + initial_message, + assistant_message, + {"role": "user", "content": "follow up"}, + ], + }, + ) + + trajectories = ray.get(actor.finalize_session.remote("session-mm-continuation")) + ray.get(actor.shutdown.remote()) + + assert second.status_code == 200 + first_call = json.loads(first.json()["choices"][0]["message"]["content"]) + second_call = json.loads(second.json()["choices"][0]["message"]["content"]) + assert first_call["image_data"] == ["image://a.png"] + assert second_call["image_data"] == ["image://a.png"] + assert len(trajectories) == 1 + assert trajectories[0].multi_modal_data == {"images": ["image://a.png"]} + + +@pytest.mark.asyncio +async def test_gateway_actor_multimodal_reference_change_splits_trajectory(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + processor=FakeProcessor(), + vision_info_extractor=fake_vision_info_extractor, + backend=InspectingBackend(), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-mm-split")) + + first_payload = { + "model": "dummy-model", + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://a.png"}}, + {"type": "text", "text": "describe image a"}, + ], + } + ], + } + second_payload = { + "model": "dummy-model", + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://b.png"}}, + {"type": "text", "text": "describe image b"}, + ], + } + ], + } + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post(f"{session.base_url}/chat/completions", json=first_payload) + second = await client.post(f"{session.base_url}/chat/completions", json=second_payload) + + trajectories = ray.get(actor.finalize_session.remote("session-mm-split")) + ray.get(actor.shutdown.remote()) + + assert first.status_code == 200 + assert second.status_code == 200 + first_call = json.loads(first.json()["choices"][0]["message"]["content"]) + second_call = json.loads(second.json()["choices"][0]["message"]["content"]) + assert first_call["image_data"] == ["image://a.png"] + assert second_call["image_data"] == ["image://b.png"] + assert len(trajectories) == 2 + + +@pytest.mark.asyncio +async def test_gateway_actor_continuation_with_tool_returned_image_appends_media(ray_runtime): + from verl.utils.chat_template import apply_chat_template, initialize_system_prompt + from uni_agent.trainer.gateway.gateway import GatewayActor + + processor = FakeProcessor() + tool_call_text = '\n{"name": "search", "arguments": {"query": "crop"}}\n' + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + processor=processor, + vision_info_extractor=fake_vision_info_extractor, + backend=InspectingSequencedBackend([tool_call_text, "__inspect__"]), + tool_parser_name="hermes", + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-mm-tool-image")) + + tools = [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}] + initial_message = { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://a.png"}}, + {"type": "text", "text": "find a crop"}, + ], + } + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": tools, + "messages": [initial_message], + }, + ) + assert first.status_code == 200 + assistant_message = first.json()["choices"][0]["message"] + tool_message = { + "role": "tool", + "tool_call_id": assistant_message["tool_calls"][0]["id"], + "content": [ + {"type": "image_url", "image_url": {"url": "image://tool-b.png"}}, + {"type": "text", "text": "zoomed crop"}, + ], + } + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": tools, + "messages": [initial_message, assistant_message, tool_message], + }, + ) + + trajectories = ray.get(actor.finalize_session.remote("session-mm-tool-image")) + ray.get(actor.shutdown.remote()) + + assert second.status_code == 200 + second_call = json.loads(second.json()["choices"][0]["message"]["content"]) + assert second_call["image_data"] == ["image://a.png", "image://tool-b.png"] + assert len(trajectories) == 1 + assert trajectories[0].multi_modal_data == { + "images": ["image://a.png", "image://tool-b.png"], + } + + initial_raw_prompt = apply_chat_template( + processor, + [initial_message], + tools=tools, + tokenize=False, + add_generation_prompt=True, + ) + initial_prompt_ids = processor( + text=[initial_raw_prompt], + images=["image://a.png"], + videos=None, + return_tensors="pt", + do_sample_frames=False, + )["input_ids"][0].tolist() + + incremental_raw_prompt = apply_chat_template( + processor, + [tool_message], + tokenize=False, + add_generation_prompt=True, + ) + incremental_prompt_ids = processor( + text=[incremental_raw_prompt], + images=["image://tool-b.png"], + videos=None, + return_tensors="pt", + do_sample_frames=False, + )["input_ids"][0].tolist() + system_prompt = initialize_system_prompt(processor) + expected_incremental_ids = incremental_prompt_ids[len(system_prompt) :] + expected_prompt_ids = initial_prompt_ids + [ord(char) for char in tool_call_text] + expected_incremental_ids + assert second_call["prompt_ids"] == expected_prompt_ids + + +@pytest.mark.asyncio +async def test_gateway_actor_prefix_mismatch_splits_trajectories(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["FIRST", "SECOND"])) + ray.get(actor.start.remote()) + + session = ray.get(actor.create_session.remote("session-1")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + assert first.status_code == 200 + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [{"role": "user", "content": "replacement context"}], + }, + ) + assert second.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-1")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 2 + assert trajectories[0].prompt_ids != trajectories[1].prompt_ids + + +@pytest.mark.asyncio +async def test_gateway_actor_tool_context_change_splits_trajectory(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["FIRST", "SECOND"])) + ray.get(actor.start.remote()) + + session = ray.get(actor.create_session.remote("session-tools")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + assert first.status_code == 200 + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "lookup", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "first turn"}, + {"role": "assistant", "content": "FIRST"}, + {"role": "user", "content": "follow up"}, + ], + }, + ) + assert second.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-tools")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 2 + + +@pytest.mark.asyncio +async def test_gateway_actor_does_not_forward_tools_in_sampling_params(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=RejectToolsSamplingParamsBackend("SAFE"), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-no-tools-sampling")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_gateway_actor_strips_request_envelope_but_keeps_sampling_params(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=RejectRequestEnvelopeBackend( + "SAFE", + expected_sampling_params={"temperature": 0.25, "top_p": 0.8, "max_tokens": 128}, + ), + base_sampling_params={"temperature": 0.1, "top_p": 0.8, "max_tokens": 64}, + allowed_request_sampling_param_keys={"temperature", "max_tokens"}, + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-envelope-boundary")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "temperature": 0.25, + "max_tokens": 128, + "presence_penalty": 1.5, + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_gateway_actor_ignores_non_whitelisted_request_sampling_params(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=RejectRequestEnvelopeBackend( + "SAFE", + expected_sampling_params={"temperature": 0.1, "top_p": 0.9}, + ), + base_sampling_params={"temperature": 0.1, "top_p": 0.9}, + allowed_request_sampling_param_keys={"temperature"}, + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-non-whitelist")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "presence_penalty": 1.5, + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_gateway_actor_continuation_preserves_prompt_and_generation_masks(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["FIRST", "SECOND"])) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-continuation-mask")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [ + { + "role": "user", + "content": "first turn", + } + ], + }, + ) + assert first.status_code == 200 + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [ + {"role": "user", "content": "first turn"}, + {"role": "assistant", "content": "FIRST"}, + {"role": "user", "content": "follow up"}, + ], + }, + ) + assert second.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-continuation-mask")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + assert 0 in trajectories[0].response_mask + assert trajectories[0].response_mask[-len("SECOND") :] == [1] * len("SECOND") + + +@pytest.mark.asyncio +async def test_gateway_actor_tool_argument_json_equivalence_does_not_split_after_valid_continuation(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + tool_call_text = '\n{"name": "search", "arguments": {"b": 2, "a": 1}}\n' + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=QueuedBackend([tool_call_text, "SECOND", "THIRD"]), + tool_parser_name="hermes", + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-tool-arg-drift")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "what is the weather?"}], + }, + ) + assert first.status_code == 200 + assistant_tool_message = first.json()["choices"][0]["message"] + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "what is the weather?"}, + assistant_tool_message, + {"role": "tool", "tool_call_id": assistant_tool_message["tool_calls"][0]["id"], "content": "sunny"}, + ], + }, + ) + assert second.status_code == 200 + + drifted_tool_message = copy.deepcopy(assistant_tool_message) + drifted_tool_message["tool_calls"][0]["function"]["arguments"] = json.dumps({"a": 1, "b": 2}) + third = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "what is the weather?"}, + drifted_tool_message, + {"role": "tool", "tool_call_id": assistant_tool_message["tool_calls"][0]["id"], "content": "sunny"}, + {"role": "assistant", "content": "SECOND"}, + {"role": "user", "content": "follow up"}, + ], + }, + ) + assert third.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-tool-arg-drift")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + assert 0 in trajectories[0].response_mask + assert trajectories[0].response_mask[-len("THIRD") :] == [1] * len("THIRD") + + +def test_message_prefix_falls_back_to_raw_tool_argument_value_comparison_when_arguments_are_invalid_json(): + from uni_agent.trainer.gateway.gateway import _is_message_prefix + + prefix = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call-1", + "type": "function", + "function": {"name": "search", "arguments": "{\"query\": weather}"}, + } + ], + } + ] + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call-1", + "type": "function", + "function": {"name": "search", "arguments": "{\"query\": sunny}"}, + } + ], + } + ] + + assert _is_message_prefix(prefix, messages) is False + + +@pytest.mark.asyncio +async def test_gateway_actor_serializes_same_session_concurrent_requests(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=RejectConcurrentSessionBackend(["FIRST", "SECOND"]), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-concurrent")) + + async with httpx.AsyncClient(timeout=5.0) as client: + async def send_request(): + return await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [{"role": "user", "content": "same session prompt"}], + }, + ) + + first, second = await asyncio.gather(send_request(), send_request()) + + trajectories = ray.get(actor.finalize_session.remote("session-concurrent")) + ray.get(actor.shutdown.remote()) + + assert first.status_code == 200 + assert second.status_code == 200 + assert len(trajectories) == 2 + assert trajectories[0].response_ids == [ord(char) for char in "FIRST"] + assert trajectories[1].response_ids == [ord(char) for char in "SECOND"] + assert trajectories[0].response_mask == [1] * len("FIRST") + assert trajectories[1].response_mask == [1] * len("SECOND") + + +@pytest.mark.asyncio +async def test_gateway_actor_rejects_chat_after_complete(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["DONE"])) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-completed-chat")) + ray.get(actor.complete_session.remote("session-completed-chat")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "after complete"}]}, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 409 + + +@pytest.mark.asyncio +async def test_gateway_actor_finalizes_without_complete(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["DONE"])) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-finalize-without-complete")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "finish directly"}]}, + ) + assert response.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-finalize-without-complete")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + assert trajectories[0].reward_info == {} + + +@pytest.mark.parametrize( + ("payload", "detail_fragment"), + [ + ({"model": "dummy-model", "messages": []}, "messages must be non-empty"), + ( + {"model": "dummy-model", "messages": [{"role": "user", "name": 123, "content": "hello"}]}, + "message.name must be a string", + ), + ( + {"model": "dummy-model", "messages": [{"role": "user", "content": 123}]}, + "Unsupported content type", + ), + ( + { + "model": "dummy-model", + "messages": [{"role": "assistant", "content": "", "tool_calls": {"id": "call-1"}}], + }, + "tool_calls must be a list", + ), + ( + { + "model": "dummy-model", + "tools": {"type": "function"}, + "messages": [{"role": "user", "content": "hello"}], + }, + "tools must be a list", + ), + ], +) +@pytest.mark.asyncio +async def test_gateway_actor_rejects_malformed_requests_with_bad_request(ray_runtime, payload, detail_fragment): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["DONE"])) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-validation")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json=payload, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 400 + assert detail_fragment in response.text + + +@pytest.mark.asyncio +async def test_gateway_actor_backend_failure_does_not_commit_partial_state(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=FailingBackend("boom")) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-backend-failure")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "first turn"}]}, + ) + + state = ray.get(actor.get_session_state.remote("session-backend-failure")) + ray.get(actor.shutdown.remote()) + + assert response.status_code == 500 + assert state["num_trajectories"] == 0 + assert state["has_active_trajectory"] is False + + +@pytest.mark.asyncio +async def test_gateway_actor_backend_failure_after_tool_mismatch_does_not_split(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=SequencedBackend(["FIRST", RuntimeError("boom")]), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-failure-mismatch")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + assert first.status_code == 200 + + async with httpx.AsyncClient(timeout=5.0) as client: + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "lookup", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "first turn"}, + {"role": "assistant", "content": "FIRST"}, + {"role": "user", "content": "follow up"}, + ], + }, + ) + assert second.status_code == 500 + + state = ray.get(actor.get_session_state.remote("session-failure-mismatch")) + trajectories = ray.get(actor.finalize_session.remote("session-failure-mismatch")) + ray.get(actor.shutdown.remote()) + + assert state["num_trajectories"] == 0 + assert len(trajectories) == 1 + assert trajectories[0].response_ids == [ord(char) for char in "FIRST"] + + +@pytest.mark.asyncio +async def test_gateway_actor_tool_call_decode_returns_openai_format(ray_runtime): + """When tool_parser_name is set and model outputs tool call tokens, + the HTTP response should contain tool_calls in OpenAI format.""" + from uni_agent.trainer.gateway.gateway import GatewayActor + + tool_call_text = '\n{"name": "search", "arguments": {"query": "weather"}}\n' + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=QueuedBackend([tool_call_text, "sunny today"]), + tool_parser_name="hermes", + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-tool-call")) + + async with httpx.AsyncClient(timeout=5.0) as client: + # First request: model returns a tool call + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "what is the weather?"}], + }, + ) + assert first.status_code == 200 + first_data = first.json() + assert first_data["choices"][0]["finish_reason"] == "tool_calls" + tool_calls = first_data["choices"][0]["message"].get("tool_calls") + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["name"] == "search" + assert tool_calls[0]["type"] == "function" + assert "id" in tool_calls[0] + # HTTP response arguments should be a JSON string (OpenAI compatible) + assert isinstance(tool_calls[0]["function"]["arguments"], str) + + # Second request: agent sends back tool result as continuation + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "what is the weather?"}, + {"role": "assistant", "content": None, "tool_calls": tool_calls}, + {"role": "tool", "tool_call_id": tool_calls[0]["id"], "content": "sunny and warm"}, + ], + }, + ) + assert second.status_code == 200 + assert second.json()["choices"][0]["message"]["content"] == "sunny today" + + trajectories = ray.get(actor.finalize_session.remote("session-tool-call")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + # Should have both mask=0 (incremental) and mask=1 (model output) tokens + assert 0 in trajectories[0].response_mask + assert 1 in trajectories[0].response_mask + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("stop_reason", "expected_finish_reason"), + [ + ("completed", "stop"), + ("length", "length"), + ("abort", "stop"), + ("matched_stop", "stop"), + (None, "stop"), + ], +) +async def test_decode_response_normalizes_backend_stop_reasons(stop_reason, expected_finish_reason): + """Gateway must map backend-specific stop_reason values to OpenAI-spec + finish_reason values so downstream OpenAI/litellm parsers stay compatible. + """ + from uni_agent.trainer.gateway.gateway import _GatewayActor + + actor = _GatewayActor(tokenizer=FakeTokenizer(), backend=QueuedBackend(["IGNORED"])) + response_ids = [ord(char) for char in "hello"] + + _message, finish_reason = await actor._decode_response( + response_ids, tools=None, stop_reason=stop_reason + ) + + assert finish_reason == expected_finish_reason + + +@pytest.mark.asyncio +async def test_decode_response_preserves_unknown_stop_reasons(): + """Unknown backend stop_reason values should be forwarded unchanged so a + future reader can spot a new backend value rather than silently coerce it. + """ + from uni_agent.trainer.gateway.gateway import _GatewayActor + + actor = _GatewayActor(tokenizer=FakeTokenizer(), backend=QueuedBackend(["IGNORED"])) + response_ids = [ord(char) for char in "hello"] + + _message, finish_reason = await actor._decode_response( + response_ids, tools=None, stop_reason="unknown_future_value" + ) + + assert finish_reason == "unknown_future_value" diff --git a/tests/uni_agent/trainer/gateway/test_gateway_manager_on_cpu.py b/tests/uni_agent/trainer/gateway/test_gateway_manager_on_cpu.py new file mode 100644 index 0000000..89c19fd --- /dev/null +++ b/tests/uni_agent/trainer/gateway/test_gateway_manager_on_cpu.py @@ -0,0 +1,106 @@ +import httpx +import pytest +import ray + +from tests.uni_agent.trainer.support import FakeTokenizer, QueuedBackend, TrackingGatewayActor + + +@pytest.fixture +def ray_runtime(): + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.mark.asyncio +async def test_gateway_manager_routes_sessions_stickily(ray_runtime): + from uni_agent.trainer.gateway.gateway import GatewayActor + from uni_agent.trainer.gateway.manager import GatewayManager + + gateways = [ + GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["A"])), + GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["B"])), + ] + ray.get([gateway.start.remote() for gateway in gateways]) + + manager = GatewayManager(gateways) + session_a = await manager.create_session("session-a") + session_b = await manager.create_session("session-b") + + assert manager.gateway_count == 2 + assert session_a.base_url != session_b.base_url + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session_a.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "route a"}]}, + ) + second = await client.post( + f"{session_b.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "route b"}]}, + ) + assert first.status_code == 200 + assert second.status_code == 200 + + trajectories_a = await manager.finalize_session("session-a") + trajectories_b = await manager.finalize_session("session-b") + + assert len(trajectories_a) == 1 + assert len(trajectories_b) == 1 + + ray.get([gateway.shutdown.remote() for gateway in gateways]) + + +@pytest.mark.asyncio +async def test_gateway_manager_uses_least_active_sessions_routing(ray_runtime): + from uni_agent.trainer.gateway.manager import GatewayManager + + gateways = [ + TrackingGatewayActor.remote("gw-0"), + TrackingGatewayActor.remote("gw-1"), + ] + ray.get([gateway.start.remote() for gateway in gateways]) + + manager = GatewayManager(gateways) + session_a = await manager.create_session("session-a") + session_b = await manager.create_session("session-b") + session_c = await manager.create_session("session-c") + + assert manager.active_sessions_per_gateway == [2, 1] + assert session_a.base_url.startswith("http://gw-0/") + assert session_b.base_url.startswith("http://gw-1/") + assert session_c.base_url.startswith("http://gw-0/") + + await manager.finalize_session("session-a") + assert manager.active_sessions_per_gateway == [1, 1] + + session_d = await manager.create_session("session-d") + assert session_d.base_url.startswith("http://gw-0/") + assert manager.active_sessions_per_gateway == [2, 1] + + ray.get([gateway.shutdown.remote() for gateway in gateways]) + + +@pytest.mark.asyncio +async def test_gateway_manager_wait_for_completion_delegates_to_session_owner(ray_runtime): + from uni_agent.trainer.gateway.manager import GatewayManager + + gateways = [ + TrackingGatewayActor.remote("gw-0"), + TrackingGatewayActor.remote("gw-1"), + ] + ray.get([gateway.start.remote() for gateway in gateways]) + + manager = GatewayManager(gateways) + await manager.create_session("session-a") + await manager.create_session("session-b") + + await manager.wait_for_completion("session-a", timeout=1.5) + + stats_0 = ray.get(gateways[0].stats.remote()) + stats_1 = ray.get(gateways[1].stats.remote()) + + assert stats_0["waited"] == [("session-a", 1.5)] + assert stats_1["waited"] == [] + + ray.get([gateway.shutdown.remote() for gateway in gateways]) diff --git a/tests/uni_agent/trainer/gateway/test_session_runtime_on_cpu.py b/tests/uni_agent/trainer/gateway/test_session_runtime_on_cpu.py new file mode 100644 index 0000000..0f987d2 --- /dev/null +++ b/tests/uni_agent/trainer/gateway/test_session_runtime_on_cpu.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import httpx +import pytest +import ray + +from tests.uni_agent.trainer.support import FakeTokenizer, RecordingLLMClient + + +@pytest.fixture +def ray_runtime(): + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.mark.asyncio +async def test_gateway_serving_runtime_owns_gateway_lifecycle_and_session_runtime(ray_runtime): + from uni_agent.trainer.gateway.runtime import GatewayServingRuntime + + llm_client = RecordingLLMClient("OWNER") + runtime = GatewayServingRuntime( + llm_client=llm_client, + gateway_count=1, + gateway_actor_kwargs={ + "tokenizer": FakeTokenizer(), + }, + ) + + session = await runtime.create_session("session-owner") + wait_task = runtime.wait_for_completion("session-owner", timeout=2.0) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "owner path"}]}, + ) + assert response.status_code == 200 + assert response.json()["choices"][0]["message"]["content"] == "OWNER" + + complete = await client.post( + f"{session.base_url.removesuffix('/v1')}/complete", + json={"reward_info": {"score": 0.5, "label": "owner"}}, + ) + assert complete.status_code == 200 + + await wait_task + trajectories = await runtime.finalize_session("session-owner") + await runtime.shutdown() + + assert len(trajectories) == 1 + assert trajectories[0].reward_info == {"score": 0.5, "label": "owner"} + + +@pytest.mark.asyncio +async def test_gateway_serving_runtime_delegates_generate_to_llm_client(ray_runtime): + from uni_agent.trainer.gateway.runtime import GatewayServingRuntime + + llm_client = RecordingLLMClient("DELEGATED") + runtime = GatewayServingRuntime(llm_client=llm_client, gateway_count=0) + + output = await runtime.generate( + "request-direct", + prompt_ids=[4, 5, 6], + sampling_params={"temperature": 0.2}, + image_data=["image://direct.png"], + ) + + await runtime.shutdown() + + assert output.token_ids == [ord(char) for char in "DELEGATED"] + assert llm_client.calls == [ + { + "request_id": "request-direct", + "prompt_ids": [4, 5, 6], + "sampling_params": {"temperature": 0.2}, + "image_data": ["image://direct.png"], + "video_data": None, + "kwargs": {}, + } + ] + + +@pytest.mark.asyncio +async def test_gateway_serving_runtime_complete_session_forwards_reward_info(ray_runtime): + from uni_agent.trainer.gateway.runtime import GatewayServingRuntime + + runtime = GatewayServingRuntime( + llm_client=RecordingLLMClient("PYTHON-API"), + gateway_count=1, + gateway_actor_kwargs={ + "tokenizer": FakeTokenizer(), + }, + ) + + session = await runtime.create_session("session-python-complete") + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "python complete"}]}, + ) + assert response.status_code == 200 + + await runtime.complete_session( + "session-python-complete", + reward_info={"score": 0.75, "label": "python-api"}, + ) + trajectories = await runtime.finalize_session("session-python-complete") + await runtime.shutdown() + + assert len(trajectories) == 1 + assert trajectories[0].reward_info == {"score": 0.75, "label": "python-api"} + + +@pytest.mark.asyncio +async def test_gateway_serving_runtime_round_robins_actors_across_alive_nodes(ray_runtime, monkeypatch): + """gateway_count > 1 should distribute actors across alive CPU nodes round-robin.""" + from uni_agent.trainer.gateway.runtime import GatewayServingRuntime + + fake_nodes = [ + {"NodeID": "a" * 56, "Alive": True, "Resources": {"CPU": 8.0}}, + {"NodeID": "b" * 56, "Alive": True, "Resources": {"CPU": 8.0}}, + {"NodeID": "c" * 56, "Alive": False, "Resources": {"CPU": 8.0}}, + {"NodeID": "d" * 56, "Alive": True, "Resources": {"GPU": 1.0}}, + ] + monkeypatch.setattr("uni_agent.trainer.gateway.runtime.ray.nodes", lambda: fake_nodes) + + captured_node_ids = [] + + class _StubStartHandle: + @staticmethod + def remote(): + return ray.put(None) + + class _StubActorHandle: + start = _StubStartHandle() + + class _RecordingActor: + @classmethod + def options(cls, *, scheduling_strategy): + captured_node_ids.append(scheduling_strategy.node_id) + return cls + + @classmethod + def remote(cls, **kwargs): + return _StubActorHandle() + + monkeypatch.setattr("uni_agent.trainer.gateway.gateway.GatewayActor", _RecordingActor) + + runtime = GatewayServingRuntime( + llm_client=object(), + gateway_count=5, + gateway_actor_kwargs={"tokenizer": object()}, + ) + + assert captured_node_ids == ["a" * 56, "b" * 56, "a" * 56, "b" * 56, "a" * 56], captured_node_ids + assert len(runtime.owned_gateway_actors) == 5 + # No shutdown call needed: stub actors have no real Ray state. diff --git a/tests/uni_agent/trainer/support.py b/tests/uni_agent/trainer/support.py new file mode 100644 index 0000000..57820d3 --- /dev/null +++ b/tests/uni_agent/trainer/support.py @@ -0,0 +1,439 @@ +import asyncio +import json + +import ray +import torch + +from uni_agent.trainer.framework.types import SessionHandle, Trajectory +from verl.workers.rollout.replica import TokenOutput + + +class FakeTokenizer: + def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True, tools=None, **kwargs): + parts = [] + for message in messages: + parts.append(f"{message['role']}:{self._normalize_content(message.get('content', ''))}\n") + if add_generation_prompt: + parts.append("assistant:") + text = "".join(parts) + if tokenize: + return [ord(char) for char in text] + return text + + def decode(self, token_ids, skip_special_tokens=True): + if hasattr(token_ids, "tolist"): + token_ids = token_ids.tolist() + normalized = [] + for token_id in token_ids: + if hasattr(token_id, "item"): + token_id = token_id.item() + normalized.append(int(token_id)) + return "".join(chr(token_id) for token_id in normalized) + + def encode(self, text, add_special_tokens=False): + return [ord(char) for char in text] + + def _normalize_content(self, content): + if isinstance(content, list): + return "".join(part.get("text", "") if isinstance(part, dict) else str(part) for part in content) + if content is None: + return "" + return str(content) + + +class FakeProcessor: + class _ImageProcessor: + patch_size = 16 + + image_token_id = 32001 + video_token_id = 32002 + + def __init__(self): + self.image_processor = self._ImageProcessor() + self.tokenizer = FakeTokenizer() + self.last_processor_call = None + self.last_get_rope_index_call = None + + def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True, tools=None, **kwargs): + return self.tokenizer.apply_chat_template( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + tools=tools, + **kwargs, + ) + + def __call__( + self, + *, + text, + images=None, + videos=None, + video_metadata=None, + return_tensors=None, + do_sample_frames=False, + **kwargs, + ): + assert len(text) == 1 + self.last_processor_call = { + "text": list(text), + "images": None if images is None else list(images), + "videos": None if videos is None else list(videos), + "video_metadata": None if video_metadata is None else list(video_metadata), + "return_tensors": return_tensors, + "do_sample_frames": do_sample_frames, + } + + prompt_ids = self.tokenizer.encode(text[0], add_special_tokens=False) + if images: + prompt_ids.extend([self.image_token_id] * len(images)) + if videos: + prompt_ids.extend([self.video_token_id] * len(videos)) + + input_ids = torch.tensor([prompt_ids], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + output = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + if images: + image_count = len(images) + output["pixel_values"] = torch.arange(image_count * 12, dtype=torch.float32).reshape(image_count, 3, 2, 2) + output["image_grid_thw"] = torch.tensor([[1, 2, 3]] * image_count, dtype=torch.long) + output["mm_token_type_ids"] = torch.ones_like(input_ids) + + if videos: + video_count = len(videos) + output["pixel_values_videos"] = torch.arange(video_count * 24, dtype=torch.float32).reshape( + video_count, 3, 2, 4 + ) + output["video_grid_thw"] = torch.tensor([[1, 3, 4]] * video_count, dtype=torch.long) + output["mm_token_type_ids"] = torch.ones_like(input_ids) + + return output + + def get_rope_index( + self, + *, + input_ids, + attention_mask, + image_grid_thw=None, + video_grid_thw=None, + mm_token_type_ids=None, + **kwargs, + ): + self.last_get_rope_index_call = { + "input_ids": input_ids.clone(), + "attention_mask": attention_mask.clone(), + "image_grid_thw": None if image_grid_thw is None else image_grid_thw.clone(), + "video_grid_thw": None if video_grid_thw is None else video_grid_thw.clone(), + "mm_token_type_ids": None if mm_token_type_ids is None else mm_token_type_ids.clone(), + } + seq_len = input_ids.shape[1] + base = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) + vision_position_ids = torch.stack( + [ + base + 100, + base + 200, + base + 300, + ], + dim=0, + ).unsqueeze(1) + return vision_position_ids, None + + +async def fake_vision_info_extractor(messages, image_patch_size, config=None): + assert image_patch_size == 16 + images = [] + videos = [] + for message in messages: + content = message.get("content") + if not isinstance(content, list): + continue + for part in content: + if not isinstance(part, dict): + continue + if part.get("type") == "image_url": + image_url = part.get("image_url", {}) + if isinstance(image_url, dict) and image_url.get("url"): + images.append(image_url["url"]) + elif part.get("type") == "video_url": + video_url = part.get("video_url", {}) + if isinstance(video_url, dict) and video_url.get("url"): + videos.append(video_url["url"]) + return images or None, videos or None + + +class SingleUseVisionInfoExtractor: + def __init__(self): + self.calls = 0 + + async def __call__(self, messages, image_patch_size, config=None): + self.calls += 1 + if self.calls > 1: + raise AssertionError("vision_info_extractor should not be called again on continuation") + return await fake_vision_info_extractor(messages, image_patch_size=image_patch_size, config=config) + + +class InspectingBackend: + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + payload = json.dumps( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + }, + sort_keys=True, + ) + token_ids = [ord(char) for char in payload] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class InspectingSequencedBackend: + def __init__(self, steps): + self.steps = list(steps) + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + step = self.steps.pop(0) + if step == "__inspect__": + text = json.dumps( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + }, + sort_keys=True, + ) + elif isinstance(step, Exception): + raise step + else: + text = step + + token_ids = [ord(char) for char in text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class QueuedBackend: + def __init__(self, responses): + self._responses = list(responses) + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + text = self._responses.pop(0) + token_ids = [ord(char) for char in text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class SlowBackend: + def __init__(self, response_text="SLOW", delay_s=1.5): + self.response_text = response_text + self.delay_s = delay_s + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + await asyncio.sleep(self.delay_s) + token_ids = [ord(char) for char in self.response_text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class RecordingLLMClient: + def __init__(self, response_text="OK"): + self.response_text = response_text + self.calls = [] + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None, **kwargs): + self.calls.append( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + "kwargs": dict(kwargs), + } + ) + token_ids = [ord(char) for char in self.response_text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class RejectToolsSamplingParamsBackend: + def __init__(self, response_text: str = "OK"): + self.response_text = response_text + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + if "tools" in sampling_params: + raise RuntimeError("tools leaked into sampling_params") + token_ids = [ord(char) for char in self.response_text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class RejectRequestEnvelopeBackend: + def __init__(self, response_text: str = "OK", expected_sampling_params: dict | None = None): + self.response_text = response_text + self.expected_sampling_params = expected_sampling_params + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + assert "messages" not in sampling_params + assert "model" not in sampling_params + assert "tools" not in sampling_params + if self.expected_sampling_params is None: + assert sampling_params["temperature"] == 0.25 + else: + assert sampling_params == self.expected_sampling_params + token_ids = [ord(char) for char in self.response_text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class FailingBackend: + def __init__(self, error_message: str = "backend failure"): + self.error_message = error_message + self.calls = [] + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + self.calls.append( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + } + ) + raise RuntimeError(self.error_message) + + +class SequencedBackend: + def __init__(self, steps): + self.steps = list(steps) + self.calls = [] + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + self.calls.append( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + } + ) + step = self.steps.pop(0) + if isinstance(step, Exception): + raise step + token_ids = [ord(char) for char in step] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class RejectConcurrentSessionBackend: + def __init__(self, responses, delay: float = 0.05): + self._responses = list(responses) + self._delay = delay + self._active_request_ids: set[str] = set() + self.call_windows = [] + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + if request_id in self._active_request_ids: + raise RuntimeError(f"concurrent request for session {request_id}") + + self._active_request_ids.add(request_id) + started_at = asyncio.get_running_loop().time() + try: + await asyncio.sleep(self._delay) + text = self._responses.pop(0) + token_ids = [ord(char) for char in text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + finally: + finished_at = asyncio.get_running_loop().time() + self.call_windows.append((request_id, started_at, finished_at)) + self._active_request_ids.remove(request_id) + + +@ray.remote +class TrackingGatewayActor: + def __init__(self, name: str): + self.name = name + self.sessions = {} + self.created = [] + self.finalized = [] + self.aborted = [] + self.waited = [] + + async def start(self): + return None + + async def shutdown(self): + return None + + async def create_session(self, session_id: str, metadata: dict | None = None): + handle = SessionHandle(session_id=session_id, base_url=f"http://{self.name}/{session_id}/v1") + self.sessions[session_id] = {"metadata": metadata or {}} + self.created.append(session_id) + return handle + + async def finalize_session(self, session_id: str): + self.finalized.append(session_id) + self.sessions.pop(session_id, None) + return [ + Trajectory( + prompt_ids=[1], + response_ids=[2], + response_mask=[1], + ) + ] + + async def abort_session(self, session_id: str): + self.aborted.append(session_id) + self.sessions.pop(session_id, None) + return None + + async def wait_for_completion(self, session_id: str, timeout: float | None = None): + self.waited.append((session_id, timeout)) + return None + + async def stats(self): + return { + "name": self.name, + "created": list(self.created), + "finalized": list(self.finalized), + "aborted": list(self.aborted), + "waited": list(self.waited), + } diff --git a/verl b/verl index 460ccf3..f6bb7a3 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit 460ccf3c2f19ae537f7de53243c89df69d746f05 +Subproject commit f6bb7a3327c98942e366f40cf4cd796415e31919 From f56c03ee362d6a4e6f953597cdbc31e97c64c999 Mon Sep 17 00:00:00 2001 From: zackcxb Date: Sun, 24 May 2026 14:38:52 +0000 Subject: [PATCH 05/22] feat(trainer): add deepeyes gateway recipe Co-Authored-By: Claude Opus 4.6 --- .../agent_train/deepeyes_gateway/README.md | 72 +++++++ .../configs/deepeyes_gateway_grpo.yaml | 87 +++++++++ .../configs/image_zoom_in_tool_config.yaml | 26 +++ .../run_deepeyes_gateway_grpo.sh | 62 ++++++ uni_agent/recipes/__init__.py | 1 + .../recipes/deepeyes_gateway/__init__.py | 1 + .../recipes/deepeyes_gateway/agent_runner.py | 172 +++++++++++++++++ uni_agent/recipes/deepeyes_gateway/dataset.py | 127 +++++++++++++ uni_agent/recipes/deepeyes_gateway/reward.py | 179 ++++++++++++++++++ 9 files changed, 727 insertions(+) create mode 100644 examples/agent_train/deepeyes_gateway/README.md create mode 100644 examples/agent_train/deepeyes_gateway/configs/deepeyes_gateway_grpo.yaml create mode 100644 examples/agent_train/deepeyes_gateway/configs/image_zoom_in_tool_config.yaml create mode 100755 examples/agent_train/deepeyes_gateway/run_deepeyes_gateway_grpo.sh create mode 100644 uni_agent/recipes/__init__.py create mode 100644 uni_agent/recipes/deepeyes_gateway/__init__.py create mode 100644 uni_agent/recipes/deepeyes_gateway/agent_runner.py create mode 100644 uni_agent/recipes/deepeyes_gateway/dataset.py create mode 100644 uni_agent/recipes/deepeyes_gateway/reward.py diff --git a/examples/agent_train/deepeyes_gateway/README.md b/examples/agent_train/deepeyes_gateway/README.md new file mode 100644 index 0000000..3d06033 --- /dev/null +++ b/examples/agent_train/deepeyes_gateway/README.md @@ -0,0 +1,72 @@ +# DeepEyes Gateway Training Example + +This example wires the DeepEyes multimodal tool-use recipe into the Uni-Agent +gateway framework path on `verl.trainer.main_ppo_sync`. + +## Layout + +- `uni_agent.recipes.deepeyes_gateway.agent_runner`: gateway-backed DeepEyes + tool loop. +- `uni_agent.recipes.deepeyes_gateway.dataset`: dataset adapter that emits + `raw_prompt`, `tools_kwargs`, and reward fields without local prompt + tokenization. +- `uni_agent.recipes.deepeyes_gateway.reward`: self-contained `compute_score` + wrapper for the DeepEyes LLM-as-a-judge reward. +- `configs/deepeyes_gateway_grpo.yaml`: recipe config using + `uni_agent.trainer.framework.entry.AgentFrameworkRolloutAdapter`. +- `configs/image_zoom_in_tool_config.yaml`: image zoom-in tool config. +- `run_deepeyes_gateway_grpo.sh`: example full-data launch script. + +## Prerequisites + +- Run from the Uni-Agent repository with the `verl` trainer dependencies + available. +- Launch an OpenAI-compatible judge service and set `LLM_AS_A_JUDGE_BASE`. +- Prepare a DeepEyes parquet dataset with image payloads. +- Reserve training GPUs separately from the judge GPU. + +Example judge service: + +```bash +CUDA_VISIBLE_DEVICES=7 \ +python3 -m vllm.entrypoints.openai.api_server \ + --model /path/to/judge-model \ + --host 127.0.0.1 \ + --port 18901 \ + --served-model-name qwen3-4b-judge \ + --dtype float16 \ + --trust-remote-code \ + --max-model-len 4096 \ + --gpu-memory-utilization 0.75 \ + --enforce-eager +``` + +## Launch + +```bash +bash examples/agent_train/deepeyes_gateway/run_deepeyes_gateway_grpo.sh +``` + +Common overrides: + +```bash +MODEL_PATH=/path/to/policy-model \ +TRAIN_FILE=/path/to/train.parquet \ +VAL_FILE=/path/to/val.parquet \ +LLM_AS_A_JUDGE_BASE=http://127.0.0.1:18901/v1 \ +PROJECT_NAME=my_project \ +EXPERIMENT_NAME=my_run \ +TOTAL_TRAINING_STEPS=20 \ +bash examples/agent_train/deepeyes_gateway/run_deepeyes_gateway_grpo.sh +``` + +The script resolves the config directory relative to its own location, then +launches from the repository root so `uni_agent.*` recipe imports are stable. + +## Notes + +- No parquet data files are included in this example. +- The image tool implementation is still loaded from `verl.tools` by the tool + config; the gateway framework adapter and recipe imports use `uni_agent.*`. +- Reward scoring returns `0.0` if the judge service or reward dependencies are + unavailable. diff --git a/examples/agent_train/deepeyes_gateway/configs/deepeyes_gateway_grpo.yaml b/examples/agent_train/deepeyes_gateway/configs/deepeyes_gateway_grpo.yaml new file mode 100644 index 0000000..84aa6f2 --- /dev/null +++ b/examples/agent_train/deepeyes_gateway/configs/deepeyes_gateway_grpo.yaml @@ -0,0 +1,87 @@ +hydra: + searchpath: + - pkg://verl.trainer.config + +defaults: + - ppo_trainer + - _self_ + +data: + max_prompt_length: 2048 + max_response_length: 2048 + train_batch_size: 8 + return_raw_chat: True + return_multi_modal_inputs: False + filter_overlong_prompts: False + custom_cls: + path: pkg://uni_agent.recipes.deepeyes_gateway.dataset + name: DeepEyesGatewayDataset + +algorithm: + adv_estimator: grpo + kl_ctrl: + kl_coef: 0.0 + +actor_rollout_ref: + hybrid_engine: True + model: + path: /data1/models/Qwen/Qwen3.5-4B + trust_remote_code: True + use_remove_padding: True + use_fused_kernels: True + custom_chat_template: "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{%- if tools %}{{- '<|im_start|>system\\n' }}{%- if messages[0]['role'] == 'system' %}{%- if messages[0]['content'] is string %}{{- messages[0]['content'] }}{%- else %}{{- messages[0]['content'][0]['text'] }}{%- endif %}{%- else %}{{- 'You are a helpful assistant.' }}{%- endif %}{{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}{%- for tool in tools %}{{- \"\\n\" }}{{- tool | tojson }}{%- endfor %}{{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}{% for message in messages %}{% if message['role'] != 'system' or loop.first == false %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endif %}{% endfor %}{%- else %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{%- elif message.role == \"assistant\" %}{{- '<|im_start|>' + message.role }}{%- if message.content %}{{- '\\n' + message.content }}{%- endif %}{%- for tool_call in message.tool_calls %}{%- if tool_call.function is defined %}{%- set tool_call = tool_call.function %}{%- endif %}{{- '\\n\\n{\"name\": \"' }}{{- tool_call.name }}{{- '\", \"arguments\": ' }}{{- tool_call.arguments | tojson }}{{- '}\\n' }}{%- endfor %}{{- '<|im_end|>\\n' }}{%- elif message.role == \"tool\" %}{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}{{- '<|im_start|>user' }}{%- endif %}{{- '\\n\\n' }}{% if message['content'] is string %}{{ message.content }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif content['type'] == 'text' or 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{- '\\n' }}{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}{{- '<|im_end|>\\n' }}{%- endif %}{%- endif %}{% endfor %}{%- endif %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + actor: + optim: + lr: 1e-6 + ppo_mini_batch_size: 8 + ppo_micro_batch_size_per_gpu: 1 + use_kl_loss: False + kl_loss_coef: 0.0 + kl_loss_type: low_var_kl + entropy_coeff: 0.0 + ulysses_sequence_parallel_size: 1 + checkpoint: + save_contents: ['model', 'hf_model', 'optimizer', 'extra'] + fsdp_config: + param_offload: True + optimizer_offload: True + rollout: + name: sglang + multi_turn: + format: hermes + n: 2 + log_prob_micro_batch_size_per_gpu: 1 + tensor_model_parallel_size: 1 + max_num_batched_tokens: 8192 + gpu_memory_utilization: 0.7 + enforce_eager: True + free_cache_engine: True + enable_chunked_prefill: True + agent: + agent_loop_manager_class: uni_agent.trainer.framework.entry.AgentFrameworkRolloutAdapter + custom: + agent_framework: + agent_runner_fqn: uni_agent.recipes.deepeyes_gateway.agent_runner.deepeyes_agent_runner + gateway_count: 8 + agent_runner_kwargs: + max_turns: 5 + tool_config_path: examples/agent_train/deepeyes_gateway/configs/image_zoom_in_tool_config.yaml + ref: + log_prob_micro_batch_size_per_gpu: 1 + fsdp_config: + param_offload: True + +trainer: + critic_warmup: 0 + logger: ['console'] + val_before_train: False + save_freq: -1 + test_freq: -1 + project_name: deepeyes_gateway_smoke + experiment_name: qwen35_4b_phase34_smoke + total_epochs: 1 + +reward: + custom_reward_function: + path: pkg://uni_agent.recipes.deepeyes_gateway.reward + name: compute_score diff --git a/examples/agent_train/deepeyes_gateway/configs/image_zoom_in_tool_config.yaml b/examples/agent_train/deepeyes_gateway/configs/image_zoom_in_tool_config.yaml new file mode 100644 index 0000000..b048c17 --- /dev/null +++ b/examples/agent_train/deepeyes_gateway/configs/image_zoom_in_tool_config.yaml @@ -0,0 +1,26 @@ +tools: + - class_name: "verl.tools.image_zoom_in_tool.ImageZoomInTool" + config: + num_workers: 256 + rate_limit: 256 + timeout: 60 + type: native + tool_schema: + type: "function" + function: + name: "image_zoom_in_tool" + description: "Zoom in on a specific region of an image by cropping it based on a bounding box (bbox) and an optional object label." + parameters: + type: "object" + properties: + bbox_2d: + type: "array" + items: + type: "number" + minItems: 4 + maxItems: 4 + description: "The bounding box of the region to zoom in, as [x1, y1, x2, y2], where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner." + label: + type: "string" + description: "The name or label of the object in the specified bounding box (optional)." + required: ["bbox_2d"] diff --git a/examples/agent_train/deepeyes_gateway/run_deepeyes_gateway_grpo.sh b/examples/agent_train/deepeyes_gateway/run_deepeyes_gateway_grpo.sh new file mode 100755 index 0000000..c5d1b70 --- /dev/null +++ b/examples/agent_train/deepeyes_gateway/run_deepeyes_gateway_grpo.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash + +set -xeuo pipefail + +SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +UNI_AGENT_REPO_ROOT=$(cd "${SCRIPT_DIR}/../../.." && pwd) +CONFIG_DIR="${UNI_AGENT_REPO_ROOT}/examples/agent_train/deepeyes_gateway/configs" +cd "${UNI_AGENT_REPO_ROOT}" + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6}" +export VERL_FORCE_TQ_NESTED_READBACK="${VERL_FORCE_TQ_NESTED_READBACK:-1}" +export LLM_AS_A_JUDGE_BASE="${LLM_AS_A_JUDGE_BASE:-http://127.0.0.1:18901/v1}" +export WANDB_MODE="${WANDB_MODE:-offline}" +export NCCL_P2P_DISABLE="${NCCL_P2P_DISABLE:-1}" +export NCCL_SHM_DISABLE="${NCCL_SHM_DISABLE:-1}" +export NCCL_DEBUG="${NCCL_DEBUG:-WARN}" +export PYTHONUNBUFFERED="${PYTHONUNBUFFERED:-1}" +export HYDRA_FULL_ERROR="${HYDRA_FULL_ERROR:-1}" + +PROJECT_NAME="${PROJECT_NAME:-deepeyes_gateway_sync_real_data}" +EXPERIMENT_NAME="${EXPERIMENT_NAME:-qwen35_4b_deepeyes_gateway_grpo}" +MODEL_PATH="${MODEL_PATH:-/data1/models/Qwen/Qwen3.5-4B}" +TRAIN_FILE="${TRAIN_FILE:-/data1/datasets/deepeyes/data/data_0.1.2_visual_toolbox_v2.parquet}" +VAL_FILE="${VAL_FILE:-${TRAIN_FILE}}" +TOTAL_TRAINING_STEPS="${TOTAL_TRAINING_STEPS:-50}" + +python3 -m verl.trainer.main_ppo_sync \ + --config-path="${CONFIG_DIR}" \ + --config-name=deepeyes_gateway_grpo \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + data.train_files="${TRAIN_FILE}" \ + "data.val_files=[${VAL_FILE}]" \ + data.train_batch_size=14 \ + data.max_prompt_length=4096 \ + data.max_response_length=1024 \ + trainer.total_training_steps="${TOTAL_TRAINING_STEPS}" \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=7 \ + trainer.nnodes=1 \ + 'trainer.logger=[console,wandb,tensorboard]' \ + trainer.project_name="${PROJECT_NAME}" \ + trainer.experiment_name="${EXPERIMENT_NAME}" \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + actor_rollout_ref.model.use_fused_kernels=False \ + actor_rollout_ref.model.use_remove_padding=True \ + '+actor_rollout_ref.model.override_config.attn_implementation=eager' \ + actor_rollout_ref.actor.ppo_mini_batch_size=14 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.response_length=1024 \ + actor_rollout_ref.rollout.max_model_len=8192 \ + actor_rollout_ref.rollout.max_num_seqs=4 \ + actor_rollout_ref.rollout.max_num_batched_tokens=16384 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.55 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.dtype=float16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.custom.agent_framework.gateway_count=7 \ + actor_rollout_ref.rollout.custom.agent_framework.tool_config_path="${CONFIG_DIR}/image_zoom_in_tool_config.yaml" \ + actor_rollout_ref.rollout.custom.agent_framework.agent_runner_kwargs.max_turns=5 diff --git a/uni_agent/recipes/__init__.py b/uni_agent/recipes/__init__.py new file mode 100644 index 0000000..a2a8918 --- /dev/null +++ b/uni_agent/recipes/__init__.py @@ -0,0 +1 @@ +"""Importable Uni-Agent training recipes.""" diff --git a/uni_agent/recipes/deepeyes_gateway/__init__.py b/uni_agent/recipes/deepeyes_gateway/__init__.py new file mode 100644 index 0000000..297372c --- /dev/null +++ b/uni_agent/recipes/deepeyes_gateway/__init__.py @@ -0,0 +1 @@ +"""DeepEyes gateway recipe.""" diff --git a/uni_agent/recipes/deepeyes_gateway/agent_runner.py b/uni_agent/recipes/deepeyes_gateway/agent_runner.py new file mode 100644 index 0000000..454fc0f --- /dev/null +++ b/uni_agent/recipes/deepeyes_gateway/agent_runner.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import base64 +import json +from io import BytesIO +from typing import TYPE_CHECKING, Any + +import httpx +from PIL import Image + +if TYPE_CHECKING: + from uni_agent.trainer.framework.types import SessionHandle + from verl.tools.schemas import ToolResponse +else: + SessionHandle = Any + ToolResponse = Any + + +IMAGE_ZOOM_IN_TOOL_NAME = "image_zoom_in_tool" +GATEWAY_REQUEST_TIMEOUT_SECONDS = 300.0 + + +def _json_ready(value: Any) -> Any: + if isinstance(value, Image.Image): + buffer = BytesIO() + value.convert("RGB").save(buffer, format="PNG") + encoded = base64.b64encode(buffer.getvalue()).decode("ascii") + return f"data:image/png;base64,{encoded}" + if isinstance(value, bytes): + encoded = base64.b64encode(value).decode("ascii") + return f"data:image/png;base64,{encoded}" + if isinstance(value, dict): + if "bytes" in value: + return _json_ready(value["bytes"]) + return {key: _json_ready(item) for key, item in value.items()} + if isinstance(value, list): + return [_json_ready(item) for item in value] + if isinstance(value, tuple): + return [_json_ready(item) for item in value] + return value + + +def _tool_kwargs_for_name(tools_kwargs: dict | None) -> dict[str, Any]: + if not isinstance(tools_kwargs, dict): + return {} + + maybe_tool_kwargs = tools_kwargs.get(IMAGE_ZOOM_IN_TOOL_NAME) + return maybe_tool_kwargs if isinstance(maybe_tool_kwargs, dict) else {} + + +def _parse_tool_arguments(arguments: object) -> dict[str, Any]: + if isinstance(arguments, dict): + return arguments + if not isinstance(arguments, str) or not arguments: + return {} + try: + parsed = json.loads(arguments) + except json.JSONDecodeError: + return {} + return parsed if isinstance(parsed, dict) else {} + + +def _assistant_message_from_response(payload: dict[str, Any]) -> dict[str, Any]: + choices = payload.get("choices") + if not choices: + raise ValueError("chat completion response did not include choices") + + message = choices[0].get("message") + if not isinstance(message, dict): + raise ValueError("chat completion response choice did not include a message") + return message + + +def _tool_response_to_openai_tool_message(*, tool_call_id: str, tool_response: ToolResponse) -> dict[str, Any]: + content: list[dict[str, Any]] = [] + + if tool_response.video: + raise NotImplementedError("ToolResponse video content is not supported by the DeepEyes gateway recipe") + + if tool_response.text is not None: + content.append({"type": "text", "text": str(tool_response.text)}) + for image in tool_response.image or []: + content.append({"type": "image", "image": _json_ready(image)}) + if not content: + content.append({"type": "text", "text": ""}) + + return { + "role": "tool", + "tool_call_id": tool_call_id, + "content": content, + } + + +def _select_tool(tool_config: list[Any] | None): + if not tool_config: + raise ValueError("tool_config is required for deepeyes_agent_runner") + + for tool in tool_config: + if getattr(tool, "name", None) == IMAGE_ZOOM_IN_TOOL_NAME: + return tool + raise ValueError(f"tool_config must include {IMAGE_ZOOM_IN_TOOL_NAME}") + + +async def deepeyes_agent_runner( + *, + raw_prompt: list[dict], + session: SessionHandle, + sample_index: int, + tools_kwargs: dict | None = None, + tool_config: list[Any] | None = None, + max_turns: int = 5, + **kwargs, +) -> None: + """Run a DeepEyes multi-turn image zoom-in tool loop against the gateway.""" + del sample_index, kwargs + if session.base_url is None: + raise ValueError("session.base_url is required for deepeyes_agent_runner") + + image_tool = _select_tool(tool_config) + image_tool_kwargs = _tool_kwargs_for_name(tools_kwargs) + create_kwargs = dict(image_tool_kwargs.get("create_kwargs") or {}) + if "image" not in create_kwargs and "image" in image_tool_kwargs: + create_kwargs["image"] = image_tool_kwargs["image"] + execute_kwargs = dict(image_tool_kwargs.get("execute_kwargs") or {}) + release_kwargs = dict(image_tool_kwargs.get("release_kwargs") or {}) + + tool_instance_id: str | None = None + messages = _json_ready(list(raw_prompt)) + + try: + tool_instance_id, _ = await image_tool.create( + instance_id=f"{session.session_id}-image_zoom_in_tool", + create_kwargs=create_kwargs, + ) + tool_schema = image_tool.get_openai_tool_schema().model_dump(exclude_none=True) + + async with httpx.AsyncClient(timeout=GATEWAY_REQUEST_TIMEOUT_SECONDS) as client: + for turn_index in range(max(0, max_turns)): + response = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "deepeyes", + "messages": messages, + "tools": [tool_schema], + }, + ) + response.raise_for_status() + + assistant_message = _assistant_message_from_response(response.json()) + messages.append(dict(assistant_message)) + + tool_calls = assistant_message.get("tool_calls") or [] + if not tool_calls or turn_index + 1 >= max_turns: + break + + for tool_call in tool_calls: + function = tool_call.get("function") or {} + parameters = _parse_tool_arguments(function.get("arguments")) + tool_response, _, _ = await image_tool.execute( + tool_instance_id, + parameters=parameters, + **execute_kwargs, + ) + messages.append( + _tool_response_to_openai_tool_message( + tool_call_id=tool_call.get("id", ""), + tool_response=tool_response, + ) + ) + finally: + if tool_instance_id is not None: + await image_tool.release(tool_instance_id, **release_kwargs) diff --git a/uni_agent/recipes/deepeyes_gateway/dataset.py b/uni_agent/recipes/deepeyes_gateway/dataset.py new file mode 100644 index 0000000..f3854d3 --- /dev/null +++ b/uni_agent/recipes/deepeyes_gateway/dataset.py @@ -0,0 +1,127 @@ +"""Minimal dataset for the DeepEyes gateway recipe. + +Produces ``raw_prompt`` and reward-related fields only. +It does not perform tokenization or vision processing. +""" + +from __future__ import annotations + +import copy +import io +import logging +import re + +import torch +from PIL import Image + +from verl.utils.dataset.rl_dataset import RLHFDataset + +logger = logging.getLogger(__name__) + + +class DeepEyesGatewayDataset(RLHFDataset): + """Thin dataset that leaves prompt encoding and vision extraction to the gateway.""" + + def _build_messages(self, example: dict, key: str) -> tuple[list[dict], object | None]: + messages = copy.deepcopy(example[key]) + images = example.get(self.image_key, None) or [] + videos = example.get(self.video_key, None) or [] + first_image = None + image_offset = 0 + video_offset = 0 + + for message in messages: + content = message.get("content") + if isinstance(content, list): + normalized = [] + for part in content: + normalized_part = _normalize_content_part(part) + if ( + first_image is None + and isinstance(normalized_part, dict) + and normalized_part.get("type") in {"image", "image_url"} + ): + first_image = _decode_image_payload(normalized_part.get("image", normalized_part)) + normalized_part = dict(normalized_part) + normalized_part["image"] = first_image + normalized.append(normalized_part) + message["content"] = normalized + continue + if not isinstance(content, str) or ("" not in content and "