diff --git a/docs/advance/agent_framework.rst b/docs/advance/agent_framework.rst new file mode 100644 index 00000000000..59e294881a1 --- /dev/null +++ b/docs/advance/agent_framework.rst @@ -0,0 +1,184 @@ +Agent Framework +=============== + +Last updated: 05/21/2026. + +.. versionadded:: 0.8.0 + [status: alpha] + +.. warning:: + Agent Framework is ready for use, but the API may change in future releases. + +Agent Framework is a session-based orchestration layer for agentic RL training. +It runs user-defined agent logic (tool calls, multi-turn reasoning, environment +interaction) inside gateway-managed sessions, collects token-level trajectories, +and writes them to the TransferQueue for sync GRPO/PPO training. + +Agent Framework coexists with the legacy :doc:`Agent Loop ` path. +Both produce the same trainer-consumable output; Agent Framework adds +session-level isolation, an OpenAI-compatible HTTP interface per session, and +structured reward dispatch. + + +Overview +-------- + +**Design goals:** + +- Black-box agent runner: any async function that speaks OpenAI chat completions +- Session isolation: each rollout sample gets its own HTTP endpoint +- Reward flexibility: inline scoring via ``reward_loop_worker_handles`` or + framework-level ``reward.custom_reward_function`` bridge +- Subclass extensibility: ``AgentFramework`` is abstract; ship your own + +**Non-goals:** + +- Defining tool semantics (that is the agent runner's job) +- Replacing Agent Loop for single-turn or simple multi-turn use cases + + +System Architecture +------------------- + +.. code-block:: text + + ┌─────────────────────────────────────────────────────────────┐ + │ Trainer (main_ppo_sync.py) │ + │ └── AgentFrameworkRolloutAdapter.generate_sequences(batch) │ + └────────────────────────────┬────────────────────────────────┘ + │ TensorDict prompts + ▼ + ┌─────────────────────────────────────────────────────────────┐ + │ OpenAICompatibleAgentFramework │ + │ ├── create sessions (1 per sample × rollout.n) │ + │ ├── launch agent_runner coroutines │ + │ ├── wait for completion / finalize │ + │ ├── score trajectories (reward dispatch) │ + │ └── write to TransferQueue │ + └────────────────────────────┬────────────────────────────────┘ + │ session lifecycle + ▼ + ┌─────────────────────────────────────────────────────────────┐ + │ GatewayServingRuntime │ + │ ├── GatewayManager (round-robin session routing) │ + │ └── GatewayActor ×N (HTTP /v1/chat/completions per session)│ + │ └── backend: LLMServerClient.generate(token-level) │ + └─────────────────────────────────────────────────────────────┘ + + +System Components +----------------- + ++--------------------------------------+-----------------------------------------------------------------------+ +| Component | Role | ++======================================+=======================================================================+ +| ``AgentFramework`` | Abstract base class. Subclasses implement ``from_config`` and | +| | ``generate_sequences``. | ++--------------------------------------+-----------------------------------------------------------------------+ +| ``OpenAICompatibleAgentFramework`` | Default subclass. Manages sessions, runs agent_runner coroutines, | +| | dispatches reward scoring, writes TQ output. | ++--------------------------------------+-----------------------------------------------------------------------+ +| ``GatewayServingRuntime`` | Owns gateway actor lifecycle. ``gateway_count=0`` degrades to a thin | +| | LLM client passthrough (no HTTP layer). | ++--------------------------------------+-----------------------------------------------------------------------+ +| ``GatewayActor`` | Ray actor running an HTTP server. Exposes ``/v1/chat/completions`` | +| | to the agent runner and collects token-level trajectories. | ++--------------------------------------+-----------------------------------------------------------------------+ +| ``AgentFrameworkRolloutAdapter`` | Trainer-facing glue in ``entry.py``. Satisfies the | +| | ``agent_loop_manager_class`` extension point contract. | ++--------------------------------------+-----------------------------------------------------------------------+ + + +Writing a Custom Agent Runner +----------------------------- + +An agent runner is any async callable with this signature: + +.. code:: python + + async def my_agent_runner( + *, + raw_prompt: list[dict], # OpenAI-format messages + session: SessionHandle, # .base_url is the per-session endpoint + sample_index: int, + **kwargs, # extra fields from dataset non_tensor columns + ) -> None: + """Run agent logic against the gateway session.""" + import httpx + + async with httpx.AsyncClient() as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "any", "messages": raw_prompt}, + ) + # ... tool calls, multi-turn loops, etc. + + # Signal that the session is complete (triggers trajectory finalization) + await client.post(session.base_url.removesuffix("/v1") + "/complete") + +The framework handles session creation, trajectory collection, reward scoring, +and TQ writes. The agent runner only needs to make HTTP requests and signal +completion. + + +Configuration Reference +----------------------- + +All fields live under ``actor_rollout_ref.rollout.custom.agent_framework``: + +.. code:: yaml + + actor_rollout_ref: + rollout: + agent: + agent_loop_manager_class: verl.agent.framework.entry.AgentFrameworkRolloutAdapter + custom: + agent_framework: + # Required: FQN of your agent runner function + agent_runner_fqn: my_package.my_module.my_agent_runner + + # Number of gateway actors (HTTP servers). 0 = no gateway, passthrough only. + gateway_count: 8 + + # Optional: kwargs passed to agent_runner via functools.partial + agent_runner_kwargs: + max_turns: 5 + + # Optional: tool config yaml for tool initialization + tool_config_path: path/to/tool_config.yaml + + # Optional: timeout for session completion (seconds). null = no wait. + completion_timeout_seconds: 30 + + # Optional: max concurrent sessions (0 = unlimited) + max_concurrent_sessions: 0 + + # Optional: FQN of framework subclass (default: OpenAICompatibleAgentFramework) + framework_class_fqn: verl.agent.framework.framework.OpenAICompatibleAgentFramework + + +Usage Example +------------- + +**Full training run** (requires GPU cluster + judge model): + +.. code:: bash + + bash examples/grpo_trainer/run_deepeyes_gateway_grpo.sh + +**Minimal CPU-only tutorial** (no GPU required): + +.. code:: bash + + python examples/tutorial/agent_framework_get_started/minimal_e2e.py + +The tutorial demonstrates the runtime → framework → generate_sequences path +with a fake rollout server, real gateway actor, and real framework orchestration. + + +See Also +-------- + +- :doc:`Agent Loop ` — legacy single/multi-turn rollout path +- :doc:`Agentic RL overview <../start/agentic_rl>` — high-level introduction +- :doc:`Reward Loop ` — reward worker integration diff --git a/docs/index.rst b/docs/index.rst index 6d9714acbe1..fda17bc07d9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -144,6 +144,7 @@ verl is fast with: advance/rollout_trace.rst advance/rollout_skip.rst advance/agent_loop + advance/agent_framework advance/reward_loop data/transfer_queue.md advance/grafana_prometheus.md diff --git a/docs/start/agentic_rl.rst b/docs/start/agentic_rl.rst index 46ca53d447f..0d42e052364 100644 --- a/docs/start/agentic_rl.rst +++ b/docs/start/agentic_rl.rst @@ -109,18 +109,28 @@ Follow :doc:`Rollout trace<../advance/rollout_trace>` to known more about trace Agent Framework --------------- +For the session-based Agent Framework (``verl.agent.framework``), which provides +per-session HTTP isolation and structured reward dispatch for agentic RL, see +:doc:`Agent Framework <../advance/agent_framework>`. + +The LangGraph-based agent path below is a separate recipe that uses LangChain +abstractions on top of the same inference backend. + +LangGraph Agent +~~~~~~~~~~~~~~~ + System Architecture -~~~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^^^ .. image:: https://github.com/eric-haibin-lin/verl-community/blob/main/docs/langgraph_agent.png?raw=true System Components -~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^ +--------------------------+-----------------------------------------------------------------------------------------------+ | Component | Role | +==========================+===============================================================================================+ -| ChatModel | LLM object of LangChain, used to adapt to the “generate” api provided by LLMServerClient | +| ChatModel | LLM object of LangChain, used to adapt to the "generate" api provided by LLMServerClient | +--------------------------+-----------------------------------------------------------------------------------------------+ | ReactAgentLoop | Agent adaptation layer, which by default supports a naive LangGraph Agentic. | | | New classes can be derived to support user-defined Agents, and the run function needs to be | diff --git a/examples/grpo_trainer/run_deepeyes_gateway_grpo.sh b/examples/grpo_trainer/run_deepeyes_gateway_grpo.sh new file mode 100755 index 00000000000..9630ea401fd --- /dev/null +++ b/examples/grpo_trainer/run_deepeyes_gateway_grpo.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +# GRPO | Agent Framework + Gateway | DeepEyes multimodal tool-use +# +# This script trains a vision-language model with agentic tool-use rollouts +# using the Agent Framework + Gateway stack. Each rollout sample gets its own +# HTTP session where the agent runner can make multi-turn chat completions +# requests and invoke tools (e.g., image zoom). +# +# Prerequisites: +# - A judge/reward model serving at LLM_AS_A_JUDGE_BASE (default: localhost:18901) +# - DeepEyes dataset parquet file at TRAIN_FILE +# - Model checkpoint at MODEL_PATH (or HuggingFace model ID) +# +# See docs/advance/agent_framework.rst for architecture details. + +set -xeuo pipefail + +########################### user-adjustable ########################### +MODEL_PATH=${MODEL_PATH:-/data1/models/Qwen/Qwen3.5-4B} +TRAIN_FILE=${TRAIN_FILE:-/data1/datasets/deepeyes/data/data_0.1.2_visual_toolbox_v2.parquet} +VAL_FILE=${VAL_FILE:-${TRAIN_FILE}} + +NGPUS_PER_NODE=${NGPUS_PER_NODE:-7} +TOTAL_TRAINING_STEPS=${TOTAL_TRAINING_STEPS:-50} +TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-14} + +# Agent Framework specific +GATEWAY_COUNT=${GATEWAY_COUNT:-7} +MAX_TURNS=${MAX_TURNS:-5} +COMPLETION_TIMEOUT=${COMPLETION_TIMEOUT:-} + +# Reward judge endpoint +LLM_AS_A_JUDGE_BASE=${LLM_AS_A_JUDGE_BASE:-http://127.0.0.1:18901/v1} + +PROJECT_NAME=${PROJECT_NAME:-deepeyes_gateway_grpo} +EXPERIMENT_NAME=${EXPERIMENT_NAME:-qwen35_4b_deepeyes_gateway_grpo} +########################### end user-adjustable ########################### + +VERL_REPO_ROOT=$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd) +cd "${VERL_REPO_ROOT}" + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6}" +export VERL_FORCE_TQ_NESTED_READBACK="${VERL_FORCE_TQ_NESTED_READBACK:-1}" +export LLM_AS_A_JUDGE_BASE +export WANDB_MODE="${WANDB_MODE:-offline}" +export NCCL_P2P_DISABLE="${NCCL_P2P_DISABLE:-1}" +export NCCL_SHM_DISABLE="${NCCL_SHM_DISABLE:-1}" +export PYTHONUNBUFFERED=1 +export HYDRA_FULL_ERROR=1 + +python3 -m verl.trainer.main_ppo_sync \ + --config-path="${VERL_REPO_ROOT}/recipe/deepeyes_with_gateway/configs" \ + --config-name=deepeyes_gateway_grpo \ + data.train_files="${TRAIN_FILE}" \ + "data.val_files=[${VAL_FILE}]" \ + data.train_batch_size="${TRAIN_BATCH_SIZE}" \ + data.max_prompt_length=4096 \ + data.max_response_length=1024 \ + trainer.total_training_steps="${TOTAL_TRAINING_STEPS}" \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes=1 \ + 'trainer.logger=[console,wandb,tensorboard]' \ + trainer.project_name="${PROJECT_NAME}" \ + trainer.experiment_name="${EXPERIMENT_NAME}" \ + trainer.save_freq=-1 \ + trainer.test_freq=-1 \ + actor_rollout_ref.model.use_fused_kernels=False \ + actor_rollout_ref.model.use_remove_padding=True \ + '+actor_rollout_ref.model.override_config.attn_implementation=eager' \ + actor_rollout_ref.actor.ppo_mini_batch_size="${TRAIN_BATCH_SIZE}" \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.response_length=1024 \ + actor_rollout_ref.rollout.max_model_len=8192 \ + actor_rollout_ref.rollout.max_num_seqs=4 \ + actor_rollout_ref.rollout.max_num_batched_tokens=16384 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.55 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.dtype=float16 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.custom.agent_framework.gateway_count="${GATEWAY_COUNT}" \ + actor_rollout_ref.rollout.custom.agent_framework.agent_runner_kwargs.max_turns="${MAX_TURNS}" diff --git a/examples/tutorial/agent_framework_get_started/README.md b/examples/tutorial/agent_framework_get_started/README.md new file mode 100644 index 00000000000..32be822e1ac --- /dev/null +++ b/examples/tutorial/agent_framework_get_started/README.md @@ -0,0 +1,60 @@ +# Agent Framework Get Started + +Minimal runnable entry for the `verl.agent.framework` + `verl.agent.gateway` +stack (PR #6299). + +It demonstrates three boundaries: + +1. The caller creates `GatewayServingRuntime` externally (entry.py does this + in production; here we do it manually for visibility). +2. `GatewayServingRuntime` is injected into `OpenAICompatibleAgentFramework`. +3. The framework is exercised with one `generate_sequences(...)` call on a + minimal `TensorDict`. + +Inside the script, the agent side is split into two layers: + +- `agent_runner(...)`: the framework-facing adapter that receives a + `SessionHandle` and extracts `session.base_url` +- `run_mock_agent(base_url, raw_prompt)`: an external-agent-style function + that only knows an OpenAI-compatible backend URL plus prompt messages + +That keeps the gateway-specific lifecycle shim visible, while showing how a +normal agent can treat the gateway as its backend URL. + +This is intentionally **not** a trainer integration example. It uses: + +- a tiny fake rollout server actor (Ray remote), +- the real `GlobalRequestLoadBalancer`, +- the real `GatewayServingRuntime` with `gateway_count=1`, +- the real `GatewayActor` (HTTP server), +- the real `OpenAICompatibleAgentFramework`. + +The example runs CPU-only and requires no GPU. `reward_loop_worker_handles=None` +means reward scoring is skipped; `rm_scores` is zero-filled in the TQ output +(matching the framework's default behavior when no reward workers are available). + +## Run + +```bash +python examples/tutorial/agent_framework_get_started/minimal_e2e.py +``` + +The script will: + +1. Start Ray (local mode). +2. Start one fake rollout server actor. +3. Create a `GlobalRequestLoadBalancer`. +4. Create a `GatewayServingRuntime` with one gateway actor. +5. Construct `OpenAICompatibleAgentFramework` with the runtime. +6. Send one chat-completions request through the gateway. +7. Call `generate_sequences(...)` which writes to a fake TransferQueue. +8. Print a JSON summary of the output. +9. Shut down the runtime and Ray. + +## Architecture Reference + +For the full architecture, configuration reference, and production usage, see +[docs/advance/agent_framework.rst](../../../docs/advance/agent_framework.rst). + +For a full training run with GPU cluster, see +[examples/grpo_trainer/run_deepeyes_gateway_grpo.sh](../../grpo_trainer/run_deepeyes_gateway_grpo.sh). diff --git a/examples/tutorial/agent_framework_get_started/minimal_e2e.py b/examples/tutorial/agent_framework_get_started/minimal_e2e.py new file mode 100644 index 00000000000..2a3f92e0956 --- /dev/null +++ b/examples/tutorial/agent_framework_get_started/minimal_e2e.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import json + +import httpx +import ray + +from verl.agent.framework import framework as framework_module +from verl.agent.framework.framework import OpenAICompatibleAgentFramework +from verl.agent.gateway.runtime import GatewayServingRuntime +from verl.utils import tensordict_utils as tu +from verl.workers.rollout.llm_server import GlobalRequestLoadBalancer, LLMServerClient +from verl.workers.rollout.replica import TokenOutput + + +class MinimalTokenizer: + """Small tokenizer stub for the gateway tutorial example.""" + + def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True, tools=None, **kwargs): + del tools, kwargs + parts = [] + for message in messages: + parts.append("{}:{}\n".format(message["role"], self._normalize_content(message.get("content", "")))) + if add_generation_prompt: + parts.append("assistant:") + text = "".join(parts) + if tokenize: + return [ord(char) for char in text] + return text + + def decode(self, token_ids, skip_special_tokens=True): + del skip_special_tokens + return "".join(chr(token_id) for token_id in token_ids) + + def encode(self, text, add_special_tokens=False): + del add_special_tokens + return [ord(char) for char in text] + + def _normalize_content(self, content): + if isinstance(content, list): + return "".join(part.get("text", "") if isinstance(part, dict) else str(part) for part in content) + if content is None: + return "" + return str(content) + + +@ray.remote +class MinimalRolloutServer: + def __init__(self, response_text: str = "MINIMAL"): + self.response_text = response_text + self.calls = [] + + async def generate( + self, + request_id, + *, + prompt_ids, + sampling_params, + image_data=None, + video_data=None, + ): + del image_data, video_data + self.calls.append( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + } + ) + token_ids = [ord(char) for char in self.response_text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + def get_calls(self): + return list(self.calls) + + +class MinimalTransferQueue: + def __init__(self): + self.puts = [] + self.batch_puts = [] + + async def async_kv_put(self, *, key, partition_id, tag): + self.puts.append({"key": key, "partition_id": partition_id, "tag": dict(tag)}) + + async def async_kv_batch_put(self, *, keys, fields, tags, partition_id): + self.batch_puts.append( + { + "keys": list(keys), + "fields": fields, + "tags": [dict(tag) for tag in tags], + "partition_id": partition_id, + } + ) + + +class MinimalReplayBuffer: + def __init__(self): + self.adds = [] + + def add(self, partition_id, items): + self.adds.append({"partition_id": partition_id, "items": dict(items)}) + + +def _build_prompts(): + return tu.get_tensordict( + tensor_dict={ + "raw_prompt": [[{"role": "user", "content": "Say MINIMAL"}]], + "uid": ["sample-0"], + }, + non_tensor_dict={"global_steps": 1}, + ) + + +async def run_mock_agent(*, base_url: str, raw_prompt) -> tuple[str, dict[str, object]]: + """Mimic an external agent that only knows an OpenAI-compatible backend URL.""" + + async with httpx.AsyncClient(timeout=5.0) as client: + chat_response = await client.post( + f"{base_url}/chat/completions", + json={ + "model": "minimal-model", + "messages": raw_prompt, + "temperature": 0.0, + }, + ) + chat_response.raise_for_status() + response_payload = chat_response.json() + + reward_info = {"score": 0.5, "label": "minimal-example"} + complete_response = await client.post( + base_url.removesuffix("/v1") + "/complete", + json={"reward_info": reward_info}, + ) + complete_response.raise_for_status() + + return response_payload["choices"][0]["message"]["content"], reward_info + + +async def run_example() -> dict[str, object]: + """Run the minimal end-to-end path through runtime -> framework -> generate_sequences.""" + + started_ray_here = False + runtime: GatewayServingRuntime | None = None + gateway_response_text = "" + fake_tq = MinimalTransferQueue() + replay_buffer = MinimalReplayBuffer() + original_tq = framework_module.tq + + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, include_dashboard=False) + started_ray_here = True + + try: + framework_module.tq = fake_tq + rollout_server = MinimalRolloutServer.remote("MINIMAL") + load_balancer = GlobalRequestLoadBalancer.remote({"server-0": rollout_server}) + llm_client = LLMServerClient( + config=None, + servers={"server-0": rollout_server}, + load_balancer_handle=load_balancer, + ) + + runtime = GatewayServingRuntime( + llm_client=llm_client, + gateway_count=1, + gateway_actor_kwargs={ + "tokenizer": MinimalTokenizer(), + }, + ) + + async def agent_runner(*, raw_prompt, session, sample_index): + nonlocal gateway_response_text + + assert session.base_url is not None + gateway_response_text, _reward_info = await run_mock_agent( + base_url=session.base_url, + raw_prompt=raw_prompt, + ) + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + reward_loop_worker_handles=None, + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + wait_for_completion_after_agent_run=True, + completion_timeout=5.0, + ) + + await framework.generate_sequences(_build_prompts()) + rollout_calls = ray.get(rollout_server.get_calls.remote()) + fields = fake_tq.batch_puts[0]["fields"] + + # Everything returned here is example evidence for reviewers/tests, + # not a suggested public API shape for framework consumers. + return { + "runtime_class": type(runtime).__name__, + "framework_class": type(framework).__name__, + "agent_runner_contract": "session_to_base_url_adapter", + "gateway_response_text": gateway_response_text, + "replay_buffer_adds": replay_buffer.adds, + "tq_keys": fake_tq.batch_puts[0]["keys"], + "finished_tags": fake_tq.puts, + "uid_values": tu.get(fields, "uid"), + # Tutorial intentionally omits reward computation + # (reward_loop_worker_handles=None); rm_scores is zero-filled. + "has_rm_scores": "rm_scores" in fields.keys(), + "rollout_calls": rollout_calls, + } + finally: + framework_module.tq = original_tq + if runtime is not None: + await runtime.shutdown() + if started_ray_here and ray.is_initialized(): + ray.shutdown() + + +def main() -> None: + import asyncio + + print(json.dumps(asyncio.run(run_example()), indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/recipe b/recipe index e7f889574b8..fba9b21dbb5 160000 --- a/recipe +++ b/recipe @@ -1 +1 @@ -Subproject commit e7f889574b8301cc0f0fc1d57c6d67f31ffeb689 +Subproject commit fba9b21dbb599ac1a3431eef93900fca96543bdc diff --git a/tests/agent/README.md b/tests/agent/README.md new file mode 100644 index 00000000000..d037c69bd7e --- /dev/null +++ b/tests/agent/README.md @@ -0,0 +1,160 @@ +# Agent tests + +This directory contains CPU-only tests for the `verl.agent` framework and +gateway packages. The suite focuses on behavior that reviewers need to trust: +TransferQueue output schema, multimodal postprocessing, OpenAI-compatible +gateway sessions, session routing, and runtime ownership. + +## Naming and CI routing + +Executable test modules use the `*_on_cpu.py` suffix so VERL's CPU unit-test +workflow can discover them without pulling GPU-only rollout infrastructure. +Run the current suite with: + +```bash +pytest tests/agent/ -q +``` + +## Coverage inventory + +### Framework + +- `framework/test_generate_sequences_on_cpu.py` + - `test_generate_sequences_writes_tq_schema_for_each_session` + - Verifies `generate_sequences()` runs multiple sessions per prompt and + writes the TransferQueue key, tag, tensor, nested-tensor, and non-tensor + schema consumed by sync training. + - `test_generate_sequences_keeps_successful_sessions_when_one_session_fails` + - Verifies a failed session is aborted and reported without dropping + successful sessions for the same prompt. + - `test_generate_sequences_marks_prompt_failure_when_all_sessions_fail` + - Verifies all-failed prompts write a failure status and no trajectory + batch. + - `test_generate_sequences_omits_rm_scores_when_reward_fn_is_none` + - Verifies reward-free generation omits `rm_scores` instead of inventing + scores. + - `test_generate_sequences_keeps_other_prompts_when_prompt_task_raises` + - Verifies an unexpected prompt-level exception is counted as one failed + uid while other prompt tasks still contribute results. +- `framework/test_multi_modal_postprocess_on_cpu.py` + - `test_compute_multi_modal_inputs_returns_empty_dict_without_processor` + - Verifies text-only execution produces no multimodal processor inputs. + - `test_compute_multi_modal_inputs_returns_image_tensors_and_images_seqlens` + - Verifies processor image outputs drop duplicate text tensors and add + `images_seqlens`. + - `test_compute_position_ids_returns_text_shape_without_processor` + - Verifies text-only position ids keep the standard 2-D shape. + - `test_compute_position_ids_returns_multimodal_shape_with_processor` + - Verifies processor-aware position ids include text and vision channels + and derive `mm_token_type_ids` from image/video token ids. + +### Gateway + +- `gateway/test_gateway_actor_on_cpu.py` + - `test_gateway_actor_abort_session_does_not_wait_for_backend_generate` + - Verifies `abort_session()` can complete while a backend generation is + still in flight. + - `test_normalize_request_context_preserves_multimodal_blocks_for_later_extraction` + - Verifies request normalization preserves multimodal content, `tools`, + `tool_calls`, and `tool_call_id` fields needed by later gateway stages. + - `test_gateway_actor_forwards_image_data_on_initial_multimodal_request` + - Verifies the initial multimodal request extracts image data, forwards it + to the backend, and materializes it into trajectory metadata. + - `test_gateway_actor_complete_wait_and_finalize` + - Verifies `/complete`, `wait_for_completion()`, and `finalize_session()` + cooperate on the happy path and attach reward info. + - `test_gateway_actor_continuation_reuses_accumulated_media_context` + - Verifies continuation turns reuse existing session media without + re-extracting the original image. + - `test_gateway_actor_multimodal_reference_change_splits_trajectory` + - Verifies changing multimodal request context starts a new trajectory. + - `test_gateway_actor_continuation_with_tool_returned_image_appends_media` + - Verifies a tool-returned image is appended to accumulated media and + encoded into the incremental prompt. + - `test_gateway_actor_prefix_mismatch_splits_trajectories` + - Verifies message-history prefix mismatch materializes the active + trajectory and starts the next one. + - `test_gateway_actor_tool_context_change_splits_trajectory` + - Verifies tool-schema changes split trajectories. + - `test_gateway_actor_does_not_forward_tools_in_sampling_params` + - Verifies `tools` do not leak into backend sampling params. + - `test_gateway_actor_strips_request_envelope_but_keeps_sampling_params` + - Verifies backend sampling params come from gateway base params plus + whitelisted request overrides, not request-envelope fields. + - `test_gateway_actor_ignores_non_whitelisted_request_sampling_params` + - Verifies non-whitelisted request sampling fields are ignored. + - `test_gateway_actor_continuation_preserves_prompt_and_generation_masks` + - Verifies continuation context uses mask `0` and new model output uses + mask `1`. + - `test_gateway_actor_tool_argument_json_equivalence_does_not_split_after_valid_continuation` + - Verifies JSON-equivalent tool-call argument strings do not split a valid + continuation. + - `test_message_prefix_falls_back_to_raw_tool_argument_value_comparison_when_arguments_are_invalid_json` + - Verifies invalid tool-call argument strings compare by raw value. + - `test_gateway_actor_serializes_same_session_concurrent_requests` + - Verifies concurrent requests for one session are serialized before they + reach the backend. + - `test_gateway_actor_rejects_chat_after_complete` + - Verifies chat requests after completion return HTTP 409. + - `test_gateway_actor_finalizes_without_complete` + - Verifies finalization can materialize an active trajectory even when + `/complete` was never called. + - `test_gateway_actor_rejects_malformed_requests_with_bad_request` + - Verifies representative malformed OpenAI request shapes return HTTP 400. + - `test_gateway_actor_backend_failure_does_not_commit_partial_state` + - Verifies backend failure returns HTTP 500 without committing a partial + trajectory. + - `test_gateway_actor_backend_failure_after_tool_mismatch_does_not_split` + - Verifies a failed split attempt leaves the previous active trajectory + intact. + - `test_gateway_actor_tool_call_decode_returns_openai_format` + - Verifies tool-parser output is decoded into OpenAI-compatible + `tool_calls` and can be continued with a tool-result turn. +- `gateway/test_gateway_manager_on_cpu.py` + - `test_gateway_manager_routes_sessions_stickily` + - Verifies created sessions remain routed to their owning gateway through + chat and finalization. + - `test_gateway_manager_uses_least_active_sessions_routing` + - Verifies new sessions are assigned to the gateway with the fewest active + sessions and counters are decremented on finalization. + - `test_gateway_manager_wait_for_completion_delegates_to_session_owner` + - Verifies completion waits are delegated to the gateway that owns the + session. +- `gateway/test_session_runtime_on_cpu.py` + - `test_gateway_serving_runtime_owns_gateway_lifecycle_and_session_runtime` + - Verifies `GatewayServingRuntime` can own gateway actors, expose the + session runtime, delegate backend generation through itself, and shut down. + - `test_gateway_serving_runtime_delegates_generate_to_llm_client` + - Verifies generate-only mode delegates directly to the supplied LLM client + when no gateway actors are configured. + +## Mocking boundaries + +- Real code under test: `verl.agent.framework.*` and `verl.agent.gateway.*`. + Gateway actor tests also use real Ray actors, FastAPI routing, and HTTPX + requests against local in-process servers. +- External systems intentionally excluded: real `LLMServer`, model weights, + GPU rollout engines, recipe submodules, external smoke tests, and trainer + integration jobs. +- `tests/agent/support.py` provides shared fakes: + - Tokenization and processors: `FakeTokenizer`, `FakeProcessor`. + - Multimodal extraction: `fake_vision_info_extractor`, + `SingleUseVisionInfoExtractor`. + - Backend behavior: `InspectingBackend`, `InspectingSequencedBackend`, + `QueuedBackend`, `SlowBackend`, `RecordingLLMClient`, + `RejectToolsSamplingParamsBackend`, `RejectRequestEnvelopeBackend`, + `FailingBackend`, `SequencedBackend`, `RejectConcurrentSessionBackend`. + - Manager/runtime actors: `TrackingGatewayActor`. + +## Intentional gaps + +- Backend-fatal layering is intentionally not covered here. Risk analysis + classifies it as P0 follow-up work because current framework/gateway code + does not yet distinguish backend-fatal failures from recoverable session + failures. +- `abort_session` backend propagation is intentionally not covered here. The + current tests only verify gateway-side session cleanup/non-blocking behavior; + request-level backend abort is P1 follow-up work. +- Framework timeout and health behavior is intentionally not covered here. + Current code has only optional completion waiting and no health/heartbeat + contract, so tests for that would describe future code rather than this PR. diff --git a/tests/agent/framework/test_generate_sequences_on_cpu.py b/tests/agent/framework/test_generate_sequences_on_cpu.py new file mode 100644 index 00000000000..7eae488cac3 --- /dev/null +++ b/tests/agent/framework/test_generate_sequences_on_cpu.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import pytest + +from verl.agent.framework.types import SessionHandle, Trajectory +from verl.utils import tensordict_utils as tu + + +class _FakeTransferQueue: + def __init__(self): + self.puts = [] + self.batch_puts = [] + + async def async_kv_put(self, *, key, partition_id, tag): + self.puts.append({"key": key, "partition_id": partition_id, "tag": dict(tag)}) + + async def async_kv_batch_put(self, *, keys, fields, tags, partition_id): + self.batch_puts.append( + { + "keys": list(keys), + "fields": fields, + "tags": [dict(tag) for tag in tags], + "partition_id": partition_id, + } + ) + + +class _FakeReplayBuffer: + def __init__(self): + self.adds = [] + + def add(self, partition_id, items): + self.adds.append({"partition_id": partition_id, "items": dict(items)}) + + +class _FakeSessionRuntime: + """Fake runtime that matches session IDs by prefix (``session-{sample}-{session}``) + to support the real uuid-suffixed IDs produced by the framework.""" + + def __init__(self, finalized_by_session_prefix: dict[str, list[Trajectory]]): + self._finalized_by_prefix = finalized_by_session_prefix + self.created_sessions = [] + self.finalized_sessions = [] + self.aborted_sessions = [] + + def _lookup(self, session_id: str) -> list[Trajectory]: + for prefix, trajectories in self._finalized_by_prefix.items(): + if session_id.startswith(prefix): + return trajectories + raise KeyError(f"No prefix match for session_id={session_id}") + + async def create_session(self, session_id: str, **kwargs): + self.created_sessions.append(session_id) + return SessionHandle(session_id=session_id, base_url=f"http://fake/{session_id}/v1") + + async def finalize_session(self, session_id: str): + self.finalized_sessions.append(session_id) + return self._lookup(session_id) + + async def abort_session(self, session_id: str) -> None: + self.aborted_sessions.append(session_id) + + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: + return None + + +def _build_prompts(count: int = 2, *, global_steps: int = 7, validate: bool = False): + non_tensor_dict = {"global_steps": global_steps} + if validate: + non_tensor_dict["validate"] = True + return tu.get_tensordict( + tensor_dict={ + "raw_prompt": [[{"role": "user", "content": f"sample {i}"}] for i in range(count)], + "uid": [f"uid-{i}" for i in range(count)], + "data_source": ["deepeyes"] * count, + "reward_model": [{"ground_truth": f"answer-{i}"} for i in range(count)], + "extra_info": [{"index": i} for i in range(count)], + "tools_kwargs": [{"tool": i} for i in range(count)], + "agent_name": ["deepeyes"] * count, + }, + non_tensor_dict=non_tensor_dict, + ) + + +def _trajectory( + *, + prompt_ids: list[int] | None = None, + response_ids: list[int] | None = None, + response_logprobs: list[float] | None = None, + num_turns: int = 2, +): + prompt_ids = prompt_ids or [10, 11] + response_ids = response_ids or [20, 21] + return Trajectory( + prompt_ids=prompt_ids, + response_ids=response_ids, + response_mask=[1] * len(response_ids), + response_logprobs=response_logprobs, + reward_score=None, + num_turns=num_turns, + multi_modal_data={"images": ["raw-image-should-not-be-written"]}, + ) + + +def _install_fake_score(monkeypatch, *, score_from_sample_fields=None, default_score=1.0): + """Replace OpenAICompatibleAgentFramework._score_trajectories with a fake. + + Mirrors the production "score-last + broadcast" behavior: returns the same + (score, extra_info) for every trajectory in the session. The score is + derived from sample_fields if a callable is provided; otherwise default_score. + """ + from verl.agent.framework.framework import OpenAICompatibleAgentFramework + + async def fake_score(self, trajectories, sample_fields): + if score_from_sample_fields is not None: + score = float(score_from_sample_fields(sample_fields)) + else: + score = float(default_score) + return [(score, {})] * len(trajectories) + + monkeypatch.setattr(OpenAICompatibleAgentFramework, "_score_trajectories", fake_score) + + +@pytest.mark.asyncio +async def test_generate_sequences_writes_tq_schema_for_each_session(monkeypatch): + from verl.agent.framework import framework as framework_module + from verl.agent.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + + runtime = _FakeSessionRuntime( + { + "session-0-0": [_trajectory(response_logprobs=[-0.1, -0.2])], + "session-0-1": [_trajectory(response_logprobs=[-0.3, -0.4])], + "session-1-0": [_trajectory(response_logprobs=[-0.5, -0.6])], + "session-1-1": [_trajectory(response_logprobs=[-0.7, -0.8])], + } + ) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + assert raw_prompt == [{"role": "user", "content": f"sample {sample_index}"}] + assert tools_kwargs == {"tool": sample_index} + + # Score derived from sample_fields["extra_info"]["index"] + 0.25 (same as legacy lambda) + _install_fake_score( + monkeypatch, + score_from_sample_fields=lambda sf: sf["extra_info"]["index"] + 0.25, + ) + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + reward_loop_worker_handles=["sentinel"], + replay_buffer=replay_buffer, + rollout_config={"n": 2, "val_kwargs": {"n": 2}}, + ) + + + await framework.generate_sequences(_build_prompts(global_steps=7)) + + assert replay_buffer.adds == [ + { + "partition_id": "train", + "items": { + "uid-0": {"global_steps": 7, "status": "running"}, + "uid-1": {"global_steps": 7, "status": "running"}, + }, + } + ] + assert fake_tq.batch_puts[0]["keys"] == ["uid-0_0_0"] + assert fake_tq.batch_puts[1]["keys"] == ["uid-0_1_0"] + assert fake_tq.batch_puts[2]["keys"] == ["uid-1_0_0"] + assert fake_tq.batch_puts[3]["keys"] == ["uid-1_1_0"] + assert fake_tq.puts == [ + {"key": "uid-0", "partition_id": "train", "tag": {"status": "finished"}}, + {"key": "uid-1", "partition_id": "train", "tag": {"status": "finished"}}, + ] + + first = fake_tq.batch_puts[0] + fields = first["fields"] + assert first["partition_id"] == "train" + assert first["tags"] == [ + {"global_steps": 7, "status": "success", "prompt_len": 2, "response_len": 2, "seq_len": 4} + ] + assert fields["input_ids"].is_nested + assert fields["response_mask"].is_nested + assert fields["position_ids"].is_nested + assert fields["prompts"][0].tolist() == [10, 11] + assert fields["responses"][0].tolist() == [20, 21] + assert fields["response_mask"][0].tolist() == [1, 1] + assert fields["loss_mask"][0].tolist() == [1, 1] + assert fields["input_ids"][0].tolist() == [10, 11, 20, 21] + assert fields["attention_mask"][0].tolist() == [1, 1, 1, 1] + assert fields["position_ids"][0].tolist() == [0, 1, 2, 3] + assert fields["rollout_log_probs"][0].tolist() == pytest.approx([-0.1, -0.2]) + assert fields["rm_scores"][0].tolist() == [0.0, 0.25] + assert tu.get(fields, "multi_modal_inputs") == [{}] + assert tu.get(fields, "uid") == ["uid-0"] + assert tu.get(fields, "raw_prompt") == [[{"role": "user", "content": "sample 0"}]] + assert tu.get(fields, "data_source") == ["deepeyes"] + assert tu.get(fields, "reward_model") == [{"ground_truth": "answer-0"}] + assert tu.get(fields, "extra_info") == [{"index": 0}] + assert tu.get(fields, "tools_kwargs") == [{"tool": 0}] + assert tu.get(fields, "agent_name") == ["deepeyes"] + assert tu.get(fields, "session_id") == [0] + assert tu.get(fields, "global_steps") == [7] + assert fields["num_turns"].tolist() == [2] + assert "multi_modal_data" not in fields.keys() + + +@pytest.mark.asyncio +async def test_generate_sequences_keeps_successful_sessions_when_one_session_fails(monkeypatch): + from verl.agent.framework import framework as framework_module + from verl.agent.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + runtime = _FakeSessionRuntime( + { + "session-0-0": [_trajectory()], + "session-0-1": [_trajectory()], + } + ) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + if session.session_id.startswith("session-0-1-"): + raise RuntimeError("gateway failed once") + + _install_fake_score(monkeypatch, default_score=1.0) + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + reward_loop_worker_handles=["sentinel"], + replay_buffer=replay_buffer, + rollout_config={"n": 2, "val_kwargs": {"n": 2}}, + ) + + + await framework.generate_sequences(_build_prompts(count=1, global_steps=8)) + + assert replay_buffer.adds == [ + {"partition_id": "train", "items": {"uid-0": {"global_steps": 8, "status": "running"}}} + ] + assert fake_tq.batch_puts[0]["keys"] == ["uid-0_0_0"] + assert fake_tq.puts == [{"key": "uid-0", "partition_id": "train", "tag": {"status": "finished"}}] + assert len(runtime.aborted_sessions) == 1 + assert runtime.aborted_sessions[0].startswith("session-0-1-") + + +@pytest.mark.asyncio +async def test_generate_sequences_marks_prompt_failure_when_all_sessions_fail(monkeypatch): + from verl.agent.framework import framework as framework_module + from verl.agent.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + runtime = _FakeSessionRuntime({"session-0-0": [], "session-0-1": []}) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + raise RuntimeError(f"failed {session.session_id}") + + _install_fake_score(monkeypatch, default_score=1.0) + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + reward_loop_worker_handles=["sentinel"], + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 2}}, + ) + + + with pytest.raises(RuntimeError, match="All rollouts failed at global_steps=9"): + await framework.generate_sequences(_build_prompts(count=1, global_steps=9, validate=True)) + + assert replay_buffer.adds == [ + {"partition_id": "val", "items": {"uid-0": {"global_steps": 9, "status": "running"}}} + ] + assert fake_tq.batch_puts == [] + assert fake_tq.puts == [{"key": "uid-0", "partition_id": "val", "tag": {"status": "failure"}}] + + +@pytest.mark.asyncio +async def test_generate_sequences_zero_fills_rm_scores_when_no_reward_handles(monkeypatch): + from verl.agent.framework import framework as framework_module + from verl.agent.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + runtime = _FakeSessionRuntime({"session-0-0": [_trajectory()]}) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + return None + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + ) + + + await framework.generate_sequences(_build_prompts(count=1, global_steps=10)) + + # rm_scores is always written (zero-filled when no reward) so the trainer's + # KVBatchMeta select_fields never hits a missing field across the batch. + rm_scores = fake_tq.batch_puts[0]["fields"]["rm_scores"] + assert rm_scores[0].tolist() == [0.0, 0.0] + + +@pytest.mark.asyncio +async def test_generate_sequences_keeps_other_prompts_when_prompt_task_raises(monkeypatch, caplog): + from verl.agent.framework.framework import OpenAICompatibleAgentFramework + + replay_buffer = _FakeReplayBuffer() + runtime = _FakeSessionRuntime( + { + "session-1-0": [_trajectory()], + } + ) + + _install_fake_score(monkeypatch, default_score=1.0) + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=lambda **_: None, + reward_loop_worker_handles=["sentinel"], + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + ) + + async def fake_run_prompt_sessions_to_tq(*, sample_index, **kwargs): + if sample_index == 0: + raise RuntimeError("prompt 0 exploded") + return { + "num_success_sessions": 1, + "num_failed_sessions": 0, + "num_success_outputs": 1, + "num_failed_uids": 0, + "failure_reasons": [], + } + + monkeypatch.setattr(framework, "_run_prompt_sessions_to_tq", fake_run_prompt_sessions_to_tq) + + caplog.set_level("INFO") + await framework.generate_sequences(_build_prompts(count=2, global_steps=11)) + + assert replay_buffer.adds == [ + { + "partition_id": "train", + "items": { + "uid-0": {"global_steps": 11, "status": "running"}, + "uid-1": {"global_steps": 11, "status": "running"}, + }, + } + ] + assert "num_failed_uids=1" in caplog.text + assert "prompt 0 exploded" in caplog.text + + +@pytest.mark.asyncio +async def test_generate_sequences_zero_fills_rollout_log_probs_when_missing(monkeypatch): + from verl.agent.framework import framework as framework_module + from verl.agent.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + # Trajectory without response_logprobs (e.g. backend returned no logprobs). + runtime = _FakeSessionRuntime({"session-0-0": [_trajectory(response_logprobs=None)]}) + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + return None + + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + ) + + + await framework.generate_sequences(_build_prompts(count=1, global_steps=10)) + + # rollout_log_probs is zero-filled rather than omitted so the trainer's + # bypass-mode select_fields(["rollout_log_probs"]) never KeyErrors. + rollout_log_probs = fake_tq.batch_puts[0]["fields"]["rollout_log_probs"] + assert rollout_log_probs[0].tolist() == [0.0, 0.0] + + +@pytest.mark.asyncio +async def test_max_concurrent_sessions_caps_in_flight_sessions(monkeypatch): + import asyncio + + from verl.agent.framework import framework as framework_module + from verl.agent.framework.framework import OpenAICompatibleAgentFramework + + fake_tq = _FakeTransferQueue() + replay_buffer = _FakeReplayBuffer() + monkeypatch.setattr(framework_module, "tq", fake_tq) + runtime = _FakeSessionRuntime( + {f"session-{i}-0": [_trajectory()] for i in range(4)} + ) + + in_flight = 0 + max_observed = 0 + + async def agent_runner(*, raw_prompt, session, sample_index, tools_kwargs, **kwargs): + nonlocal in_flight, max_observed + in_flight += 1 + max_observed = max(max_observed, in_flight) + await asyncio.sleep(0.01) + in_flight -= 1 + return None + + _install_fake_score(monkeypatch, default_score=1.0) + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=agent_runner, + replay_buffer=replay_buffer, + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + max_concurrent_sessions=2, + ) + + + await framework.generate_sequences(_build_prompts(count=4, global_steps=10)) + + assert max_observed <= 2 + + +# --------------------------------------------------------------------------- +# _score_trajectories method-level tests +# --------------------------------------------------------------------------- + + +@pytest.fixture +def ray_runtime(): + import ray + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.mark.asyncio +async def test_score_trajectories_dispatches_only_final_trajectory_and_broadcasts(ray_runtime): + """_score_trajectories scores trajectories[-1] only, broadcasts to all (matches AgentLoopWorkerTQ).""" + import ray as ray_module + from verl.agent.framework.framework import OpenAICompatibleAgentFramework + from verl.agent.framework.types import Trajectory + + @ray_module.remote + class _StubWorker: + def __init__(self): + self.calls = [] + + def compute_score(self, data): + self.calls.append(data) + return {"reward_score": 0.42, "reward_extra_info": {"acc": 1.0, "format": 0.8}} + + def get_call_count(self): + return len(self.calls) + + worker = _StubWorker.remote() + + runtime = _FakeSessionRuntime({}) # not used in this test + framework = OpenAICompatibleAgentFramework( + session_runtime=runtime, + agent_runner=lambda **_: None, + reward_loop_worker_handles=[worker], + replay_buffer=_FakeReplayBuffer(), + rollout_config={"n": 1, "val_kwargs": {"n": 1}}, + ) + + trajectories = [ + Trajectory(prompt_ids=[1, 2], response_ids=[3, 4], response_mask=[1, 1], num_turns=1), + Trajectory(prompt_ids=[5, 6], response_ids=[7, 8], response_mask=[1, 1], num_turns=2), + Trajectory(prompt_ids=[9, 10], response_ids=[11, 12], response_mask=[1, 1], num_turns=3), + ] + sample_fields = {"data_source": "test", "raw_prompt": [{"role": "user", "content": "hi"}]} + annotations = await framework._score_trajectories(trajectories, sample_fields) + + # Score-last + broadcast: 3 trajectories, but only 1 worker call + assert ray_module.get(worker.get_call_count.remote()) == 1 + # All 3 trajectories get the same score and extra_info + assert annotations == [ + (0.42, {"acc": 1.0, "format": 0.8}), + (0.42, {"acc": 1.0, "format": 0.8}), + (0.42, {"acc": 1.0, "format": 0.8}), + ] diff --git a/tests/agent/framework/test_multi_modal_postprocess_on_cpu.py b/tests/agent/framework/test_multi_modal_postprocess_on_cpu.py new file mode 100644 index 00000000000..934161faa4c --- /dev/null +++ b/tests/agent/framework/test_multi_modal_postprocess_on_cpu.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import torch + +from tests.agent.support import FakeProcessor + + +def test_compute_multi_modal_inputs_returns_empty_dict_without_processor(): + from verl.agent.framework.multi_modal_postprocess import compute_multi_modal_inputs + + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + + assert compute_multi_modal_inputs(None, input_ids, {"images": ["image://a.png"]}) == {} + + +def test_compute_multi_modal_inputs_returns_image_tensors_and_images_seqlens(): + from verl.agent.framework.multi_modal_postprocess import compute_multi_modal_inputs + + processor = FakeProcessor() + input_ids = torch.tensor([[11, processor.image_token_id, 12]], dtype=torch.long) + + multi_modal_inputs = compute_multi_modal_inputs( + processor, + input_ids, + {"images": ["image://a.png"]}, + ) + + assert "input_ids" not in multi_modal_inputs + assert "attention_mask" not in multi_modal_inputs + assert tuple(multi_modal_inputs["pixel_values"].shape) == (1, 3, 2, 2) + assert multi_modal_inputs["image_grid_thw"].tolist() == [[1, 2, 3]] + assert multi_modal_inputs["images_seqlens"].tolist() == [6] + assert "mm_token_type_ids" in multi_modal_inputs + + +def test_compute_position_ids_returns_text_shape_without_processor(): + from verl.agent.framework.multi_modal_postprocess import compute_position_ids + + input_ids = torch.tensor([[7, 8, 9, 10]], dtype=torch.long) + attention_mask = torch.tensor([[0, 1, 1, 1]], dtype=torch.long) + + position_ids = compute_position_ids(None, input_ids, attention_mask, {}) + + assert tuple(position_ids.shape) == (1, 4) + assert position_ids.tolist() == [[0, 0, 1, 2]] + + +def test_compute_position_ids_returns_multimodal_shape_with_processor(): + from verl.agent.framework.multi_modal_postprocess import compute_position_ids + + processor = FakeProcessor() + input_ids = torch.tensor( + [[11, processor.image_token_id, processor.video_token_id, 12]], + dtype=torch.long, + ) + attention_mask = torch.ones_like(input_ids) + multi_modal_inputs = { + "image_grid_thw": torch.tensor([[1, 2, 3]], dtype=torch.long), + "video_grid_thw": torch.tensor([[1, 3, 4]], dtype=torch.long), + "mm_token_type_ids": torch.ones_like(input_ids), + } + + position_ids = compute_position_ids(processor, input_ids, attention_mask, multi_modal_inputs) + + assert tuple(position_ids.shape) == (1, 4, 4) + assert position_ids[0, 0].tolist() == [0, 1, 2, 3] + assert processor.last_get_rope_index_call["mm_token_type_ids"].tolist() == [[0, 1, 2, 0]] diff --git a/tests/agent/gateway/test_gateway_actor_on_cpu.py b/tests/agent/gateway/test_gateway_actor_on_cpu.py new file mode 100644 index 00000000000..5e2754a89f2 --- /dev/null +++ b/tests/agent/gateway/test_gateway_actor_on_cpu.py @@ -0,0 +1,1019 @@ +import asyncio +import copy +import json + +import httpx +import pytest +import ray + +from tests.agent.support import ( + FailingBackend, + FakeProcessor, + FakeTokenizer, + InspectingBackend, + InspectingSequencedBackend, + QueuedBackend, + RejectConcurrentSessionBackend, + RejectRequestEnvelopeBackend, + RejectToolsSamplingParamsBackend, + SequencedBackend, + SingleUseVisionInfoExtractor, + SlowBackend, + fake_vision_info_extractor, +) + + +@pytest.fixture +def ray_runtime(): + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.mark.asyncio +async def test_gateway_actor_abort_session_does_not_wait_for_backend_generate(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=SlowBackend(delay_s=1.5), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-abort-during-generate")) + + async with httpx.AsyncClient(timeout=5.0) as client: + request_task = asyncio.create_task( + client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "slow path"}]}, + ) + ) + await asyncio.sleep(0.1) + + abort_ref = actor.abort_session.remote("session-abort-during-generate") + await asyncio.wait_for(asyncio.wrap_future(abort_ref.future()), timeout=0.3) + + request_task.cancel() + try: + await request_task + except (asyncio.CancelledError, httpx.HTTPError): + pass + + ray.get(actor.shutdown.remote()) + + +def test_normalize_request_context_preserves_multimodal_blocks_for_later_extraction(): + from verl.agent.gateway.gateway import _normalize_request_context + + context = _normalize_request_context( + { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "look"}, + {"type": "image_url", "image_url": {"url": "file://image.png"}}, + ], + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call-1", + "type": "function", + "function": {"name": "search", "arguments": "{\"query\": \"weather\"}"}, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call-1", + "content": [{"type": "text", "text": "sunny"}], + }, + ], + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + } + ) + + assert context["tools"][0]["function"]["name"] == "search" + assert context["messages"][0]["content"][1]["type"] == "image_url" + assert context["messages"][1]["tool_calls"][0]["id"] == "call-1" + assert context["messages"][2]["tool_call_id"] == "call-1" + + +@pytest.mark.asyncio +async def test_gateway_actor_forwards_image_data_on_initial_multimodal_request(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor, _normalize_request_context + + processor = FakeProcessor() + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + processor=processor, + vision_info_extractor=fake_vision_info_extractor, + backend=InspectingBackend(), + ) + ray.get(actor.start.remote()) + + session = ray.get(actor.create_session.remote("session-mm-initial")) + payload = { + "model": "dummy-model", + "temperature": 0.25, + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://a.png"}}, + {"type": "text", "text": "describe this image"}, + ], + } + ], + } + + normalized = _normalize_request_context(payload) + raw_prompt = processor.apply_chat_template( + normalized["messages"], + tokenize=False, + add_generation_prompt=True, + tools=normalized["tools"], + ) + expected_prompt_ids = processor( + text=[raw_prompt], + images=["image://a.png"], + videos=None, + return_tensors="pt", + do_sample_frames=False, + )["input_ids"][0].tolist() + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json=payload, + ) + + trajectories = ray.get(actor.finalize_session.remote("session-mm-initial")) + ray.get(actor.shutdown.remote()) + + assert response.status_code == 200 + backend_request = json.loads(response.json()["choices"][0]["message"]["content"]) + assert backend_request["image_data"] == ["image://a.png"] + assert backend_request["video_data"] is None + assert backend_request["prompt_ids"] == expected_prompt_ids + assert backend_request["sampling_params"] == {"temperature": 0.25} + assert len(trajectories) == 1 + assert trajectories[0].multi_modal_data == {"images": ["image://a.png"]} + + +@pytest.mark.asyncio +async def test_gateway_actor_complete_wait_and_finalize(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["ANSWER: A"])) + ray.get(actor.start.remote()) + + session = ray.get(actor.create_session.remote("session-0")) + wait_ref = actor.wait_for_completion.remote("session-0", timeout=2.0) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [{"role": "user", "content": "Pick label A"}], + }, + ) + assert response.status_code == 200 + assert response.json()["choices"][0]["message"]["content"] == "ANSWER: A" + + complete = await client.post( + f"{session.base_url.removesuffix('/v1')}/complete", + json={"reward_info": {"score": 1.0, "label": "A"}}, + ) + assert complete.status_code == 200 + + ray.get(wait_ref) + trajectories = ray.get(actor.finalize_session.remote("session-0")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + assert trajectories[0].reward_info == {"score": 1.0, "label": "A"} + assert trajectories[0].response_ids + assert all(mask == 1 for mask in trajectories[0].response_mask) + + +@pytest.mark.asyncio +async def test_gateway_actor_continuation_reuses_accumulated_media_context(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + processor=FakeProcessor(), + vision_info_extractor=SingleUseVisionInfoExtractor(), + backend=InspectingBackend(), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-mm-continuation")) + + initial_message = { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://a.png"}}, + {"type": "text", "text": "describe this image"}, + ], + } + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [initial_message], + }, + ) + assert first.status_code == 200 + assistant_message = first.json()["choices"][0]["message"] + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [ + initial_message, + assistant_message, + {"role": "user", "content": "follow up"}, + ], + }, + ) + + trajectories = ray.get(actor.finalize_session.remote("session-mm-continuation")) + ray.get(actor.shutdown.remote()) + + assert second.status_code == 200 + first_call = json.loads(first.json()["choices"][0]["message"]["content"]) + second_call = json.loads(second.json()["choices"][0]["message"]["content"]) + assert first_call["image_data"] == ["image://a.png"] + assert second_call["image_data"] == ["image://a.png"] + assert len(trajectories) == 1 + assert trajectories[0].multi_modal_data == {"images": ["image://a.png"]} + + +@pytest.mark.asyncio +async def test_gateway_actor_multimodal_reference_change_splits_trajectory(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + processor=FakeProcessor(), + vision_info_extractor=fake_vision_info_extractor, + backend=InspectingBackend(), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-mm-split")) + + first_payload = { + "model": "dummy-model", + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://a.png"}}, + {"type": "text", "text": "describe image a"}, + ], + } + ], + } + second_payload = { + "model": "dummy-model", + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://b.png"}}, + {"type": "text", "text": "describe image b"}, + ], + } + ], + } + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post(f"{session.base_url}/chat/completions", json=first_payload) + second = await client.post(f"{session.base_url}/chat/completions", json=second_payload) + + trajectories = ray.get(actor.finalize_session.remote("session-mm-split")) + ray.get(actor.shutdown.remote()) + + assert first.status_code == 200 + assert second.status_code == 200 + first_call = json.loads(first.json()["choices"][0]["message"]["content"]) + second_call = json.loads(second.json()["choices"][0]["message"]["content"]) + assert first_call["image_data"] == ["image://a.png"] + assert second_call["image_data"] == ["image://b.png"] + assert len(trajectories) == 2 + + +@pytest.mark.asyncio +async def test_gateway_actor_continuation_with_tool_returned_image_appends_media(ray_runtime): + from verl.utils.chat_template import apply_chat_template, initialize_system_prompt + from verl.agent.gateway.gateway import GatewayActor + + processor = FakeProcessor() + tool_call_text = '\n{"name": "search", "arguments": {"query": "crop"}}\n' + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + processor=processor, + vision_info_extractor=fake_vision_info_extractor, + backend=InspectingSequencedBackend([tool_call_text, "__inspect__"]), + tool_parser_name="hermes", + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-mm-tool-image")) + + tools = [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}] + initial_message = { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "image://a.png"}}, + {"type": "text", "text": "find a crop"}, + ], + } + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": tools, + "messages": [initial_message], + }, + ) + assert first.status_code == 200 + assistant_message = first.json()["choices"][0]["message"] + tool_message = { + "role": "tool", + "tool_call_id": assistant_message["tool_calls"][0]["id"], + "content": [ + {"type": "image_url", "image_url": {"url": "image://tool-b.png"}}, + {"type": "text", "text": "zoomed crop"}, + ], + } + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": tools, + "messages": [initial_message, assistant_message, tool_message], + }, + ) + + trajectories = ray.get(actor.finalize_session.remote("session-mm-tool-image")) + ray.get(actor.shutdown.remote()) + + assert second.status_code == 200 + second_call = json.loads(second.json()["choices"][0]["message"]["content"]) + assert second_call["image_data"] == ["image://a.png", "image://tool-b.png"] + assert len(trajectories) == 1 + assert trajectories[0].multi_modal_data == { + "images": ["image://a.png", "image://tool-b.png"], + } + + initial_raw_prompt = apply_chat_template( + processor, + [initial_message], + tools=tools, + tokenize=False, + add_generation_prompt=True, + ) + initial_prompt_ids = processor( + text=[initial_raw_prompt], + images=["image://a.png"], + videos=None, + return_tensors="pt", + do_sample_frames=False, + )["input_ids"][0].tolist() + + incremental_raw_prompt = apply_chat_template( + processor, + [tool_message], + tokenize=False, + add_generation_prompt=True, + ) + incremental_prompt_ids = processor( + text=[incremental_raw_prompt], + images=["image://tool-b.png"], + videos=None, + return_tensors="pt", + do_sample_frames=False, + )["input_ids"][0].tolist() + system_prompt = initialize_system_prompt(processor) + expected_incremental_ids = incremental_prompt_ids[len(system_prompt) :] + expected_prompt_ids = initial_prompt_ids + [ord(char) for char in tool_call_text] + expected_incremental_ids + assert second_call["prompt_ids"] == expected_prompt_ids + + +@pytest.mark.asyncio +async def test_gateway_actor_prefix_mismatch_splits_trajectories(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["FIRST", "SECOND"])) + ray.get(actor.start.remote()) + + session = ray.get(actor.create_session.remote("session-1")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + assert first.status_code == 200 + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [{"role": "user", "content": "replacement context"}], + }, + ) + assert second.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-1")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 2 + assert trajectories[0].prompt_ids != trajectories[1].prompt_ids + + +@pytest.mark.asyncio +async def test_gateway_actor_tool_context_change_splits_trajectory(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["FIRST", "SECOND"])) + ray.get(actor.start.remote()) + + session = ray.get(actor.create_session.remote("session-tools")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + assert first.status_code == 200 + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "lookup", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "first turn"}, + {"role": "assistant", "content": "FIRST"}, + {"role": "user", "content": "follow up"}, + ], + }, + ) + assert second.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-tools")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 2 + + +@pytest.mark.asyncio +async def test_gateway_actor_does_not_forward_tools_in_sampling_params(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=RejectToolsSamplingParamsBackend("SAFE"), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-no-tools-sampling")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_gateway_actor_strips_request_envelope_but_keeps_sampling_params(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=RejectRequestEnvelopeBackend( + "SAFE", + expected_sampling_params={"temperature": 0.25, "top_p": 0.8, "max_tokens": 128}, + ), + base_sampling_params={"temperature": 0.1, "top_p": 0.8, "max_tokens": 64}, + allowed_request_sampling_param_keys={"temperature", "max_tokens"}, + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-envelope-boundary")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "temperature": 0.25, + "max_tokens": 128, + "presence_penalty": 1.5, + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_gateway_actor_ignores_non_whitelisted_request_sampling_params(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=RejectRequestEnvelopeBackend( + "SAFE", + expected_sampling_params={"temperature": 0.1, "top_p": 0.9}, + ), + base_sampling_params={"temperature": 0.1, "top_p": 0.9}, + allowed_request_sampling_param_keys={"temperature"}, + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-non-whitelist")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "presence_penalty": 1.5, + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_gateway_actor_continuation_preserves_prompt_and_generation_masks(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["FIRST", "SECOND"])) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-continuation-mask")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [ + { + "role": "user", + "content": "first turn", + } + ], + }, + ) + assert first.status_code == 200 + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [ + {"role": "user", "content": "first turn"}, + {"role": "assistant", "content": "FIRST"}, + {"role": "user", "content": "follow up"}, + ], + }, + ) + assert second.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-continuation-mask")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + assert 0 in trajectories[0].response_mask + assert trajectories[0].response_mask[-len("SECOND") :] == [1] * len("SECOND") + + +@pytest.mark.asyncio +async def test_gateway_actor_tool_argument_json_equivalence_does_not_split_after_valid_continuation(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + tool_call_text = '\n{"name": "search", "arguments": {"b": 2, "a": 1}}\n' + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=QueuedBackend([tool_call_text, "SECOND", "THIRD"]), + tool_parser_name="hermes", + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-tool-arg-drift")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "what is the weather?"}], + }, + ) + assert first.status_code == 200 + assistant_tool_message = first.json()["choices"][0]["message"] + + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "what is the weather?"}, + assistant_tool_message, + {"role": "tool", "tool_call_id": assistant_tool_message["tool_calls"][0]["id"], "content": "sunny"}, + ], + }, + ) + assert second.status_code == 200 + + drifted_tool_message = copy.deepcopy(assistant_tool_message) + drifted_tool_message["tool_calls"][0]["function"]["arguments"] = json.dumps({"a": 1, "b": 2}) + third = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "what is the weather?"}, + drifted_tool_message, + {"role": "tool", "tool_call_id": assistant_tool_message["tool_calls"][0]["id"], "content": "sunny"}, + {"role": "assistant", "content": "SECOND"}, + {"role": "user", "content": "follow up"}, + ], + }, + ) + assert third.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-tool-arg-drift")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + assert 0 in trajectories[0].response_mask + assert trajectories[0].response_mask[-len("THIRD") :] == [1] * len("THIRD") + + +def test_message_prefix_falls_back_to_raw_tool_argument_value_comparison_when_arguments_are_invalid_json(): + from verl.agent.gateway.gateway import _is_message_prefix + + prefix = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call-1", + "type": "function", + "function": {"name": "search", "arguments": "{\"query\": weather}"}, + } + ], + } + ] + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call-1", + "type": "function", + "function": {"name": "search", "arguments": "{\"query\": sunny}"}, + } + ], + } + ] + + assert _is_message_prefix(prefix, messages) is False + + +@pytest.mark.asyncio +async def test_gateway_actor_serializes_same_session_concurrent_requests(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=RejectConcurrentSessionBackend(["FIRST", "SECOND"]), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-concurrent")) + + async with httpx.AsyncClient(timeout=5.0) as client: + async def send_request(): + return await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "messages": [{"role": "user", "content": "same session prompt"}], + }, + ) + + first, second = await asyncio.gather(send_request(), send_request()) + + trajectories = ray.get(actor.finalize_session.remote("session-concurrent")) + ray.get(actor.shutdown.remote()) + + assert first.status_code == 200 + assert second.status_code == 200 + assert len(trajectories) == 2 + assert trajectories[0].response_ids == [ord(char) for char in "FIRST"] + assert trajectories[1].response_ids == [ord(char) for char in "SECOND"] + assert trajectories[0].response_mask == [1] * len("FIRST") + assert trajectories[1].response_mask == [1] * len("SECOND") + + +@pytest.mark.asyncio +async def test_gateway_actor_rejects_chat_after_complete(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["DONE"])) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-completed-chat")) + ray.get(actor.complete_session.remote("session-completed-chat")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "after complete"}]}, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 409 + + +@pytest.mark.asyncio +async def test_gateway_actor_finalizes_without_complete(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["DONE"])) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-finalize-without-complete")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "finish directly"}]}, + ) + assert response.status_code == 200 + + trajectories = ray.get(actor.finalize_session.remote("session-finalize-without-complete")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + assert trajectories[0].reward_info == {} + + +@pytest.mark.parametrize( + ("payload", "detail_fragment"), + [ + ({"model": "dummy-model", "messages": []}, "messages must be non-empty"), + ( + {"model": "dummy-model", "messages": [{"role": "user", "name": 123, "content": "hello"}]}, + "message.name must be a string", + ), + ( + {"model": "dummy-model", "messages": [{"role": "user", "content": 123}]}, + "Unsupported content type", + ), + ( + { + "model": "dummy-model", + "messages": [{"role": "assistant", "content": "", "tool_calls": {"id": "call-1"}}], + }, + "tool_calls must be a list", + ), + ( + { + "model": "dummy-model", + "tools": {"type": "function"}, + "messages": [{"role": "user", "content": "hello"}], + }, + "tools must be a list", + ), + ], +) +@pytest.mark.asyncio +async def test_gateway_actor_rejects_malformed_requests_with_bad_request(ray_runtime, payload, detail_fragment): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["DONE"])) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-validation")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json=payload, + ) + + ray.get(actor.shutdown.remote()) + + assert response.status_code == 400 + assert detail_fragment in response.text + + +@pytest.mark.asyncio +async def test_gateway_actor_backend_failure_does_not_commit_partial_state(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote(tokenizer=FakeTokenizer(), backend=FailingBackend("boom")) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-backend-failure")) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "first turn"}]}, + ) + + state = ray.get(actor.get_session_state.remote("session-backend-failure")) + ray.get(actor.shutdown.remote()) + + assert response.status_code == 500 + assert state["num_trajectories"] == 0 + assert state["has_active_trajectory"] is False + + +@pytest.mark.asyncio +async def test_gateway_actor_backend_failure_after_tool_mismatch_does_not_split(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=SequencedBackend(["FIRST", RuntimeError("boom")]), + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-failure-mismatch")) + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "first turn"}], + }, + ) + assert first.status_code == 200 + + async with httpx.AsyncClient(timeout=5.0) as client: + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "lookup", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "first turn"}, + {"role": "assistant", "content": "FIRST"}, + {"role": "user", "content": "follow up"}, + ], + }, + ) + assert second.status_code == 500 + + state = ray.get(actor.get_session_state.remote("session-failure-mismatch")) + trajectories = ray.get(actor.finalize_session.remote("session-failure-mismatch")) + ray.get(actor.shutdown.remote()) + + assert state["num_trajectories"] == 0 + assert len(trajectories) == 1 + assert trajectories[0].response_ids == [ord(char) for char in "FIRST"] + + +@pytest.mark.asyncio +async def test_gateway_actor_tool_call_decode_returns_openai_format(ray_runtime): + """When tool_parser_name is set and model outputs tool call tokens, + the HTTP response should contain tool_calls in OpenAI format.""" + from verl.agent.gateway.gateway import GatewayActor + + tool_call_text = '\n{"name": "search", "arguments": {"query": "weather"}}\n' + actor = GatewayActor.remote( + tokenizer=FakeTokenizer(), + backend=QueuedBackend([tool_call_text, "sunny today"]), + tool_parser_name="hermes", + ) + ray.get(actor.start.remote()) + session = ray.get(actor.create_session.remote("session-tool-call")) + + async with httpx.AsyncClient(timeout=5.0) as client: + # First request: model returns a tool call + first = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [{"role": "user", "content": "what is the weather?"}], + }, + ) + assert first.status_code == 200 + first_data = first.json() + assert first_data["choices"][0]["finish_reason"] == "tool_calls" + tool_calls = first_data["choices"][0]["message"].get("tool_calls") + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0]["function"]["name"] == "search" + assert tool_calls[0]["type"] == "function" + assert "id" in tool_calls[0] + # HTTP response arguments should be a JSON string (OpenAI compatible) + assert isinstance(tool_calls[0]["function"]["arguments"], str) + + # Second request: agent sends back tool result as continuation + second = await client.post( + f"{session.base_url}/chat/completions", + json={ + "model": "dummy-model", + "tools": [{"type": "function", "function": {"name": "search", "parameters": {"type": "object"}}}], + "messages": [ + {"role": "user", "content": "what is the weather?"}, + {"role": "assistant", "content": None, "tool_calls": tool_calls}, + {"role": "tool", "tool_call_id": tool_calls[0]["id"], "content": "sunny and warm"}, + ], + }, + ) + assert second.status_code == 200 + assert second.json()["choices"][0]["message"]["content"] == "sunny today" + + trajectories = ray.get(actor.finalize_session.remote("session-tool-call")) + ray.get(actor.shutdown.remote()) + + assert len(trajectories) == 1 + # Should have both mask=0 (incremental) and mask=1 (model output) tokens + assert 0 in trajectories[0].response_mask + assert 1 in trajectories[0].response_mask + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("stop_reason", "expected_finish_reason"), + [ + ("completed", "stop"), + ("length", "length"), + ("abort", "stop"), + ("matched_stop", "stop"), + (None, "stop"), + ], +) +async def test_decode_response_normalizes_backend_stop_reasons(stop_reason, expected_finish_reason): + """Gateway must map backend-specific stop_reason values to OpenAI-spec + finish_reason values so downstream OpenAI/litellm parsers stay compatible. + """ + from verl.agent.gateway.gateway import _GatewayActor + + actor = _GatewayActor(tokenizer=FakeTokenizer(), backend=QueuedBackend(["IGNORED"])) + response_ids = [ord(char) for char in "hello"] + + _message, finish_reason = await actor._decode_response( + response_ids, tools=None, stop_reason=stop_reason + ) + + assert finish_reason == expected_finish_reason + + +@pytest.mark.asyncio +async def test_decode_response_preserves_unknown_stop_reasons(): + """Unknown backend stop_reason values should be forwarded unchanged so a + future reader can spot a new backend value rather than silently coerce it. + """ + from verl.agent.gateway.gateway import _GatewayActor + + actor = _GatewayActor(tokenizer=FakeTokenizer(), backend=QueuedBackend(["IGNORED"])) + response_ids = [ord(char) for char in "hello"] + + _message, finish_reason = await actor._decode_response( + response_ids, tools=None, stop_reason="unknown_future_value" + ) + + assert finish_reason == "unknown_future_value" diff --git a/tests/agent/gateway/test_gateway_manager_on_cpu.py b/tests/agent/gateway/test_gateway_manager_on_cpu.py new file mode 100644 index 00000000000..9bb39f045da --- /dev/null +++ b/tests/agent/gateway/test_gateway_manager_on_cpu.py @@ -0,0 +1,106 @@ +import httpx +import pytest +import ray + +from tests.agent.support import FakeTokenizer, QueuedBackend, TrackingGatewayActor + + +@pytest.fixture +def ray_runtime(): + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.mark.asyncio +async def test_gateway_manager_routes_sessions_stickily(ray_runtime): + from verl.agent.gateway.gateway import GatewayActor + from verl.agent.gateway.manager import GatewayManager + + gateways = [ + GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["A"])), + GatewayActor.remote(tokenizer=FakeTokenizer(), backend=QueuedBackend(["B"])), + ] + ray.get([gateway.start.remote() for gateway in gateways]) + + manager = GatewayManager(gateways) + session_a = await manager.create_session("session-a") + session_b = await manager.create_session("session-b") + + assert manager.gateway_count == 2 + assert session_a.base_url != session_b.base_url + + async with httpx.AsyncClient(timeout=5.0) as client: + first = await client.post( + f"{session_a.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "route a"}]}, + ) + second = await client.post( + f"{session_b.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "route b"}]}, + ) + assert first.status_code == 200 + assert second.status_code == 200 + + trajectories_a = await manager.finalize_session("session-a") + trajectories_b = await manager.finalize_session("session-b") + + assert len(trajectories_a) == 1 + assert len(trajectories_b) == 1 + + ray.get([gateway.shutdown.remote() for gateway in gateways]) + + +@pytest.mark.asyncio +async def test_gateway_manager_uses_least_active_sessions_routing(ray_runtime): + from verl.agent.gateway.manager import GatewayManager + + gateways = [ + TrackingGatewayActor.remote("gw-0"), + TrackingGatewayActor.remote("gw-1"), + ] + ray.get([gateway.start.remote() for gateway in gateways]) + + manager = GatewayManager(gateways) + session_a = await manager.create_session("session-a") + session_b = await manager.create_session("session-b") + session_c = await manager.create_session("session-c") + + assert manager.active_sessions_per_gateway == [2, 1] + assert session_a.base_url.startswith("http://gw-0/") + assert session_b.base_url.startswith("http://gw-1/") + assert session_c.base_url.startswith("http://gw-0/") + + await manager.finalize_session("session-a") + assert manager.active_sessions_per_gateway == [1, 1] + + session_d = await manager.create_session("session-d") + assert session_d.base_url.startswith("http://gw-0/") + assert manager.active_sessions_per_gateway == [2, 1] + + ray.get([gateway.shutdown.remote() for gateway in gateways]) + + +@pytest.mark.asyncio +async def test_gateway_manager_wait_for_completion_delegates_to_session_owner(ray_runtime): + from verl.agent.gateway.manager import GatewayManager + + gateways = [ + TrackingGatewayActor.remote("gw-0"), + TrackingGatewayActor.remote("gw-1"), + ] + ray.get([gateway.start.remote() for gateway in gateways]) + + manager = GatewayManager(gateways) + await manager.create_session("session-a") + await manager.create_session("session-b") + + await manager.wait_for_completion("session-a", timeout=1.5) + + stats_0 = ray.get(gateways[0].stats.remote()) + stats_1 = ray.get(gateways[1].stats.remote()) + + assert stats_0["waited"] == [("session-a", 1.5)] + assert stats_1["waited"] == [] + + ray.get([gateway.shutdown.remote() for gateway in gateways]) diff --git a/tests/agent/gateway/test_session_runtime_on_cpu.py b/tests/agent/gateway/test_session_runtime_on_cpu.py new file mode 100644 index 00000000000..31c6e45523c --- /dev/null +++ b/tests/agent/gateway/test_session_runtime_on_cpu.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import httpx +import pytest +import ray + +from tests.agent.support import FakeTokenizer, RecordingLLMClient + + +@pytest.fixture +def ray_runtime(): + ray.init(ignore_reinit_error=True) + yield + ray.shutdown() + + +@pytest.mark.asyncio +async def test_gateway_serving_runtime_owns_gateway_lifecycle_and_session_runtime(ray_runtime): + from verl.agent.gateway.runtime import GatewayServingRuntime + + llm_client = RecordingLLMClient("OWNER") + runtime = GatewayServingRuntime( + llm_client=llm_client, + gateway_count=1, + gateway_actor_kwargs={ + "tokenizer": FakeTokenizer(), + }, + ) + + session = await runtime.create_session("session-owner") + wait_task = runtime.wait_for_completion("session-owner", timeout=2.0) + + async with httpx.AsyncClient(timeout=5.0) as client: + response = await client.post( + f"{session.base_url}/chat/completions", + json={"model": "dummy-model", "messages": [{"role": "user", "content": "owner path"}]}, + ) + assert response.status_code == 200 + assert response.json()["choices"][0]["message"]["content"] == "OWNER" + + complete = await client.post( + f"{session.base_url.removesuffix('/v1')}/complete", + json={"reward_info": {"score": 0.5, "label": "owner"}}, + ) + assert complete.status_code == 200 + + await wait_task + trajectories = await runtime.finalize_session("session-owner") + await runtime.shutdown() + + assert len(trajectories) == 1 + assert trajectories[0].reward_info == {"score": 0.5, "label": "owner"} + + +@pytest.mark.asyncio +async def test_gateway_serving_runtime_delegates_generate_to_llm_client(ray_runtime): + from verl.agent.gateway.runtime import GatewayServingRuntime + + llm_client = RecordingLLMClient("DELEGATED") + runtime = GatewayServingRuntime(llm_client=llm_client, gateway_count=0) + + output = await runtime.generate( + "request-direct", + prompt_ids=[4, 5, 6], + sampling_params={"temperature": 0.2}, + image_data=["image://direct.png"], + ) + + await runtime.shutdown() + + assert output.token_ids == [ord(char) for char in "DELEGATED"] + assert llm_client.calls == [ + { + "request_id": "request-direct", + "prompt_ids": [4, 5, 6], + "sampling_params": {"temperature": 0.2}, + "image_data": ["image://direct.png"], + "video_data": None, + "kwargs": {}, + } + ] + + +@pytest.mark.asyncio +async def test_gateway_serving_runtime_round_robins_actors_across_alive_nodes(ray_runtime, monkeypatch): + """gateway_count > 1 should distribute actors across alive CPU nodes round-robin.""" + from verl.agent.gateway.runtime import GatewayServingRuntime + + fake_nodes = [ + {"NodeID": "a" * 56, "Alive": True, "Resources": {"CPU": 8.0}}, + {"NodeID": "b" * 56, "Alive": True, "Resources": {"CPU": 8.0}}, + {"NodeID": "c" * 56, "Alive": False, "Resources": {"CPU": 8.0}}, + {"NodeID": "d" * 56, "Alive": True, "Resources": {"GPU": 1.0}}, + ] + monkeypatch.setattr("verl.agent.gateway.runtime.ray.nodes", lambda: fake_nodes) + + captured_node_ids = [] + + class _StubStartHandle: + @staticmethod + def remote(): + return ray.put(None) + + class _StubActorHandle: + start = _StubStartHandle() + + class _RecordingActor: + @classmethod + def options(cls, *, scheduling_strategy): + captured_node_ids.append(scheduling_strategy.node_id) + return cls + + @classmethod + def remote(cls, **kwargs): + return _StubActorHandle() + + monkeypatch.setattr("verl.agent.gateway.gateway.GatewayActor", _RecordingActor) + + runtime = GatewayServingRuntime( + llm_client=object(), + gateway_count=5, + gateway_actor_kwargs={"tokenizer": object()}, + ) + + assert captured_node_ids == ["a" * 56, "b" * 56, "a" * 56, "b" * 56, "a" * 56], captured_node_ids + assert len(runtime.owned_gateway_actors) == 5 + # No shutdown call needed: stub actors have no real Ray state. diff --git a/tests/agent/support.py b/tests/agent/support.py new file mode 100644 index 00000000000..6d7ce0f335f --- /dev/null +++ b/tests/agent/support.py @@ -0,0 +1,439 @@ +import asyncio +import json + +import ray +import torch + +from verl.agent.framework.types import SessionHandle, Trajectory +from verl.workers.rollout.replica import TokenOutput + + +class FakeTokenizer: + def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True, tools=None, **kwargs): + parts = [] + for message in messages: + parts.append(f"{message['role']}:{self._normalize_content(message.get('content', ''))}\n") + if add_generation_prompt: + parts.append("assistant:") + text = "".join(parts) + if tokenize: + return [ord(char) for char in text] + return text + + def decode(self, token_ids, skip_special_tokens=True): + if hasattr(token_ids, "tolist"): + token_ids = token_ids.tolist() + normalized = [] + for token_id in token_ids: + if hasattr(token_id, "item"): + token_id = token_id.item() + normalized.append(int(token_id)) + return "".join(chr(token_id) for token_id in normalized) + + def encode(self, text, add_special_tokens=False): + return [ord(char) for char in text] + + def _normalize_content(self, content): + if isinstance(content, list): + return "".join(part.get("text", "") if isinstance(part, dict) else str(part) for part in content) + if content is None: + return "" + return str(content) + + +class FakeProcessor: + class _ImageProcessor: + patch_size = 16 + + image_token_id = 32001 + video_token_id = 32002 + + def __init__(self): + self.image_processor = self._ImageProcessor() + self.tokenizer = FakeTokenizer() + self.last_processor_call = None + self.last_get_rope_index_call = None + + def apply_chat_template(self, messages, tokenize=True, add_generation_prompt=True, tools=None, **kwargs): + return self.tokenizer.apply_chat_template( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + tools=tools, + **kwargs, + ) + + def __call__( + self, + *, + text, + images=None, + videos=None, + video_metadata=None, + return_tensors=None, + do_sample_frames=False, + **kwargs, + ): + assert len(text) == 1 + self.last_processor_call = { + "text": list(text), + "images": None if images is None else list(images), + "videos": None if videos is None else list(videos), + "video_metadata": None if video_metadata is None else list(video_metadata), + "return_tensors": return_tensors, + "do_sample_frames": do_sample_frames, + } + + prompt_ids = self.tokenizer.encode(text[0], add_special_tokens=False) + if images: + prompt_ids.extend([self.image_token_id] * len(images)) + if videos: + prompt_ids.extend([self.video_token_id] * len(videos)) + + input_ids = torch.tensor([prompt_ids], dtype=torch.long) + attention_mask = torch.ones_like(input_ids) + output = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + if images: + image_count = len(images) + output["pixel_values"] = torch.arange(image_count * 12, dtype=torch.float32).reshape(image_count, 3, 2, 2) + output["image_grid_thw"] = torch.tensor([[1, 2, 3]] * image_count, dtype=torch.long) + output["mm_token_type_ids"] = torch.ones_like(input_ids) + + if videos: + video_count = len(videos) + output["pixel_values_videos"] = torch.arange(video_count * 24, dtype=torch.float32).reshape( + video_count, 3, 2, 4 + ) + output["video_grid_thw"] = torch.tensor([[1, 3, 4]] * video_count, dtype=torch.long) + output["mm_token_type_ids"] = torch.ones_like(input_ids) + + return output + + def get_rope_index( + self, + *, + input_ids, + attention_mask, + image_grid_thw=None, + video_grid_thw=None, + mm_token_type_ids=None, + **kwargs, + ): + self.last_get_rope_index_call = { + "input_ids": input_ids.clone(), + "attention_mask": attention_mask.clone(), + "image_grid_thw": None if image_grid_thw is None else image_grid_thw.clone(), + "video_grid_thw": None if video_grid_thw is None else video_grid_thw.clone(), + "mm_token_type_ids": None if mm_token_type_ids is None else mm_token_type_ids.clone(), + } + seq_len = input_ids.shape[1] + base = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) + vision_position_ids = torch.stack( + [ + base + 100, + base + 200, + base + 300, + ], + dim=0, + ).unsqueeze(1) + return vision_position_ids, None + + +async def fake_vision_info_extractor(messages, image_patch_size, config=None): + assert image_patch_size == 16 + images = [] + videos = [] + for message in messages: + content = message.get("content") + if not isinstance(content, list): + continue + for part in content: + if not isinstance(part, dict): + continue + if part.get("type") == "image_url": + image_url = part.get("image_url", {}) + if isinstance(image_url, dict) and image_url.get("url"): + images.append(image_url["url"]) + elif part.get("type") == "video_url": + video_url = part.get("video_url", {}) + if isinstance(video_url, dict) and video_url.get("url"): + videos.append(video_url["url"]) + return images or None, videos or None + + +class SingleUseVisionInfoExtractor: + def __init__(self): + self.calls = 0 + + async def __call__(self, messages, image_patch_size, config=None): + self.calls += 1 + if self.calls > 1: + raise AssertionError("vision_info_extractor should not be called again on continuation") + return await fake_vision_info_extractor(messages, image_patch_size=image_patch_size, config=config) + + +class InspectingBackend: + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + payload = json.dumps( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + }, + sort_keys=True, + ) + token_ids = [ord(char) for char in payload] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class InspectingSequencedBackend: + def __init__(self, steps): + self.steps = list(steps) + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + step = self.steps.pop(0) + if step == "__inspect__": + text = json.dumps( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + }, + sort_keys=True, + ) + elif isinstance(step, Exception): + raise step + else: + text = step + + token_ids = [ord(char) for char in text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class QueuedBackend: + def __init__(self, responses): + self._responses = list(responses) + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + text = self._responses.pop(0) + token_ids = [ord(char) for char in text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class SlowBackend: + def __init__(self, response_text="SLOW", delay_s=1.5): + self.response_text = response_text + self.delay_s = delay_s + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + await asyncio.sleep(self.delay_s) + token_ids = [ord(char) for char in self.response_text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class RecordingLLMClient: + def __init__(self, response_text="OK"): + self.response_text = response_text + self.calls = [] + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None, **kwargs): + self.calls.append( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + "kwargs": dict(kwargs), + } + ) + token_ids = [ord(char) for char in self.response_text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class RejectToolsSamplingParamsBackend: + def __init__(self, response_text: str = "OK"): + self.response_text = response_text + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + if "tools" in sampling_params: + raise RuntimeError("tools leaked into sampling_params") + token_ids = [ord(char) for char in self.response_text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class RejectRequestEnvelopeBackend: + def __init__(self, response_text: str = "OK", expected_sampling_params: dict | None = None): + self.response_text = response_text + self.expected_sampling_params = expected_sampling_params + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + assert "messages" not in sampling_params + assert "model" not in sampling_params + assert "tools" not in sampling_params + if self.expected_sampling_params is None: + assert sampling_params["temperature"] == 0.25 + else: + assert sampling_params == self.expected_sampling_params + token_ids = [ord(char) for char in self.response_text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class FailingBackend: + def __init__(self, error_message: str = "backend failure"): + self.error_message = error_message + self.calls = [] + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + self.calls.append( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + } + ) + raise RuntimeError(self.error_message) + + +class SequencedBackend: + def __init__(self, steps): + self.steps = list(steps) + self.calls = [] + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + self.calls.append( + { + "request_id": request_id, + "prompt_ids": list(prompt_ids), + "sampling_params": dict(sampling_params), + "image_data": image_data, + "video_data": video_data, + } + ) + step = self.steps.pop(0) + if isinstance(step, Exception): + raise step + token_ids = [ord(char) for char in step] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + + +class RejectConcurrentSessionBackend: + def __init__(self, responses, delay: float = 0.05): + self._responses = list(responses) + self._delay = delay + self._active_request_ids: set[str] = set() + self.call_windows = [] + + async def generate(self, request_id, *, prompt_ids, sampling_params, image_data=None, video_data=None): + if request_id in self._active_request_ids: + raise RuntimeError(f"concurrent request for session {request_id}") + + self._active_request_ids.add(request_id) + started_at = asyncio.get_running_loop().time() + try: + await asyncio.sleep(self._delay) + text = self._responses.pop(0) + token_ids = [ord(char) for char in text] + return TokenOutput( + token_ids=token_ids, + log_probs=[-0.1] * len(token_ids), + stop_reason="completed", + ) + finally: + finished_at = asyncio.get_running_loop().time() + self.call_windows.append((request_id, started_at, finished_at)) + self._active_request_ids.remove(request_id) + + +@ray.remote +class TrackingGatewayActor: + def __init__(self, name: str): + self.name = name + self.sessions = {} + self.created = [] + self.finalized = [] + self.aborted = [] + self.waited = [] + + async def start(self): + return None + + async def shutdown(self): + return None + + async def create_session(self, session_id: str, metadata: dict | None = None): + handle = SessionHandle(session_id=session_id, base_url=f"http://{self.name}/{session_id}/v1") + self.sessions[session_id] = {"metadata": metadata or {}} + self.created.append(session_id) + return handle + + async def finalize_session(self, session_id: str): + self.finalized.append(session_id) + self.sessions.pop(session_id, None) + return [ + Trajectory( + prompt_ids=[1], + response_ids=[2], + response_mask=[1], + ) + ] + + async def abort_session(self, session_id: str): + self.aborted.append(session_id) + self.sessions.pop(session_id, None) + return None + + async def wait_for_completion(self, session_id: str, timeout: float | None = None): + self.waited.append((session_id, timeout)) + return None + + async def stats(self): + return { + "name": self.name, + "created": list(self.created), + "finalized": list(self.finalized), + "aborted": list(self.aborted), + "waited": list(self.waited), + } diff --git a/tests/utils/test_transferqueue_utils_on_cpu.py b/tests/utils/test_transferqueue_utils_on_cpu.py new file mode 100644 index 00000000000..520a1fbcae1 --- /dev/null +++ b/tests/utils/test_transferqueue_utils_on_cpu.py @@ -0,0 +1,56 @@ +import torch +from tensordict import TensorDict + +from verl.utils import transferqueue_utils as tq_utils +from verl.utils.transferqueue_utils import force_tq_sequence_fields_nested + + +def test_force_tq_sequence_fields_nested_converts_dense_sequence_tensors(): + data = TensorDict( + { + "prompts": torch.tensor([[1, 2, 3], [4, 5, 6]]), + "attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]), + "response_mask": torch.tensor([[0, 1], [1, 1]]), + "position_ids": torch.arange(2 * 4 * 3, dtype=torch.long).reshape(2, 4, 3), + "num_turns": torch.tensor([1, 2]), + }, + batch_size=[2], + ) + + normalized = force_tq_sequence_fields_nested(data) + + assert normalized["prompts"].is_nested + assert normalized["attention_mask"].is_nested + assert normalized["response_mask"].is_nested + assert normalized["position_ids"].is_nested + assert normalized["position_ids"][0].shape == (4, 3) + assert not normalized["num_turns"].is_nested + + +def test_install_tq_nested_readback_wrappers_can_be_env_gated(monkeypatch): + class FakeTQ: + def __init__(self): + self.kv_batch_get_calls = 0 + + def kv_batch_get(self, *args, **kwargs): + self.kv_batch_get_calls += 1 + return TensorDict( + { + "prompts": torch.tensor([[1, 2, 3], [4, 5, 6]]), + "attention_mask": torch.tensor([[1, 1, 1], [1, 1, 1]]), + "num_turns": torch.tensor([1, 2]), + }, + batch_size=[2], + ) + + fake_tq = FakeTQ() + monkeypatch.setattr(tq_utils, "tq", fake_tq) + monkeypatch.setenv("VERL_FORCE_TQ_NESTED_READBACK", "1") + + tq_utils._install_tq_nested_readback_wrappers() + + normalized = tq_utils.tq.kv_batch_get(keys=["a", "b"], partition_id="train") + assert normalized["prompts"].is_nested + assert normalized["attention_mask"].is_nested + assert not normalized["num_turns"].is_nested + assert fake_tq.kv_batch_get_calls == 1 diff --git a/verl/agent/__init__.py b/verl/agent/__init__.py new file mode 100644 index 00000000000..d8946c87799 --- /dev/null +++ b/verl/agent/__init__.py @@ -0,0 +1,2 @@ +"""Agent framework and gateway packages.""" + diff --git a/verl/agent/framework/__init__.py b/verl/agent/framework/__init__.py new file mode 100644 index 00000000000..5b23f9dc71b --- /dev/null +++ b/verl/agent/framework/__init__.py @@ -0,0 +1,9 @@ +from .framework import AgentFramework, OpenAICompatibleAgentFramework +from .types import SessionHandle, Trajectory + +__all__ = [ + "AgentFramework", + "OpenAICompatibleAgentFramework", + "SessionHandle", + "Trajectory", +] diff --git a/verl/agent/framework/entry.py b/verl/agent/framework/entry.py new file mode 100644 index 00000000000..0ce850a2c9c --- /dev/null +++ b/verl/agent/framework/entry.py @@ -0,0 +1,108 @@ +"""Factory entry + trainer-facing adapter for the agent framework stack. + +`build_agent_framework` owns gateway-universal wiring so framework subclasses +only handle their own agent runner, reward dispatch, and framework-specific +config fields. + +`AgentFrameworkRolloutAdapter` satisfies the trainer's +`agent_loop_manager_class` extension-point contract; recipes wire it in via +yaml without authoring per-recipe glue: + + actor_rollout_ref.rollout.agent.agent_loop_manager_class: + verl.agent.framework.entry.AgentFrameworkRolloutAdapter +""" + +from __future__ import annotations + +from omegaconf import OmegaConf + +from verl.agent.framework.framework import AgentFramework +from verl.agent.gateway.runtime import GatewayServingRuntime +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.import_utils import load_class_from_fqn +from verl.utils.ray_utils import auto_await +from verl.workers.config.model import HFModelConfig + +_DEFAULT_FRAMEWORK_CLASS = "verl.agent.framework.framework.OpenAICompatibleAgentFramework" + + +async def build_agent_framework( + *, + config, + llm_client, + replay_buffer, + reward_loop_worker_handles=None, +) -> AgentFramework: + """Build GatewayServingRuntime, then delegate subclass-specific wiring.""" + # TODO(phase-b): switch this to actor_rollout_ref.rollout.agent_framework.* + af_cfg = OmegaConf.select(config, "actor_rollout_ref.rollout.custom.agent_framework", default={}) or {} + + # Match AgentLoopWorker pattern: self-load tokenizer/processor via HFModelConfig. + model_config: HFModelConfig = omega_conf_to_dataclass(config.actor_rollout_ref.model) + + gateway_actor_kwargs = { + "tokenizer": model_config.tokenizer, + "processor": model_config.processor, + } + tool_parser_name = config.actor_rollout_ref.rollout.get("multi_turn", {}).get("format") + if tool_parser_name is not None: + gateway_actor_kwargs["tool_parser_name"] = tool_parser_name + + session_runtime = GatewayServingRuntime( + llm_client=llm_client, + gateway_count=int(af_cfg["gateway_count"]), + gateway_actor_kwargs=gateway_actor_kwargs, + ) + + framework_cls = load_class_from_fqn(str(af_cfg.get("framework_class_fqn", _DEFAULT_FRAMEWORK_CLASS))) + return await framework_cls.from_config( + config=config, + session_runtime=session_runtime, + processor=model_config.processor, + replay_buffer=replay_buffer, + reward_loop_worker_handles=reward_loop_worker_handles, + ) + + +class AgentFrameworkRolloutAdapter: + """Trainer-facing adapter satisfying the `agent_loop_manager_class` contract. + + Holds zero recipe-specific logic; every agent-framework recipe wires the + same class in yaml. Phase B will let `main_ppo_sync.py` call + `build_agent_framework` directly and this adapter can retire. + """ + + def __init__(self) -> None: + self.framework = None + + @classmethod + @auto_await + async def create( + cls, + *, + config, + llm_client, + teacher_client=None, + reward_loop_worker_handles=None, + replay_buffer=None, + **_, + ) -> "AgentFrameworkRolloutAdapter": + del teacher_client + assert replay_buffer is not None, "AgentFrameworkRolloutAdapter requires replay_buffer" + + framework = await build_agent_framework( + config=config, + llm_client=llm_client, + replay_buffer=replay_buffer, + reward_loop_worker_handles=reward_loop_worker_handles, + ) + + instance = cls() + instance.framework = framework + return instance + + @auto_await + async def generate_sequences(self, prompts) -> None: + if self.framework is None: + raise RuntimeError("framework must be initialized before generate_sequences") + return await self.framework.generate_sequences(prompts) diff --git a/verl/agent/framework/framework.py b/verl/agent/framework/framework.py new file mode 100644 index 00000000000..f672f7546c6 --- /dev/null +++ b/verl/agent/framework/framework.py @@ -0,0 +1,612 @@ +from __future__ import annotations + +import asyncio +import logging +import random +from abc import ABC, abstractmethod +from dataclasses import replace +from functools import partial +from uuid import uuid4 + +from omegaconf import OmegaConf +import torch +from tensordict import TensorDict +from tensordict.tensorclass import NonTensorData, NonTensorStack + +from verl.tools.utils.tool_registry import initialize_tools_from_config +from verl.utils.import_utils import load_class_from_fqn +from verl.utils.transferqueue_utils import tq +from verl.utils import tensordict_utils as tu +from verl.utils.model import compute_position_id_with_mask + +from .multi_modal_postprocess import compute_multi_modal_inputs, compute_position_ids +from .types import SessionRuntime, Trajectory + +logger = logging.getLogger(__name__) + + +class AgentFramework(ABC): + """Abstract base for framework implementations. + + Phase A: entry.py owns session runtime construction and passes it in. + Subclasses receive shared entry resources plus the raw config for + subclass-specific field parsing. + + Phase B: trainer inlines entry; this from_config contract remains. + """ + + @classmethod + @abstractmethod + async def from_config( + cls, + *, + config, + session_runtime, + processor=None, + replay_buffer, + reward_loop_worker_handles=None, + ) -> "AgentFramework": + ... + + @abstractmethod + async def generate_sequences(self, prompts: TensorDict) -> None: + """Run agent sessions and write finalized trajectories to TransferQueue.""" + ... + + +def _short_failure_reason(error: BaseException) -> str: + message = str(error) + if not message: + message = error.__class__.__name__ + return message[:512] + + +_TQ_NESTED_SEQUENCE_FIELDS = { + "prompts", + "responses", + "response_mask", + "loss_mask", + "input_ids", + "attention_mask", + "position_ids", + "rollout_log_probs", + "rm_scores", + "teacher_logprobs", + "teacher_ids", +} + + +def _list_of_tq_fields_to_tensordict(fields: list[dict[str, object]]) -> TensorDict: + td = tu.list_of_dict_to_tensordict(fields) + for key in _TQ_NESTED_SEQUENCE_FIELDS: + if key not in fields[0]: + continue + values = [field[key] for field in fields] + if not all(isinstance(value, torch.Tensor) for value in values): + continue + ragged_idx = 2 if key == "position_ids" and values[0].dim() == 2 else None + td[key] = tu.nested_tensor_from_tensor_list(values, ragged_idx=ragged_idx) + return td + + +def _trajectory_to_reward_dataproto(trajectory, sample_fields): + """Build a single-sample DataProto for RewardLoopWorker.compute_score. + + Field shape matches AgentLoopWorker._compute_score + (verl/experimental/agent_loop/agent_loop.py:753-772). Only fields actually + consumed by NaiveRewardManager.run_single / RewardLoopWorker dispatch are + populated; tool_extra_fields / num_turns are passed via non_tensor_batch + for parity. + """ + import numpy as np + from verl.protocol import DataProto + + prompt_ids = torch.tensor(trajectory.prompt_ids, dtype=torch.long).unsqueeze(0) + response_ids = torch.tensor(trajectory.response_ids, dtype=torch.long).unsqueeze(0) + input_ids = torch.cat([prompt_ids, response_ids], dim=1) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + batch = TensorDict( + { + "prompts": prompt_ids, + "responses": response_ids, + "input_ids": input_ids, + "attention_mask": attention_mask, + }, + batch_size=1, + ) + + non_tensor_batch: dict[str, object] = {} + for key in ("raw_prompt", "data_source", "reward_model", "extra_info", "tools_kwargs", "agent_name"): + if key in sample_fields: + non_tensor_batch[key] = np.array([sample_fields[key]], dtype=object) + non_tensor_batch["__num_turns__"] = np.array([trajectory.num_turns]) + + return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) + + +class OpenAICompatibleAgentFramework(AgentFramework): + """Reference AgentFramework implementation for OpenAI-compatible agent loops. + + Each sample in the batch is run as an independent session: the agent + communicates with the Gateway via standard ``/v1/chat/completions`` + requests, and the Gateway collects token-level trajectories. After + finalization, ``_score_trajectories`` dispatches the session's final + trajectory to a RewardLoopWorker and broadcasts the score back to all + trajectories in the session (matching + ``AgentLoopWorkerTQ._agent_loop_postprocess``); the framework then writes + them to the TransferQueue schema consumed by sync training. + """ + + def __init__( + self, + session_runtime: SessionRuntime, + agent_runner, + *, + reward_loop_worker_handles=None, + processor=None, + replay_buffer=None, + rollout_config=None, + completion_timeout: float | None = 30.0, + wait_for_completion_after_agent_run: bool = False, + max_concurrent_sessions: int = 0, + ): + self.session_runtime = session_runtime + self.agent_runner = agent_runner + self.reward_loop_worker_handles = list(reward_loop_worker_handles) if reward_loop_worker_handles else None + self._processor = processor + # TODO(phase-b): once trainer constructs framework directly, these become + # constructor-required and no transitional dual-path is needed. + self._replay_buffer = replay_buffer + self._rollout_config = rollout_config + self.completion_timeout = completion_timeout + self.wait_for_completion_after_agent_run = wait_for_completion_after_agent_run + self._max_concurrent_sessions = max_concurrent_sessions + self._semaphore: asyncio.Semaphore | None = None + self._semaphore_loop: asyncio.AbstractEventLoop | None = None + + @classmethod + async def from_config( + cls, + *, + config, + session_runtime, + processor=None, + replay_buffer, + reward_loop_worker_handles=None, + ) -> "OpenAICompatibleAgentFramework": + # TODO(phase-b): switch this to actor_rollout_ref.rollout.agent_framework.* + af_cfg = OmegaConf.select(config, "actor_rollout_ref.rollout.custom.agent_framework", default={}) or {} + agent_runner_fqn = af_cfg.get("agent_runner_fqn") + if not agent_runner_fqn: + raise ValueError("actor_rollout_ref.rollout.custom.agent_framework.agent_runner_fqn is required") + + agent_runner = load_class_from_fqn(str(agent_runner_fqn), description="agent runner") + runner_kwargs = dict( + OmegaConf.to_container(OmegaConf.create(af_cfg.get("agent_runner_kwargs", {})), resolve=True) or {} + ) + tool_config_path = af_cfg.get("tool_config_path") + if tool_config_path: + tool_config = initialize_tools_from_config(tool_config_path) + if not tool_config: + raise ValueError(f"tool config did not initialize any tools: {tool_config_path}") + runner_kwargs["tool_config"] = tool_config + if runner_kwargs: + agent_runner = partial(agent_runner, **runner_kwargs) + + completion_timeout = af_cfg.get("completion_timeout_seconds") + return cls( + session_runtime=session_runtime, + agent_runner=agent_runner, + reward_loop_worker_handles=reward_loop_worker_handles, + processor=processor, + replay_buffer=replay_buffer, + rollout_config=config.actor_rollout_ref.rollout, + completion_timeout=completion_timeout, + wait_for_completion_after_agent_run=completion_timeout is not None, + max_concurrent_sessions=int(af_cfg.get("max_concurrent_sessions", 0)), + ) + + async def generate_sequences(self, prompts: TensorDict) -> None: + """Run rollout-manager generation and write outputs into TransferQueue.""" + if self._replay_buffer is None: + raise RuntimeError("OpenAICompatibleAgentFramework requires replay_buffer for generate_sequences") + if self._rollout_config is None: + raise RuntimeError("OpenAICompatibleAgentFramework requires rollout_config for generate_sequences") + + global_steps = tu.get(prompts, "global_steps") + if global_steps is None: + raise ValueError("OpenAICompatibleAgentFramework requires prompts['global_steps']") + + partition_id = "val" if "validate" in prompts.keys() else "train" + if partition_id == "val": + val_kwargs = self._rollout_config.get("val_kwargs", {}) + num_sessions = int(val_kwargs.get("n")) + else: + num_sessions = int(self._rollout_config.get("n")) + + uids = tu.get(prompts, "uid") + if uids is None: + raise ValueError("OpenAICompatibleAgentFramework requires prompts['uid'] for replay_buffer") + uid_values = uids.tolist() if hasattr(uids, "tolist") else list(uids) + self._replay_buffer.add( + partition_id, + {str(uid): {"global_steps": global_steps, "status": "running"} for uid in uid_values}, + ) + + stats = await self._run_batch_to_tq( + prompts, + global_steps=global_steps, + partition_id=partition_id, + num_sessions=num_sessions, + ) + logger.info( + "generate_sequences summary: num_input_prompts=%s num_success_sessions=%s " + "num_failed_sessions=%s num_success_outputs=%s num_failed_uids=%s failure_reasons=%s", + stats["num_input_prompts"], + stats["num_success_sessions"], + stats["num_failed_sessions"], + stats["num_success_outputs"], + stats["num_failed_uids"], + stats["failure_reasons"][:3], + ) + if stats["num_success_outputs"] == 0: + raise RuntimeError( + f"All rollouts failed at global_steps={global_steps}. " + f"failures={stats['num_failed_uids']}/{stats['num_input_prompts']}" + ) + return None + + async def _run_batch_to_tq( + self, + prompts: TensorDict, + *, + global_steps: int, + partition_id: str, + num_sessions: int = 1, + ) -> dict: + """Run all prompts in a batch and aggregate prompt/session stats.""" + assert len(prompts) > 0, "generate_sequences requires a non-empty batch" + if num_sessions <= 0: + raise ValueError(f"num_sessions must be positive, got {num_sessions}") + + raw_prompts = tu.get(prompts, "raw_prompt") + if raw_prompts is None: + raise ValueError("OpenAICompatibleAgentFramework requires prompts['raw_prompt']") + + # Batch layer: each sample/prompt owns its own group of rollout.n sessions. + # Prompt tasks are isolated so one prompt failure does not drop the whole batch. + tasks = [ + self._run_prompt_sessions_to_tq( + prompts=prompts, + raw_prompt=raw_prompts[sample_index], + sample_index=sample_index, + global_steps=global_steps, + partition_id=partition_id, + num_sessions=num_sessions, + ) + for sample_index in range(len(prompts)) + ] + outcomes = await asyncio.gather(*tasks, return_exceptions=True) + + failure_reasons: list[str] = [] + stats = { + "num_input_prompts": len(prompts), + "num_success_sessions": 0, + "num_failed_sessions": 0, + "num_success_outputs": 0, + "num_failed_uids": 0, + "failure_reasons": failure_reasons, + } + for outcome in outcomes: + if isinstance(outcome, Exception): + stats["num_failed_sessions"] += num_sessions + stats["num_failed_uids"] += 1 + failure_reasons.append(_short_failure_reason(outcome)) + continue + stats["num_success_sessions"] += outcome["num_success_sessions"] + stats["num_failed_sessions"] += outcome["num_failed_sessions"] + stats["num_success_outputs"] += outcome["num_success_outputs"] + stats["num_failed_uids"] += outcome["num_failed_uids"] + failure_reasons.extend(outcome["failure_reasons"]) + return stats + + async def _run_prompt_sessions_to_tq( + self, + *, + prompts: TensorDict, + raw_prompt, + sample_index: int, + global_steps: int, + partition_id: str, + num_sessions: int, + ) -> dict: + sample_fields = self._extract_sample_fields(prompts=prompts, sample_index=sample_index) + uid = sample_fields.get("uid") + if uid is None: + raise ValueError("OpenAICompatibleAgentFramework requires prompts['uid'] for TransferQueue output") + uid = str(uid) + + # Prompt layer: rollout.n sessions race independently for the same uid. + # Successful sessions are written to TQ; failed sessions only affect this uid's stats. + tasks = [ + self._run_session_with_concurrency_limit( + prompts=prompts, + raw_prompt=raw_prompt, + sample_index=sample_index, + session_id=f"session-{sample_index}-{session_index}-{uuid4().hex}", + runner_kwargs={ + key: sample_fields[key] + for key in ("tools_kwargs", "agent_name") + if key in sample_fields + }, + ) + for session_index in range(num_sessions) + ] + outcomes = await asyncio.gather(*tasks, return_exceptions=True) + + success_sessions = 0 + failed_sessions = 0 + success_outputs = 0 + failure_reasons: list[str] = [] + for session_index, outcome in enumerate(outcomes): + if isinstance(outcome, Exception): + failed_sessions += 1 + failure_reasons.append(_short_failure_reason(outcome)) + continue + + trajectories, session_sample_fields = outcome + if not trajectories: + failed_sessions += 1 + failure_reasons.append(f"empty trajectories for uid={uid} session_index={session_index}") + continue + + success_sessions += 1 + await self._write_session_trajectories_to_tq( + uid=uid, + session_index=session_index, + trajectories=trajectories, + sample_fields=session_sample_fields, + global_steps=global_steps, + partition_id=partition_id, + ) + success_outputs += len(trajectories) + + if success_sessions > 0: + await tq.async_kv_put(key=uid, partition_id=partition_id, tag={"status": "finished"}) + failed_uids = 0 + else: + await tq.async_kv_put(key=uid, partition_id=partition_id, tag={"status": "failure"}) + failed_uids = 1 + + return { + "num_success_sessions": success_sessions, + "num_failed_sessions": failed_sessions, + "num_success_outputs": success_outputs, + "num_failed_uids": failed_uids, + "failure_reasons": failure_reasons, + } + + async def _run_session_with_concurrency_limit( + self, + *, + prompts: TensorDict, + raw_prompt, + sample_index: int, + session_id: str | None = None, + runner_kwargs: dict[str, object] | None = None, + ) -> tuple[list[Trajectory], dict[str, object]]: + if self._max_concurrent_sessions <= 0: + return await self._run_session( + prompts=prompts, + raw_prompt=raw_prompt, + sample_index=sample_index, + session_id=session_id, + runner_kwargs=runner_kwargs, + ) + # Lazy-init Semaphore on first use and rebind if the running loop + # changed: asyncio.Semaphore binds to the loop at construction, but + # Ray actors may run sessions on a different loop than __init__. + loop = asyncio.get_running_loop() + if self._semaphore is None or self._semaphore_loop is not loop: + self._semaphore = asyncio.Semaphore(self._max_concurrent_sessions) + self._semaphore_loop = loop + async with self._semaphore: + return await self._run_session( + prompts=prompts, + raw_prompt=raw_prompt, + sample_index=sample_index, + session_id=session_id, + runner_kwargs=runner_kwargs, + ) + + async def _run_session( + self, + *, + prompts: TensorDict, + raw_prompt, + sample_index: int, + session_id: str | None = None, + runner_kwargs: dict[str, object] | None = None, + ) -> tuple[list[Trajectory], dict[str, object]]: + """Run one gateway session lifecycle and return finalized trajectories.""" + session_id = session_id or f"session-{sample_index}-0-{uuid4().hex}" + sample_fields = self._extract_sample_fields(prompts=prompts, sample_index=sample_index) + session = await self.session_runtime.create_session(session_id) + try: + await self.agent_runner( + raw_prompt=raw_prompt, + session=session, + sample_index=sample_index, + **(runner_kwargs or {}), + ) + if self.wait_for_completion_after_agent_run: + await self.session_runtime.wait_for_completion(session_id, timeout=self.completion_timeout) + session_trajectories = await self.session_runtime.finalize_session(session_id) + except Exception: + await self.session_runtime.abort_session(session_id) + raise + + # Score the session's trajectories immediately after finalization, + # consistent with VERL's per-sample reward path. + if not self.reward_loop_worker_handles or not session_trajectories: + return session_trajectories, sample_fields + + annotations = await self._score_trajectories(session_trajectories, sample_fields) + scored_trajectories = [] + for traj, (score, extra) in zip(session_trajectories, annotations, strict=True): + scored_trajectories.append( + replace( + traj, + reward_score=score, + extra_fields={**traj.extra_fields, "reward_extra_info": extra}, + ) + ) + return scored_trajectories, sample_fields + + async def _score_trajectories( + self, + session_trajectories: list[Trajectory], + sample_fields: dict[str, object], + ) -> list[tuple[float, dict[str, object]]]: + """Score the session's final trajectory and broadcast (score, extra_info) to all. + + Mirrors AgentLoopWorkerTQ._agent_loop_postprocess + (verl/trainer/main_ppo_sync.py:353-396): only the final trajectory (the + session's last interaction segment) is dispatched to RewardLoopWorker; + its score + reward_extra_info are then broadcast to every trajectory in + the session. Subclasses can override this method to implement custom + session-to-trajectory scoring policies. + """ + assert self.reward_loop_worker_handles is not None + assert session_trajectories, "expected non-empty session_trajectories" + + final_trajectory = session_trajectories[-1] + data = _trajectory_to_reward_dataproto(final_trajectory, sample_fields) + worker = random.choice(self.reward_loop_worker_handles) + result = await worker.compute_score.remote(data) + + if "reward_score" not in result: + raise ValueError( + f"RewardLoopWorker result missing 'reward_score' key for uid={sample_fields.get('uid')}" + ) + score = float(result["reward_score"]) + extra = dict(result.get("reward_extra_info") or {}) + return [(score, extra)] * len(session_trajectories) + + def _extract_sample_fields(self, *, prompts: TensorDict, sample_index: int) -> dict[str, object]: + sample_fields = {} + for key, value in prompts.items(): + if isinstance(value, torch.Tensor): + sample_fields[key] = value if value.ndim == 0 else value[sample_index] + elif isinstance(value, NonTensorStack): + sample_fields[key] = tu.get(prompts, key)[sample_index] + else: + assert isinstance(value, NonTensorData) + sample_fields[key] = value.data + return sample_fields + + async def _write_session_trajectories_to_tq( + self, + *, + uid: str, + session_index: int, + trajectories: list[Trajectory], + sample_fields: dict[str, object], + global_steps: int, + partition_id: str, + ) -> None: + keys = [] + fields = [] + tags = [] + for index, trajectory in enumerate(trajectories): + field, tag = self._trajectory_to_tq_field_and_tag( + trajectory=trajectory, + sample_fields=sample_fields, + session_index=session_index, + global_steps=global_steps, + ) + keys.append(f"{uid}_{session_index}_{index}") + fields.append(field) + tags.append(tag) + + await tq.async_kv_batch_put( + keys=keys, + fields=_list_of_tq_fields_to_tensordict(fields), + tags=tags, + partition_id=partition_id, + ) + + def _trajectory_to_tq_field_and_tag( + self, + *, + trajectory: Trajectory, + sample_fields: dict[str, object], + session_index: int, + global_steps: int, + ) -> tuple[dict[str, object], dict[str, object]]: + prompts = torch.tensor(trajectory.prompt_ids, dtype=torch.long) + responses = torch.tensor(trajectory.response_ids, dtype=torch.long) + response_mask = torch.tensor(trajectory.response_mask, dtype=torch.long) + input_ids = torch.cat([prompts, responses], dim=0) + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + multi_modal_inputs = compute_multi_modal_inputs( + self._processor, + input_ids.unsqueeze(0), + trajectory.multi_modal_data, + ) + if self._processor is None: + position_ids = compute_position_id_with_mask(attention_mask.unsqueeze(0)).squeeze(0) + else: + position_ids = compute_position_ids( + self._processor, + input_ids.unsqueeze(0), + attention_mask.unsqueeze(0), + multi_modal_inputs, + ).squeeze(0) + + field: dict[str, object] = { + "prompts": prompts, + "responses": responses, + "response_mask": response_mask, + "loss_mask": response_mask, + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "multi_modal_inputs": multi_modal_inputs, + } + if trajectory.response_logprobs is not None: + field["rollout_log_probs"] = torch.tensor(trajectory.response_logprobs, dtype=torch.float32) + else: + field["rollout_log_probs"] = torch.zeros_like(responses, dtype=torch.float32) + if trajectory.routed_experts is not None: + field["routed_experts"] = ( + torch.from_numpy(trajectory.routed_experts.copy()) + if hasattr(trajectory.routed_experts, "copy") and not isinstance(trajectory.routed_experts, torch.Tensor) + else trajectory.routed_experts + ) + rm_scores = torch.zeros_like(responses, dtype=torch.float32) + if trajectory.reward_score is not None and responses.numel() > 0: + rm_scores[-1] = float(trajectory.reward_score) + field["rm_scores"] = rm_scores + + field.update(trajectory.extra_fields) + field.pop("multi_modal_data", None) + for key in ("uid", "raw_prompt", "data_source", "reward_model", "extra_info", "tools_kwargs", "agent_name"): + if key in sample_fields: + field[key] = sample_fields[key] + field["session_id"] = session_index + field["global_steps"] = global_steps + field["num_turns"] = torch.tensor(int(trajectory.num_turns), dtype=torch.long) + + prompt_len = prompts.size(0) + response_len = responses.size(0) + tag = { + "global_steps": global_steps, + "status": "success", + "prompt_len": prompt_len, + "response_len": response_len, + "seq_len": prompt_len + response_len, + } + return field, tag + diff --git a/verl/agent/framework/multi_modal_postprocess.py b/verl/agent/framework/multi_modal_postprocess.py new file mode 100644 index 00000000000..67694cf2e97 --- /dev/null +++ b/verl/agent/framework/multi_modal_postprocess.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from verl.utils.model import compute_position_id_with_mask + +# Behavior mirrors legacy `_compute_multi_modal_inputs` / `_compute_position_ids` +# in `verl/experimental/agent_loop/agent_loop.py`. + + +def _split_videos_and_metadata(videos: list[Any] | None) -> tuple[list[Any] | None, list[Any] | None]: + if not videos: + return videos, None + + first_video = videos[0] + if isinstance(first_video, tuple) and len(first_video) == 2: + split_videos, video_metadata = zip(*videos, strict=False) + return list(split_videos), list(video_metadata) + + return list(videos), None + + +def _to_plain_tensor_dict(processor_output) -> dict[str, torch.Tensor]: + if hasattr(processor_output, "convert_to_tensors"): + processor_output = processor_output.convert_to_tensors("pt") + return dict(processor_output) + + +def compute_multi_modal_inputs( + processor, + input_ids: torch.Tensor, + multi_modal_data: dict[str, Any] | None, +) -> dict[str, torch.Tensor]: + """Return processor-produced multimodal tensors for a single sample.""" + if processor is None or not multi_modal_data: + return {} + + images = multi_modal_data.get("images") + videos, video_metadata = _split_videos_and_metadata(multi_modal_data.get("videos")) + current_text = processor.tokenizer.decode(input_ids.squeeze(0), skip_special_tokens=True) + multi_modal_inputs = _to_plain_tensor_dict( + processor( + text=[current_text], + images=images, + videos=videos, + video_metadata=video_metadata, + return_tensors="pt", + do_sample_frames=False, + ) + ) + multi_modal_inputs.pop("input_ids", None) + multi_modal_inputs.pop("attention_mask", None) + + image_grid_thw = multi_modal_inputs.get("image_grid_thw") + if image_grid_thw is not None: + images_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]) + multi_modal_inputs["images_seqlens"] = images_seqlens + return multi_modal_inputs + + +def compute_position_ids( + processor, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + multi_modal_inputs: dict[str, torch.Tensor], +) -> torch.Tensor: + """Return text-only or multimodal-aware position ids for a single sample.""" + if processor is None: + return compute_position_id_with_mask(attention_mask) + + multi_modal_kwargs = { + "image_grid_thw": multi_modal_inputs.get("image_grid_thw"), + "video_grid_thw": multi_modal_inputs.get("video_grid_thw"), + } + if multi_modal_inputs.get("mm_token_type_ids") is not None: + mm_token_type_ids = torch.zeros_like(input_ids) + mm_token_type_ids[0][input_ids[0] == processor.image_token_id] = 1 + mm_token_type_ids[0][input_ids[0] == processor.video_token_id] = 2 + multi_modal_kwargs["mm_token_type_ids"] = mm_token_type_ids + + vision_position_ids, _ = processor.get_rope_index( + input_ids=input_ids, + attention_mask=attention_mask, + **multi_modal_kwargs, + ) + vision_position_ids = vision_position_ids.transpose(0, 1) + + valid_mask = attention_mask[0].bool() + text_position_ids = torch.ones((1, input_ids.shape[1]), dtype=torch.long, device=input_ids.device) + text_position_ids[0, valid_mask] = torch.arange(valid_mask.sum().item(), device=input_ids.device) + text_position_ids = text_position_ids.unsqueeze(0) + return torch.cat((text_position_ids, vision_position_ids), dim=1) diff --git a/verl/agent/framework/types.py b/verl/agent/framework/types.py new file mode 100644 index 00000000000..2abf18a8e50 --- /dev/null +++ b/verl/agent/framework/types.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol + +import numpy as np +import torch + + +@dataclass +class SessionHandle: + session_id: str + base_url: str | None = None + + +@dataclass +class Trajectory: + prompt_ids: list[int] + response_ids: list[int] + response_mask: list[int] + response_logprobs: list[float] | None = None + reward_info: dict[str, Any] = field(default_factory=dict) + reward_score: float | None = None + num_turns: int = 0 + routed_experts: torch.Tensor | np.ndarray | None = None + multi_modal_data: dict[str, Any] | None = None + extra_fields: dict[str, Any] = field(default_factory=dict) + + +class SessionRuntime(Protocol): + """Protocol for gateway-backed session lifecycle. + + Used by OpenAICompatibleAgentFramework to decouple the framework from the + concrete AsyncLLMServerManager / GatewayManager implementation, making it + testable without a Ray cluster. + """ + + async def create_session(self, session_id: str, **kwargs) -> SessionHandle: ... + async def complete_session(self, session_id: str) -> None: ... + async def finalize_session(self, session_id: str) -> list[Trajectory]: ... + async def abort_session(self, session_id: str) -> None: ... + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: ... diff --git a/verl/agent/gateway/__init__.py b/verl/agent/gateway/__init__.py new file mode 100644 index 00000000000..8fb072aa51b --- /dev/null +++ b/verl/agent/gateway/__init__.py @@ -0,0 +1,9 @@ +from .gateway import GatewayActor +from .manager import GatewayManager +from .runtime import GatewayServingRuntime + +__all__ = [ + "GatewayActor", + "GatewayManager", + "GatewayServingRuntime", +] diff --git a/verl/agent/gateway/gateway.py b/verl/agent/gateway/gateway.py new file mode 100644 index 00000000000..bfd7cfd0de1 --- /dev/null +++ b/verl/agent/gateway/gateway.py @@ -0,0 +1,752 @@ +from __future__ import annotations + +import asyncio +import json +import time +from dataclasses import replace +from typing import Any +from uuid import uuid4 + +import ray +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +from verl.agent.framework.types import SessionHandle, Trajectory +from verl.agent.gateway.types import GatewaySessionState, SessionPhase, TrajectoryBuffer +from verl.experimental.agent_loop.tool_parser import ToolParser +from verl.utils.chat_template import apply_chat_template as _apply_chat_template, initialize_system_prompt +from verl.utils.tokenizer import normalize_token_ids +from verl.workers.rollout.utils import run_uvicorn + + +class MalformedRequestError(ValueError): + pass + + +_DEFAULT_ALLOWED_REQUEST_SAMPLING_PARAM_KEYS = frozenset({ + "temperature", + "top_p", + "top_k", + "max_tokens", +}) + + +# Map backend stop_reason values to OpenAI-spec finish_reason values. +# OpenAI Chat Completions spec defines finish_reason ∈ +# {"stop", "length", "tool_calls", "content_filter", "function_call"}. +# +# Note on vLLM information loss: the vLLM rollout adapter +# (verl/workers/rollout/vllm_rollout/vllm_async_server.py:538-545) +# collapses vLLM's raw finish_reason "stop" and "length" into a single +# "completed" stop_reason before the gateway sees it. As a result, +# mapping "completed" -> "stop" here cannot recover whether generation +# actually hit max_tokens; recovering that distinction requires the +# vLLM adapter to preserve the raw finish_reason on TokenOutput. +# TODO(phase-c): preserve raw backend finish_reason on TokenOutput +# (e.g. a new TokenOutput.finish_reason field or +# extra_fields["finish_reason"]) so the gateway can distinguish vLLM +# "length" from "stop" instead of mapping both to "stop". +_FINISH_REASON_MAP = { + "completed": "stop", + "stop": "stop", + "matched_stop": "stop", + "eos": "stop", + "length": "length", + "max_tokens": "length", + "aborted": "stop", + "abort": "stop", +} + + +# TODO: double-check if all these validations/normalization are necessary +# Make sure they don't alter messages in unexpected ways. +def _normalize_message_content(content: Any) -> Any: + """Normalize message content: coerce None to empty string, validate type.""" + if isinstance(content, list | dict | str): + return content + if content is None: + return "" + raise MalformedRequestError(f"Unsupported content type: {type(content).__name__}") + + +def _validate_tool_calls(tool_calls: Any) -> None: + """Validate tool_calls structure. Does not modify content.""" + if not isinstance(tool_calls, list): + raise MalformedRequestError("tool_calls must be a list") + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + raise MalformedRequestError("tool_calls entries must be objects") + function = tool_call.get("function") + if not isinstance(function, dict): + raise MalformedRequestError("tool_call.function must be an object") + + +def _normalize_message(message: Any) -> dict[str, Any]: + """Normalize a single message: validate structure, coerce types, filter to known fields. + + Constructs a new dict with only role/content/tool_calls/tool_call_id. + This ensures prefix comparison is not affected by extraneous fields. + """ + if not isinstance(message, dict): + raise MalformedRequestError("messages entries must be objects") + + role = message.get("role") + if not isinstance(role, str) or not role: + raise MalformedRequestError("message.role must be a non-empty string") + + normalized: dict[str, Any] = { + "role": role, + "content": _normalize_message_content(message.get("content", "")), + } + if "name" in message: + name = message["name"] + if not isinstance(name, str): + raise MalformedRequestError("message.name must be a string") + normalized["name"] = name + if "tool_calls" in message: + _validate_tool_calls(message["tool_calls"]) + normalized["tool_calls"] = list(message["tool_calls"]) + if "tool_call_id" in message: + normalized["tool_call_id"] = str(message["tool_call_id"]) + return normalized + + +def _validate_tools(tools: Any) -> list[Any] | None: + """Validate tools structure. Does not modify content.""" + if tools is None: + return None + if not isinstance(tools, list): + raise MalformedRequestError("tools must be a list") + return tools + + +def _normalize_request_context(payload: dict[str, Any]) -> dict[str, Any]: + """Normalize and validate the request payload, extracting messages and tools. + """ + messages = payload.get("messages") + if not isinstance(messages, list) or not messages: + raise MalformedRequestError("messages must be non-empty") + return { + "messages": [_normalize_message(message) for message in messages], + "tools": _validate_tools(payload.get("tools")), + } + + +def _build_sampling_params( + payload: dict[str, Any], + *, + base_sampling_params: dict[str, Any], + allowed_request_sampling_param_keys: frozenset[str], +) -> dict[str, Any]: + sampling_params = dict(base_sampling_params) + for key in allowed_request_sampling_param_keys: + if key in payload: + sampling_params[key] = payload[key] + return sampling_params + + +def _canonicalize_tool_arguments_for_comparison(arguments: Any) -> tuple[str, Any]: + if isinstance(arguments, dict | list): + return ("json", arguments) + if isinstance(arguments, str): + try: + return ("json", json.loads(arguments)) + except json.JSONDecodeError: + return ("raw", arguments) + return ("raw", arguments) + + +def _canonicalize_message_for_prefix_comparison(message: dict[str, Any]) -> dict[str, Any]: + normalized = dict(message) + tool_calls = normalized.get("tool_calls") + if not isinstance(tool_calls, list): + return normalized + + normalized_tool_calls: list[dict[str, Any]] = [] + for tool_call in tool_calls: + normalized_tool_call = dict(tool_call) + function = normalized_tool_call.get("function") + if isinstance(function, dict) and "arguments" in function: + normalized_function = dict(function) + normalized_function["arguments"] = _canonicalize_tool_arguments_for_comparison(function["arguments"]) + normalized_tool_call["function"] = normalized_function + normalized_tool_calls.append(normalized_tool_call) + normalized["tool_calls"] = normalized_tool_calls + return normalized + + +def _is_message_prefix(prefix: list[dict[str, Any]], messages: list[dict[str, Any]]) -> bool: + if len(prefix) > len(messages): + return False + return [ + _canonicalize_message_for_prefix_comparison(message) + for message in prefix + ] == [ + _canonicalize_message_for_prefix_comparison(message) + for message in messages[: len(prefix)] + ] + + +def _is_request_context_prefix( + *, + session: GatewaySessionState, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, +) -> bool: + if session.request_tools != tools: + return False + # TODO: dict equality is not token-level equivalent — two tool schemas with + # different key order compare equal in Python but may tokenize differently. + # This could cause a false prefix match on the tools path. Low practical + # risk (same agent rarely reorders keys within a session), but worth noting. + #TODO: need to improve the prefix check logic, e.g.,how to handle tool lists and multimodal data + return _is_message_prefix(session.message_history, messages) + + +def _copy_trajectory_buffer(buffer: TrajectoryBuffer | None) -> TrajectoryBuffer | None: + if buffer is None: + return None + return TrajectoryBuffer( + prompt_ids=list(buffer.prompt_ids), + response_ids=list(buffer.response_ids), + response_mask=list(buffer.response_mask), + response_logprobs=list(buffer.response_logprobs), + ) + + +def _count_chat_turns(message_history: list[dict[str, Any]]) -> int: + """Count chat turns consistent with ToolAgentLoop semantics. + + ToolAgentLoop computes: user_turns + assistant_turns + 1 (the +1 accounts + for the initial prompt). We count user + assistant role messages and add 1. + System and tool messages are excluded. + """ + return sum(1 for m in message_history if m.get("role") in ("user", "assistant")) + 1 + + +def _materialize_response_logprobs(buffer: TrajectoryBuffer) -> list[float] | None: + if not buffer.response_logprobs: + return None + return list(buffer.response_logprobs) + + +def _build_multi_modal_trajectory_data( + image_data: list[Any] | None, + video_data: list[Any] | None, +) -> dict[str, Any] | None: + multi_modal_data: dict[str, Any] = {} + if image_data: + multi_modal_data["images"] = list(image_data) + if video_data: + multi_modal_data["videos"] = list(video_data) + return multi_modal_data or None + + +class _GatewayActor: + def __init__( + self, + tokenizer, + backend, + *, + processor=None, + vision_info_extractor=None, + vision_info_extractor_kwargs: dict[str, Any] | None = None, + tool_parser_name: str | None = None, + apply_chat_template_kwargs: dict[str, Any] | None = None, + base_sampling_params: dict[str, Any] | None = None, + allowed_request_sampling_param_keys: set[str] | frozenset[str] | None = None, + ): + # Same pattern as vllm_async_server.py / async_sglang_server.py: + # use the node's routable IP for both bind and URL. + self._server_address = ray.util.get_node_ip_address() + self._tokenizer = tokenizer + self._processor = processor + self._backend = backend + self._vision_info_extractor = vision_info_extractor or self._default_vision_info_extractor + self._vision_info_extractor_kwargs = dict(vision_info_extractor_kwargs or {}) + self._apply_chat_template_kwargs = apply_chat_template_kwargs or {} + self._base_sampling_params = dict(base_sampling_params or {}) + allowed_keys = ( + _DEFAULT_ALLOWED_REQUEST_SAMPLING_PARAM_KEYS + if allowed_request_sampling_param_keys is None + else frozenset(allowed_request_sampling_param_keys) + ) + self._allowed_request_sampling_param_keys = allowed_keys + self._system_prompt = initialize_system_prompt( + tokenizer, + **self._apply_chat_template_kwargs, + ) + self._tool_parser = ( + ToolParser.get_tool_parser(tool_parser_name, tokenizer) if tool_parser_name else None + ) + self._sessions: dict[str, GatewaySessionState] = {} + self._app = FastAPI() + self._server_port: int | None = None + self._server_task: asyncio.Task | None = None + self._server_base_url: str | None = None + self._register_routes() + + def _register_routes(self) -> None: + @self._app.post("/sessions/{session_id}/v1/chat/completions") + async def _chat_completions(session_id: str, request: Request): + payload = await request.json() + return await self._handle_chat_completions(session_id=session_id, payload=payload) + + @self._app.post("/sessions/{session_id}/complete") + async def _complete(session_id: str, request: Request): + payload = await request.json() + reward_info = payload.get("reward_info") + try: + await self.complete_session(session_id=session_id, reward_info=reward_info) + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + return JSONResponse({"status": "ok"}) + + def _require_started(self) -> None: + if self._server_base_url is None: + raise RuntimeError("GatewayActor.start() must be called before session creation") + + def _get_session(self, session_id: str) -> GatewaySessionState: + session = self._sessions.get(session_id) + if session is None: + raise KeyError(f"Unknown session_id: {session_id}") + return session + + def _set_phase(self, session: GatewaySessionState, phase: SessionPhase) -> None: + session.phase = phase + self._touch_session(session) + + def _touch_session(self, session: GatewaySessionState) -> None: + session.updated_at = time.time() + + def _materialize_active_trajectory(self, session: GatewaySessionState) -> None: + active = session.active_trajectory + if active is None: + return + + self._touch_session(session) + session.trajectories.append( + self._build_materialized_trajectory( + session=session, + active=active, + ) + ) + session.active_trajectory = None + + def _build_materialized_trajectory( + self, + *, + session: GatewaySessionState, + active: TrajectoryBuffer, + ) -> Trajectory: + return Trajectory( + prompt_ids=list(active.prompt_ids), + response_ids=list(active.response_ids), + response_mask=list(active.response_mask), + response_logprobs=_materialize_response_logprobs(active), + reward_info={}, + num_turns=_count_chat_turns(session.message_history), + multi_modal_data=_build_multi_modal_trajectory_data(session.image_data, session.video_data), + ) + + async def _default_vision_info_extractor( + self, + messages: list[dict[str, Any]], + *, + image_patch_size: int, + ) -> tuple[list[Any] | None, list[Any] | None]: + # Keep the dataset dependency lazy so custom extractors do not pay for + # RLHFDataset imports unless they actually use the default path. + from verl.utils.dataset.rl_dataset import RLHFDataset + + return await RLHFDataset.process_vision_info( + messages, + image_patch_size=image_patch_size, + config=self._vision_info_extractor_kwargs.get("config"), + ) + + async def _extract_multi_modal_data( + self, + messages: list[dict[str, Any]], + ) -> tuple[list[Any] | None, list[Any] | None]: + if self._processor is None: + return None, None + + has_multi_modal_blocks = False + for message in messages: + content = message.get("content") + if not isinstance(content, list): + continue + for part in content: + if isinstance(part, dict) and part.get("type") in {"image", "image_url", "video", "video_url"}: + has_multi_modal_blocks = True + break + if has_multi_modal_blocks: + break + + if not has_multi_modal_blocks: + return None, None + + return await self._vision_info_extractor( + messages, + image_patch_size=self._processor.image_processor.patch_size, + **self._vision_info_extractor_kwargs, + ) + + def _encode_full( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + image_data: list[Any] | None = None, + video_data: list[Any] | None = None, + ) -> list[int]: + """Encode a full conversation for a new trajectory (includes system prompt + generation prompt).""" + if self._processor is not None: + raw_prompt = _apply_chat_template( + self._processor, + messages, + tools=tools, + add_generation_prompt=True, + tokenize=False, + **self._apply_chat_template_kwargs, + ) + videos = video_data + video_metadata = None + if videos is not None: + videos, video_metadata = zip(*videos, strict=False) + videos, video_metadata = list(videos), list(video_metadata) + model_inputs = self._processor( + text=[raw_prompt], + images=image_data, + videos=videos, + video_metadata=video_metadata, + return_tensors="pt", + do_sample_frames=False, + ) + return normalize_token_ids(model_inputs["input_ids"]) + + return normalize_token_ids( + _apply_chat_template( + self._tokenizer, messages, tools=tools, add_generation_prompt=True, + **self._apply_chat_template_kwargs, + ) + ) + # TODO: check if delta tokenization is better than remove_system_prompt + def _encode_incremental( + self, + messages: list[dict[str, Any]], + image_data: list[Any] | None = None, + video_data: list[Any] | None = None, + ) -> list[int]: + """Encode incremental messages (tool results, user follow-ups) for a continuation turn. + + Uses the remove_system_prompt pattern from ToolAgentLoop: encode the new messages + alone (which prepends a system prompt), then strip the known system_prompt prefix. + No tools parameter — tool schema is already in the initial prompt_ids. + """ + if self._processor is not None: + raw_prompt = _apply_chat_template( + self._processor, + messages, + add_generation_prompt=True, + tokenize=False, + **self._apply_chat_template_kwargs, + ) + videos = video_data + video_metadata = None + if videos is not None: + videos, video_metadata = zip(*videos, strict=False) + videos, video_metadata = list(videos), list(video_metadata) + model_inputs = self._processor( + text=[raw_prompt], + images=image_data, + videos=videos, + video_metadata=video_metadata, + return_tensors="pt", + do_sample_frames=False, + ) + ids = normalize_token_ids(model_inputs["input_ids"]) + else: + ids = normalize_token_ids( + _apply_chat_template( + self._tokenizer, messages, add_generation_prompt=True, + **self._apply_chat_template_kwargs, + ) + ) + return ids[len(self._system_prompt):] + + async def _decode_response( + self, response_ids: list[int], *, tools: list[dict[str, Any]] | None = None, + stop_reason: str | None = None, + ) -> tuple[dict[str, Any], str]: + """Decode model output tokens into an OpenAI-compatible assistant message. + + Returns: + message: OpenAI-compatible assistant message. + finish_reason: "tool_calls" when tool calls are present, else the + OpenAI-spec-normalized stop_reason (see _FINISH_REASON_MAP). + """ + if self._tool_parser is not None and tools: + parsed_tools = None + try: + from verl.tools.schemas import OpenAIFunctionToolSchema + parsed_tools = [ + OpenAIFunctionToolSchema(**t) if isinstance(t, dict) else t + for t in tools + ] + except Exception: + pass + content, function_calls = await self._tool_parser.extract_tool_calls(response_ids, parsed_tools) + if function_calls: + tool_calls = [ + { + "id": f"call_{uuid4().hex[:8]}", + "type": "function", + "function": {"name": fc.name, "arguments": fc.arguments}, + } + for fc in function_calls + ] + message = { + "role": "assistant", + # Use "" instead of None so that prefix comparison with + # _normalize_message_content (which also coerces None → "") + # stays consistent. Both must agree on the None policy. + "content": content or "", + "tool_calls": tool_calls, + } + return message, "tool_calls" + response_text = self._tokenizer.decode(response_ids, skip_special_tokens=True) + finish_reason = _FINISH_REASON_MAP.get(stop_reason, stop_reason) if stop_reason else "stop" + return {"role": "assistant", "content": response_text}, finish_reason + + async def _handle_chat_completions(self, session_id: str, payload: dict[str, Any]) -> JSONResponse: + session = self._sessions.get(session_id) + if session is None: + raise HTTPException(status_code=404, detail=f"Unknown session_id: {session_id}") + + try: + request_context = _normalize_request_context(payload) + except MalformedRequestError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + async with session.generation_lock: + if session.phase != SessionPhase.ACTIVE: + raise HTTPException(status_code=409, detail=f"Session {session_id} is {session.phase.value.lower()}") + + async with session.request_lock: + if session.phase != SessionPhase.ACTIVE: + raise HTTPException( + status_code=409, detail=f"Session {session_id} is {session.phase.value.lower()}" + ) + + self._touch_session(session) + messages = request_context["messages"] + tools = request_context["tools"] + materialized_trajectory = None + image_data = None + video_data = None + + if session.active_trajectory is None: + image_data, video_data = await self._extract_multi_modal_data(messages) + active_trajectory = TrajectoryBuffer( + prompt_ids=self._encode_full( + messages, tools=tools, image_data=image_data, video_data=video_data + ) + ) + elif _is_request_context_prefix(session=session, messages=messages, tools=tools): + active_trajectory = _copy_trajectory_buffer(session.active_trajectory) + image_data = list(session.image_data) if session.image_data is not None else None + video_data = list(session.video_data) if session.video_data is not None else None + incremental_messages = messages[len(session.message_history) :] + if incremental_messages: + new_image_data, new_video_data = await self._extract_multi_modal_data(incremental_messages) + if new_image_data: + if image_data is None: + image_data = [] + image_data.extend(new_image_data) + if new_video_data: + if video_data is None: + video_data = [] + video_data.extend(new_video_data) + incremental_ids = self._encode_incremental( + incremental_messages, + image_data=new_image_data, + video_data=new_video_data, + ) + active_trajectory.response_ids.extend(incremental_ids) + active_trajectory.response_mask.extend([0] * len(incremental_ids)) + if active_trajectory.response_logprobs: + active_trajectory.response_logprobs.extend([0.0] * len(incremental_ids)) + else: + materialized_trajectory = self._build_materialized_trajectory( + session=session, + active=session.active_trajectory, + ) + image_data, video_data = await self._extract_multi_modal_data(messages) + active_trajectory = TrajectoryBuffer( + prompt_ids=self._encode_full( + messages, tools=tools, image_data=image_data, video_data=video_data + ) + ) + + generation_context_ids = active_trajectory.prompt_ids + active_trajectory.response_ids + sampling_params = _build_sampling_params( + payload, + base_sampling_params=self._base_sampling_params, + allowed_request_sampling_param_keys=self._allowed_request_sampling_param_keys, + ) + + output = await self._backend.generate( + request_id=session_id, + prompt_ids=generation_context_ids, + sampling_params=sampling_params, + image_data=image_data, + video_data=video_data, + ) + + response_ids = list(output.token_ids) + active_trajectory.response_ids.extend(response_ids) + active_trajectory.response_mask.extend([1] * len(response_ids)) + if output.log_probs is not None: + active_trajectory.response_logprobs.extend(list(output.log_probs)) + + assistant_msg, finish_reason = await self._decode_response( + response_ids, tools=tools, stop_reason=output.stop_reason, + ) + async with session.request_lock: + if session.phase != SessionPhase.ACTIVE: + raise HTTPException( + status_code=409, detail=f"Session {session_id} is {session.phase.value.lower()}" + ) + + if materialized_trajectory is not None: + session.trajectories.append(materialized_trajectory) + session.active_trajectory = active_trajectory + session.image_data = list(image_data) if image_data is not None else None + session.video_data = list(video_data) if video_data is not None else None + session.message_history = messages + [assistant_msg] + session.request_tools = tools + self._touch_session(session) + + return JSONResponse( + { + "id": f"chatcmpl-{uuid4().hex}", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": assistant_msg, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": len(generation_context_ids), + "completion_tokens": len(response_ids), + "total_tokens": len(generation_context_ids) + len(response_ids), + }, + } + ) + + async def start(self) -> None: + if self._server_task is not None: + return + self._server_port, self._server_task = await run_uvicorn(self._app, None, self._server_address) + self._server_base_url = f"http://{self._server_address}:{self._server_port}" + + async def shutdown(self) -> None: + if self._server_task is None: + return + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass + self._server_task = None + self._server_port = None + self._server_base_url = None + + async def create_session(self, session_id: str, metadata: dict[str, Any] | None = None) -> SessionHandle: + self._require_started() + if session_id in self._sessions: + raise RuntimeError(f"Session {session_id} already exists") + + handle = SessionHandle( + session_id=session_id, + base_url=f"{self._server_base_url}/sessions/{session_id}/v1", + ) + self._sessions[session_id] = GatewaySessionState(handle=handle, metadata=dict(metadata or {})) + return handle + + async def complete_session(self, session_id: str, reward_info: dict[str, Any] | None = None) -> None: + session = self._get_session(session_id) + async with session.request_lock: + # Accommodate retry attempts + if session.phase not in {SessionPhase.COMPLETED, SessionPhase.ACTIVE}: + raise RuntimeError(f"Session {session_id} is {session.phase.value.lower()}") + + if reward_info is not None: + session.reward_info = dict(reward_info) + + self._set_phase(session, SessionPhase.COMPLETED) + session.completed.set() + + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: + session = self._sessions.get(session_id) + if session is None: + # Already finalized or aborted by a concurrent caller — nothing to wait for. + return + if session.phase == SessionPhase.COMPLETED: + # Fast path: agent already called /complete, no need to wait. + return + + await asyncio.wait_for(session.completed.wait(), timeout=timeout) + + # Post-await: the session may have been aborted during the wait. + # The local reference is still valid even if the session was removed from _sessions. + if session.phase == SessionPhase.ABORTED: + raise RuntimeError(f"Session {session_id} is aborted") + + async def finalize_session(self, session_id: str) -> list[Trajectory]: + session = self._get_session(session_id) + async with session.request_lock: + if session.phase == SessionPhase.ABORTED: + raise RuntimeError(f"Session {session_id} is aborted") + if session.phase == SessionPhase.FINALIZED: + raise RuntimeError(f"Session {session_id} is finalized") + + self._touch_session(session) + self._materialize_active_trajectory(session) + self._set_phase(session, SessionPhase.FINALIZED) + session.completed.set() + trajectories = [replace(trajectory, reward_info=dict(session.reward_info)) for trajectory in session.trajectories] + self._sessions.pop(session_id, None) + return trajectories + + async def abort_session(self, session_id: str) -> None: + session = self._sessions.get(session_id) + if session is None: + return # Already finalized or aborted — treat as idempotent. + async with session.request_lock: + if session.phase == SessionPhase.ABORTED: + return # Concurrent abort — idempotent. + if session.phase == SessionPhase.FINALIZED: + raise RuntimeError(f"Session {session_id} is finalized") + + self._set_phase(session, SessionPhase.ABORTED) + session.completed.set() + self._sessions.pop(session_id, None) + + async def get_session_state(self, session_id: str) -> dict[str, Any]: + session = self._get_session(session_id) + return { + "session_id": session.handle.session_id, + "metadata": dict(session.metadata), + "phase": session.phase.value, + "created_at": session.created_at, + "updated_at": session.updated_at, + "num_trajectories": len(session.trajectories), + "has_active_trajectory": session.active_trajectory is not None, + } + + +GatewayActor = ray.remote(_GatewayActor) diff --git a/verl/agent/gateway/manager.py b/verl/agent/gateway/manager.py new file mode 100644 index 00000000000..c439c7e3ea9 --- /dev/null +++ b/verl/agent/gateway/manager.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import asyncio + + +async def _await_object_ref(object_ref): + return await asyncio.wrap_future(object_ref.future()) + + +class GatewayManager: + """Session-routing component owned by the serving runtime.""" + + def __init__(self, gateways: list): + self.gateways = gateways + self.gateway_count = len(gateways) + self.active_sessions_per_gateway = [0 for _ in gateways] + self._session_to_gateway_index: dict[str, int] = {} + + def _select_gateway_index(self) -> int: + if not self.gateways: + raise RuntimeError("No gateway actors configured") + return min(range(len(self.gateways)), key=lambda index: self.active_sessions_per_gateway[index]) + + def _get_gateway_index(self, session_id: str) -> int: + gateway_index = self._session_to_gateway_index.get(session_id) + if gateway_index is None: + raise KeyError(session_id) + return gateway_index + + def _get_gateway(self, session_id: str): + gateway_index = self._get_gateway_index(session_id) + return self.gateways[gateway_index], gateway_index + + async def create_session(self, session_id: str, **kwargs): + gateway_index = self._select_gateway_index() + gateway = self.gateways[gateway_index] + handle = await _await_object_ref(gateway.create_session.remote(session_id=session_id, **kwargs)) + self._session_to_gateway_index[session_id] = gateway_index + self.active_sessions_per_gateway[gateway_index] += 1 + return handle + + async def finalize_session(self, session_id: str): + gateway, gateway_index = self._get_gateway(session_id) + trajectories = await _await_object_ref(gateway.finalize_session.remote(session_id=session_id)) + self._session_to_gateway_index.pop(session_id, None) + self.active_sessions_per_gateway[gateway_index] -= 1 + return trajectories + + async def complete_session(self, session_id: str) -> None: + gateway, _ = self._get_gateway(session_id) + await _await_object_ref(gateway.complete_session.remote(session_id=session_id)) + + async def abort_session(self, session_id: str) -> None: + gateway, gateway_index = self._get_gateway(session_id) + await _await_object_ref(gateway.abort_session.remote(session_id=session_id)) + self._session_to_gateway_index.pop(session_id, None) + self.active_sessions_per_gateway[gateway_index] -= 1 + + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: + gateway, _ = self._get_gateway(session_id) + await _await_object_ref(gateway.wait_for_completion.remote(session_id=session_id, timeout=timeout)) diff --git a/verl/agent/gateway/runtime.py b/verl/agent/gateway/runtime.py new file mode 100644 index 00000000000..c489ed16a26 --- /dev/null +++ b/verl/agent/gateway/runtime.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import asyncio +from typing import Any + +import ray + +from verl.workers.rollout.llm_server import LLMServerClient + + +async def _await_ray_ref(object_ref): + return await asyncio.wrap_future(object_ref.future()) + + +class GatewayServingRuntime: + """Standalone serving runtime that owns gateway actors and delegates backend routing.""" + + def __init__( + self, + llm_client: LLMServerClient, + *, + gateway_manager=None, + gateway_count: int = 0, + gateway_actor_kwargs: dict[str, Any] | None = None, + ): + self._llm_client = llm_client + self.owned_gateway_actors: list[ray.actor.ActorHandle] = [] + self.gateway_manager = gateway_manager + + if self.gateway_manager is None and gateway_count > 0: + self._initialize_gateway_runtime( + gateway_count=gateway_count, + gateway_actor_kwargs=gateway_actor_kwargs, + ) + + def _initialize_gateway_runtime( + self, + *, + gateway_count: int, + gateway_actor_kwargs: dict[str, Any] | None = None, + ) -> None: + from verl.agent.gateway.gateway import GatewayActor + from verl.agent.gateway.manager import GatewayManager + + gateway_actor_kwargs = dict(gateway_actor_kwargs or {}) + if "backend" not in gateway_actor_kwargs: + gateway_actor_kwargs["backend"] = self + + # Round-robin across alive CPU nodes so gateway actors do not all pack onto + # the driver node under Ray's default PACK scheduling. Mirrors + # AgentLoopWorker placement (verl/experimental/agent_loop/agent_loop.py). + node_ids = [ + node["NodeID"] + for node in ray.nodes() + if node["Alive"] and node["Resources"].get("CPU", 0) > 0 + ] + if not node_ids: + raise RuntimeError("No alive CPU nodes available for GatewayActor placement") + + self.owned_gateway_actors = [ + GatewayActor.options( + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=node_ids[i % len(node_ids)], soft=True, + ), + ).remote(**gateway_actor_kwargs) + for i in range(gateway_count) + ] + ray.get([gateway.start.remote() for gateway in self.owned_gateway_actors]) + self.gateway_manager = GatewayManager(self.owned_gateway_actors) + + def _require_session_runtime(self): + if self.gateway_manager is None: + raise RuntimeError("Session runtime is disabled because gateway_count=0") + return self.gateway_manager + + async def create_session(self, session_id: str, **kwargs): + gateway_manager = self._require_session_runtime() + return await gateway_manager.create_session(session_id=session_id, **kwargs) + + async def finalize_session(self, session_id: str): + gateway_manager = self._require_session_runtime() + return await gateway_manager.finalize_session(session_id=session_id) + + async def complete_session(self, session_id: str) -> None: + gateway_manager = self._require_session_runtime() + await gateway_manager.complete_session(session_id=session_id) + + async def abort_session(self, session_id: str) -> None: + gateway_manager = self._require_session_runtime() + await gateway_manager.abort_session(session_id=session_id) + + async def wait_for_completion(self, session_id: str, timeout: float | None = None) -> None: + gateway_manager = self._require_session_runtime() + await gateway_manager.wait_for_completion(session_id=session_id, timeout=timeout) + + async def shutdown(self) -> None: + if self.owned_gateway_actors: + await asyncio.gather(*[_await_ray_ref(gateway.shutdown.remote()) for gateway in self.owned_gateway_actors]) + self.owned_gateway_actors = [] + self.gateway_manager = None + + async def generate( + self, + request_id, + *, + prompt_ids: list[int], + sampling_params: dict[str, Any], + image_data: list[Any] | None = None, + video_data: list[Any] | None = None, + **kwargs: Any, + ) -> Any: + return await self._llm_client.generate( + request_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, + image_data=image_data, + video_data=video_data, + **kwargs, + ) diff --git a/verl/agent/gateway/types.py b/verl/agent/gateway/types.py new file mode 100644 index 00000000000..b690ab7d4b4 --- /dev/null +++ b/verl/agent/gateway/types.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from verl.agent.framework.types import SessionHandle, Trajectory + + +class SessionPhase(str, Enum): + ACTIVE = "ACTIVE" + COMPLETED = "COMPLETED" + FINALIZED = "FINALIZED" + ABORTED = "ABORTED" + + +@dataclass +class TrajectoryBuffer: + prompt_ids: list[int] + response_ids: list[int] = field(default_factory=list) + response_mask: list[int] = field(default_factory=list) + response_logprobs: list[float] = field(default_factory=list) + + +@dataclass +class GatewaySessionState: + handle: SessionHandle + metadata: dict[str, Any] = field(default_factory=dict) + request_tools: list[dict[str, Any]] | None = None + message_history: list[dict[str, Any]] = field(default_factory=list) + image_data: list[Any] | None = None + video_data: list[Any] | None = None + active_trajectory: TrajectoryBuffer | None = None + trajectories: list[Trajectory] = field(default_factory=list) + reward_info: dict[str, Any] = field(default_factory=dict) + completed: asyncio.Event = field(default_factory=asyncio.Event) + phase: SessionPhase = SessionPhase.ACTIVE + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + request_lock: asyncio.Lock = field(default_factory=asyncio.Lock) + generation_lock: asyncio.Lock = field(default_factory=asyncio.Lock) diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index 6642fd26966..6585046acf8 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -38,7 +38,10 @@ KVBatchMeta, ) + _TRANSFER_QUEUE_AVAILABLE = True + except ImportError: + _TRANSFER_QUEUE_AVAILABLE = False class BatchMeta: pass @@ -69,6 +72,103 @@ def _raise(*args, **kwargs): TQ_INITIALIZED = False +_TQ_FORCE_NESTED_READBACK_ENV = "VERL_FORCE_TQ_NESTED_READBACK" + +_TQ_SEQUENCE_FIELDS_REQUIRING_NESTED = { + "prompts", + "responses", + "response_mask", + "loss_mask", + "input_ids", + "attention_mask", + "position_ids", + "rollout_log_probs", + "old_log_probs", + "ref_log_prob", + "rm_scores", + "token_level_scores", + "token_level_rewards", + "advantages", + "returns", + "values", + "log_probs", + "entropy", + "teacher_logprobs", + "teacher_ids", + "routed_experts", +} + + +def _force_tq_nested_readback_enabled() -> bool: + return os.getenv(_TQ_FORCE_NESTED_READBACK_ENV, "").lower() in {"1", "true", "yes", "on"} + + +def force_tq_sequence_fields_nested(data: TensorDict) -> TensorDict: + """Restore nested semantics for sequence fields read from TransferQueue. + + TransferQueue 0.1.6 may return a dense tensor when all samples in a field + have the same shape, even if that field was originally written as a jagged + NestedTensor. No-padding VERL workers expect token sequence fields to stay + nested, so opt-in jobs can normalize known sequence fields at the TQ bridge. + """ + if not isinstance(data, TensorDict): + return data + + for key in _TQ_SEQUENCE_FIELDS_REQUIRING_NESTED: + if key not in data.keys(): + continue + + value = data[key] + if not isinstance(value, torch.Tensor) or value.is_nested or value.dim() < 2 or value.size(0) == 0: + continue + + rows = list(value.unbind(0)) + ragged_idx = 2 if key == "position_ids" and rows[0].dim() == 2 else None + data[key] = tu.nested_tensor_from_tensor_list(rows, ragged_idx=ragged_idx) + + return data + + +def _wrap_tq_sync_batch_get(func): + if getattr(func, "_verl_force_nested_sequence_fields", False): + return func + + @wraps(func) + def wrapper(*args, **kwargs): + return force_tq_sequence_fields_nested(func(*args, **kwargs)) + + wrapper._verl_force_nested_sequence_fields = True + return wrapper + + +def _wrap_tq_async_batch_get(func): + if getattr(func, "_verl_force_nested_sequence_fields", False): + return func + + @wraps(func) + async def wrapper(*args, **kwargs): + return force_tq_sequence_fields_nested(await func(*args, **kwargs)) + + wrapper._verl_force_nested_sequence_fields = True + return wrapper + + +def _install_tq_nested_readback_wrappers() -> None: + if not _TRANSFER_QUEUE_AVAILABLE or not _force_tq_nested_readback_enabled(): + return + + for name in ("kv_batch_get", "kv_batch_get_by_meta"): + func = getattr(tq, name, None) + if func is not None: + setattr(tq, name, _wrap_tq_sync_batch_get(func)) + + for name in ("async_kv_batch_get", "async_kv_batch_get_by_meta"): + func = getattr(tq, name, None) + if func is not None: + setattr(tq, name, _wrap_tq_async_batch_get(func)) + + +_install_tq_nested_readback_wrappers() # TODO (TQ): verl will make all actor async, so this can be cleanup later. def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any: @@ -119,6 +219,8 @@ async def _async_meta_to_realdata(meta: BatchMeta | KVBatchMeta) -> TensorDict: tq_client = tq.get_client() tensordict = await tq_client.async_get_data(meta) + if _force_tq_nested_readback_enabled(): + force_tq_sequence_fields_nested(tensordict) for key, val in meta_info.items(): if isinstance(val, (NonTensorData | NonTensorStack)):