From 319ebf7fabf92a424d298d877ca821ee879f27f5 Mon Sep 17 00:00:00 2001 From: adityasoni9998 Date: Wed, 25 Mar 2026 19:25:18 -0400 Subject: [PATCH] apptainer support for RL on codescout models --- platoon/train/areal/patches.py | 8 + platoon/train/areal/rl.py | 6 +- platoon/utils/openhands_utils.py | 121 +++++++--- plugins/codescout/README.md | 54 +++++ .../codescout/platoon/codescout/__init__.py | 0 .../codescout/custom_tools/__init__.py | 0 .../custom_tools/localization_finish.py | 210 +++++++++++++++++ plugins/codescout/platoon/codescout/env.py | 181 ++++++++++++++ .../platoon/codescout/prompts/user_prompt.j2 | 14 ++ .../codescout/platoon/codescout/rollout.py | 220 ++++++++++++++++++ plugins/codescout/platoon/codescout/tasks.py | 49 ++++ plugins/codescout/platoon/codescout/train.py | 37 +++ .../platoon/codescout/train_codescout.yaml | 152 ++++++++++++ plugins/codescout/pyproject.toml | 70 ++++++ plugins/openhands/platoon/openhands/env.py | 76 ++++-- 15 files changed, 1146 insertions(+), 52 deletions(-) create mode 100644 plugins/codescout/README.md create mode 100644 plugins/codescout/platoon/codescout/__init__.py create mode 100644 plugins/codescout/platoon/codescout/custom_tools/__init__.py create mode 100644 plugins/codescout/platoon/codescout/custom_tools/localization_finish.py create mode 100644 plugins/codescout/platoon/codescout/env.py create mode 100644 plugins/codescout/platoon/codescout/prompts/user_prompt.j2 create mode 100644 plugins/codescout/platoon/codescout/rollout.py create mode 100644 plugins/codescout/platoon/codescout/tasks.py create mode 100644 plugins/codescout/platoon/codescout/train.py create mode 100644 plugins/codescout/platoon/codescout/train_codescout.yaml create mode 100644 plugins/codescout/pyproject.toml diff --git a/platoon/train/areal/patches.py b/platoon/train/areal/patches.py index 370aafa..bf6d535 100644 --- a/platoon/train/areal/patches.py +++ b/platoon/train/areal/patches.py @@ -205,6 +205,14 @@ async def patched_create( # Convert messages to prompt format tools_val = tools if not is_omitted(tools) else None if self.chat_template_type == "hf": + for message in messages_list: + if isinstance(message["content"], list): + new_content = "".join( + item.get("text", "") + for item in message["content"] + if isinstance(item, dict) and item.get("type") == "text" + ) + message["content"] = new_content prompt_token_ids = self.tokenizer.apply_chat_template( messages_list, tools=tools_val, diff --git a/platoon/train/areal/rl.py b/platoon/train/areal/rl.py index ff6c4ba..e5ddaeb 100644 --- a/platoon/train/areal/rl.py +++ b/platoon/train/areal/rl.py @@ -190,13 +190,15 @@ def __init__( self.ref.initialize(None, self.ft_spec) # Setup proxy servers - self.llm_client = ArealOpenAI(engine=self.rollout, tokenizer=self.tokenizer) + # TODO: How to fix the hard-coded tool_call_parser here? + self.llm_client = ArealOpenAI(engine=self.rollout, tokenizer=self.tokenizer, tool_call_parser="qwen25") free_port = find_free_ports(1)[0] self.proxy_server = ProxyServer(free_port, client=self.llm_client) self.proxy_server.start(wait_until_ready=True) # Eval proxy uses the eval rollout engine; training rollout is paused during eval. - self.eval_llm_client = ArealOpenAI(engine=self.eval_rollout, tokenizer=self.tokenizer) + # TODO: How to fix the hard-coded tool_call_parser here? + self.eval_llm_client = ArealOpenAI(engine=self.eval_rollout, tokenizer=self.tokenizer, tool_call_parser="qwen25") eval_free_port = find_free_ports(1)[0] self.eval_proxy_server = ProxyServer(eval_free_port, client=self.eval_llm_client) self.eval_proxy_server.start(wait_until_ready=True) diff --git a/platoon/utils/openhands_utils.py b/platoon/utils/openhands_utils.py index 3887259..2f97b80 100644 --- a/platoon/utils/openhands_utils.py +++ b/platoon/utils/openhands_utils.py @@ -1,60 +1,108 @@ -from openhands.sdk.conversation.state import ConversationExecutionStatus + +from typing import Sequence from openhands.sdk.event import ActionEvent, AgentErrorEvent, Event, EventID, MessageEvent - +from openhands.sdk.tool.builtins.finish import FinishAction +from openhands.sdk.event.conversation_error import ConversationErrorEvent +from openhands.sdk.conversation import ConversationExecutionStatus from platoon.openhands.types import OpenHandsObservation +from collections import defaultdict -def is_action(event: Event) -> bool: - return isinstance(event, ActionEvent) or (isinstance(event, MessageEvent) and event.source == "agent") - +def _is_terminal_status(conversation_state) -> bool: + """Return True if the conversation execution status is a terminal state.""" + return conversation_state.execution_status in ( + ConversationExecutionStatus.FINISHED, + ConversationExecutionStatus.STUCK, + ConversationExecutionStatus.ERROR, + ) -# TODO: Simplify by looking at changes in llm_response_id. When it changes, consider it a new action. -def get_actions_for_last_obs(observation: OpenHandsObservation, require_same_llm_call_id: bool = False) -> list[Event]: - """Collect Event(s) we consider as actions that immediately follow a past ObservationEvent and are - fully observed by a subsequent ObservationBaseEvent referencing them. - """ +def is_action(event: Event) -> bool: + return isinstance(event, ActionEvent) \ + or (isinstance(event, MessageEvent) and event.source == "agent") + +def group_actions(events: Sequence[Event]): + """Build a map of llm_response_id -> list of ActionEvent IDs.""" + batches: dict[EventID, list[EventID]] = defaultdict(list) + action_id_to_response_id: dict[EventID, EventID] = {} + tool_call_id_to_action_id = {} + action_id_to_tool_call_id = {} + + for event in events: + if isinstance(event, ActionEvent) or (isinstance(event, MessageEvent) and event.source == "agent"): + llm_response_id = event.llm_response_id + batches[llm_response_id].append(event.id) + action_id_to_response_id[event.id] = llm_response_id + if isinstance(event, ActionEvent) and event.tool_call_id is not None: + tool_call_id_to_action_id[event.tool_call_id] = event.id + action_id_to_tool_call_id[event.id] = event.tool_call_id + + return batches, action_id_to_response_id, tool_call_id_to_action_id, action_id_to_tool_call_id + +def get_actions_for_last_obs(observation: OpenHandsObservation, require_same_llm_call_id: bool = True) -> list[Event]: + """Collect all Actions between the last observation.last_step_observation_id and the most recent observation, ensuring that all these Actions have a corresponding observation except for messages and finish actions from agent.""" events = observation.conversation_state.events new_actions: list[Event] = list() - new_actions_candidates: list[Event] = list() seen_action_ids: set[EventID] = set() at_least_one_future_obs_seen = False at_least_one_future_error_event_seen = False + batches, action_id_to_response_id, tool_call_id_to_action_id, action_id_to_tool_call_id = group_actions(events) for event in reversed(events): + # Only consider events after the last observed event if event.id == observation.last_step_observation_id: break + if not is_action(event): - new_actions.clear() + new_actions.clear() # clear all accumulated actions till now if a non-action event (observation) happened before them at_least_one_future_obs_seen = True if hasattr(event, "action_id"): - seen_action_ids.add(event.action_id) - if isinstance(event, AgentErrorEvent): + seen_action_ids.add(event.action_id) # event.action_id is always present for observation events + + if isinstance(event, AgentErrorEvent) and event.tool_call_id is not None and event.tool_call_id in tool_call_id_to_action_id: + # If we see an agent error event that references a tool call id, we should consider the corresponding action as having a future observation, since agent error events are a type of observation event that LLM would see and react to and AgentErrorEvents don't terminate agent loop. + seen_action_ids.add(tool_call_id_to_action_id[event.tool_call_id]) + + if isinstance(event, ConversationErrorEvent): # this event will terminate the agent loop at_least_one_future_error_event_seen = True continue else: new_actions.append(event) - new_actions_candidates.append(event) + if isinstance(event, MessageEvent) and event.source == "agent": + seen_action_ids.add(event.id) + at_least_one_future_obs_seen = True + elif isinstance(event, ActionEvent) and event.source == "agent" and ( + isinstance(event.action, FinishAction) + or _is_terminal_status(observation.conversation_state) + ): + # The agent submitted a terminal action (built-in FinishAction or a custom tool that set execution_status to FINISHED/STUCK/ERROR, + # e.g. LocalizationFinishAction). Treat this the same as a message action: mark it as "seen" so the downstream validation logic + # doesn't clear it for lacking a corresponding observation. + seen_action_ids.add(event.id) + at_least_one_future_obs_seen = True + + if len(new_actions) == 0: + return new_actions + last_event_seen = new_actions[0].id if new_actions else None - last_event_seen = new_actions[-1].id if new_actions else None if not is_finished(observation, last_event_seen=last_event_seen) and not at_least_one_future_error_event_seen: for action in new_actions: - if isinstance(action, ActionEvent) and action.id not in seen_action_ids: + if action.id not in seen_action_ids: + print(f"Clearing new_actions due to action event that has not been observed in a future observation: {action.id} {action.kind}", flush=True) new_actions.clear() break if not at_least_one_future_obs_seen: new_actions.clear() - + if require_same_llm_call_id and new_actions: llm_call_id = new_actions[0].llm_response_id if any(action.llm_response_id != llm_call_id for action in new_actions): - raise ValueError( - "Detected at least two actions in a step with differing llm_response_id. " - "This is unexpected and can lead to undefined behavior." - ) + raise ValueError("Detected at least two actions in a step with differing llm_response_id. " + "This is unexpected and can lead to undefined behavior.") + if len(new_actions) != len(batches[llm_call_id]): + print("Warning: The number of new actions detected does not match the number of actions in the batch for the corresponding llm_response_id. This could indicate that some actions are not being properly observed or that there are unexpected events in the conversation history.", flush=True) return list(reversed(new_actions)) - def get_obs_for_last_action(observation: OpenHandsObservation) -> list[Event]: """Collect event(s) that immediately follow a past ActionEvent and are fully observed by a subsequent ObservationBaseEvent referencing them. @@ -73,9 +121,15 @@ def get_obs_for_last_action(observation: OpenHandsObservation) -> list[Event]: else: new_obs.append(event) + if len(new_obs) == 0: + return new_obs + + # NOTE: The primary objective of oh_conversation_finished is to not clear the observations for the last action if the conversation has already entered exited. This is done to allow processing the final observation events that will not have any subsequent action events. We don't check termination via is_finished() as the problem is that platoon episode has not caught up simply because the last event in openhands event stream has not yet been stored in last_step_observation_id or last_step_action_id fields in OpenHandsObservation state, even though it is the final event in the conversation. This causes the logic in is_finished() to return False and consequently clear the observations for the final action. As a result, the loop will get stuck and timeout since no future actions will ever be observed. + conversation_state = observation.conversation_state + oh_conversation_finished = _is_terminal_status(conversation_state) + # If not at least one future action seen and if this obs is not the final one, empty the list. - last_event_seen = new_obs[-1].id if new_obs else None - if not at_least_one_future_action_seen and not is_finished(observation, last_event_seen=last_event_seen): + if not at_least_one_future_action_seen and not oh_conversation_finished: new_obs.clear() return list(reversed(new_obs)) @@ -83,15 +137,10 @@ def get_obs_for_last_action(observation: OpenHandsObservation) -> list[Event]: def is_finished(observation: OpenHandsObservation, last_event_seen: EventID | None = None) -> bool: conversation_state = observation.conversation_state - oh_conversation_finished = ( - conversation_state.agent_status == ConversationExecutionStatus.FINISHED - or conversation_state.agent_status == ConversationExecutionStatus.STUCK - or conversation_state.agent_status == ConversationExecutionStatus.ERROR - ) + oh_conversation_finished = _is_terminal_status(conversation_state) last_event_id = conversation_state.events[-1].id - platoon_episode_caught_up = last_event_id in ( - observation.last_step_action_id, - observation.last_step_observation_id, - last_event_seen, - ) - return oh_conversation_finished and platoon_episode_caught_up + assert last_event_id is not None, "Last event in conversation must have a non-None ID" + valid_ids = [event_id for event_id in [observation.last_step_action_id, observation.last_step_observation_id, last_event_seen] if event_id is not None] + platoon_episode_caught_up = last_event_id in valid_ids + + return oh_conversation_finished and platoon_episode_caught_up \ No newline at end of file diff --git a/plugins/codescout/README.md b/plugins/codescout/README.md new file mode 100644 index 0000000..84df8f0 --- /dev/null +++ b/plugins/codescout/README.md @@ -0,0 +1,54 @@ +# CodeScout Training in Platoon using Apptainer Runtime + +This README explains how to train LLMs using RL recipe the [CodeScout paper](https://arxiv.org/abs/2603.17829) when using the Apptainer sandbox. + +> Assumption: commands are run on Linux, from the repo root unless stated otherwise. ripgrep must be installed on the system. + + +![CodeScout main figure (verified file-level)](https://raw.githubusercontent.com/OpenHands/codescout/974238b1d22308fd9cf0c79d3544697f4206ec2c/docs/verified_file_main.png) + +![CodeScout main figure (verified function-level)](https://raw.githubusercontent.com/OpenHands/codescout/974238b1d22308fd9cf0c79d3544697f4206ec2c/docs/verified_function_main.png) + +![CodeScout system diagram](https://raw.githubusercontent.com/OpenHands/codescout/974238b1d22308fd9cf0c79d3544697f4206ec2c/docs/recipe.png) + +--- + +## Instructions to Train CodeScout Models + +### Environment setup + +Execute the following commands from the repository root: + +```bash +mkdir -p /tmp/apptainer_cache +mkdir -p /tmp/apptainer_tmp +mkdir -p /tmp/areal/experiments +mkdir -p /tmp/areal/name_resolve +uv sync --extra areal --extra wandb +source .venv/bin/activate +uv pip install -e plugins/codescout +``` + +--- + +### Apptainer and logging environment variables + +Set these enviroment variables before launching training: + +```bash +export APPTAINER_CACHEDIR=/tmp/apptainer_cache +export APPTAINER_TMPDIR=/tmp/apptainer_tmp +export OPENHANDS_SUPPRESS_BANNER=1 +export WANDB_API_KEY="" +``` + +### Train CodeScout Models + +We provide an example config in [train_codescout.yaml](./platoon/codescout/train_codescout.yaml) which trains CodeScout-1.7B using RL from an SFT'ed checkpoint (CodeScout-1.7B-RFT) which can be modified to use a different model, GPU setup, and other training hyper-parameters. + +```bash +cd plugins/codescout +python3 -m areal.launcher.local \ + platoon/codescout/train.py \ + --config platoon/codescout/train_codescout.yaml +``` \ No newline at end of file diff --git a/plugins/codescout/platoon/codescout/__init__.py b/plugins/codescout/platoon/codescout/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/codescout/platoon/codescout/custom_tools/__init__.py b/plugins/codescout/platoon/codescout/custom_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/codescout/platoon/codescout/custom_tools/localization_finish.py b/plugins/codescout/platoon/codescout/custom_tools/localization_finish.py new file mode 100644 index 0000000..5856725 --- /dev/null +++ b/plugins/codescout/platoon/codescout/custom_tools/localization_finish.py @@ -0,0 +1,210 @@ +"""Custom finish tool for code localization tasks. + +This tool allows the agent to submit localization results in a structured format where: +- File path is required +- Class name is optional +- Function name is optional +""" + +import sys + +# This module lives at two importable paths simultaneously: +# - "platoon.codescout.custom_tools.localization_finish" (via the platoon package) +# - "custom_tools.localization_finish" (via the .pth entry for plugins/codescout) +# +# Python treats these as two separate modules and will load the file twice, +# producing two distinct class objects that both claim __module__ = +# "custom_tools.localization_finish" after our __module__ override. +# Pydantic then sees two Action/Observation subclasses with the same name +# and raises "Duplicate class definition". +# +# Fix: whichever path loads us first, immediately register both keys in +# sys.modules pointing to the SAME module object, so the second import +# finds the cached version instead of re-executing the file. +_SHORT = "custom_tools.localization_finish" +_LONG = "platoon.codescout.custom_tools.localization_finish" +_PLUGINS = "plugins.codescout.platoon.codescout.custom_tools.localization_finish" +_SHORT_PKG = "custom_tools" +_LONG_PKG = "platoon.codescout.custom_tools" + +_this_module = sys.modules[__name__] +if _SHORT not in sys.modules: + sys.modules[_SHORT] = _this_module +if _LONG not in sys.modules: + sys.modules[_LONG] = _this_module +if _PLUGINS not in sys.modules: + sys.modules[_PLUGINS] = _this_module +if _SHORT_PKG not in sys.modules and _LONG_PKG in sys.modules: + sys.modules[_SHORT_PKG] = sys.modules[_LONG_PKG] + +import json +from typing import TYPE_CHECKING +from collections.abc import Sequence + +from pydantic import BaseModel, Field, computed_field +from rich.text import Text + +from openhands.sdk import ( + Action, + Observation, + ToolDefinition +) +from openhands.sdk.tool import ToolExecutor, ToolAnnotations, register_tool + +from openhands.sdk.conversation.state import ConversationExecutionStatus + +if TYPE_CHECKING: + from openhands.sdk.conversation.base import BaseConversation + +class CodeLocation(BaseModel): + """A single code location with optional class and function.""" + + file: str = Field(description="Path to the file (required)") + class_name: str | None = Field(default=None, description="Class name (optional)") + function_name: str | None = Field(default=None, description="Function/method name (optional)") + +class LocalizationFinishAction(Action): + """Action for submitting final localization results.""" + + locations: list[CodeLocation] = Field( + description="""List of code locations to modify. Each location in this list must have: +- file: Path to the file relative to the repository root (required) +- class_name: Class name (optional, omit for changes to imports, global variables, and global functions) +- function_name: Function/method name (optional, omit for changes that edit parts of a file outside of any particular function) +""" + ) + + @property + def visualize(self) -> Text: + """Return Rich Text representation of this action.""" + content = Text() + content.append("Submitting localization results:\n", style="bold blue") + content.append(f"Found {len(self.locations)} location(s):\n", style="green") + for i, loc in enumerate(self.locations, 1): + content.append(f" {i}. {loc.file}", style="cyan") + if loc.class_name: + content.append(f" → {loc.class_name}", style="yellow") + if loc.function_name: + content.append(f".{loc.function_name}", style="magenta") + content.append("\n") + return content + +# Override __module__ so the client-side pydantic discriminator registry uses the +# same module path as the server container ("custom_tools.localization_finish"). +# This prevents a "Duplicate class definition" error when the server sends events +# back and the client tries to deserialize them — both sides see the same name. +LocalizationFinishAction.__module__ = "custom_tools.localization_finish" + + +class LocalizationFinishObservation(Observation): + """Observation returned after submitting localization results. No observation is needed since the agent will exit after this action.""" + + @property + def visualize(self) -> Text: + """Return an empty Text representation since the message is in the action.""" + return Text() + +LocalizationFinishObservation.__module__ = "custom_tools.localization_finish" + +def locations_to_dict_list(locations: list[CodeLocation]) -> list[dict]: + """Convert CodeLocation objects to dictionary format. + + Args: + locations: List of CodeLocation objects + + Returns: + List of dictionaries with 'file', 'class_name', 'function_name' keys + """ + return [ + { + "file": loc.file, + "class_name": loc.class_name, + "function_name": loc.function_name, + } + for loc in locations + ] + +class LocalizationFinishExecutor(ToolExecutor): + def __call__( + self, + action: LocalizationFinishAction, + conversation: "BaseConversation | None" = None, # noqa: ARG002 + ) -> LocalizationFinishObservation: + try: + loc_dict = locations_to_dict_list(action.locations) + text = json.dumps(loc_dict, indent=2) + conversation.state.execution_status = ConversationExecutionStatus.FINISHED + return LocalizationFinishObservation.from_text(text=text) + except Exception as _: + return LocalizationFinishObservation.from_text(text="") + +TOOL_DESCRIPTION = """Submit your final code localization results. + +Use this tool when you have identified all relevant files, classes, and functions that need to be modified to address the issue described in the problem statement. + +Provide a structured list of locations. Each location must have: +- file: Path to the file relative to the root of the repository (required) +- class_name: Class name (optional) +- function_name: Function/method name (optional) + +You must submit a list of locations that require modification and for each location you must follow the below rules in your output: +1. If the required modifications belong to a specific function that belongs to a class, provide the file path, class name, and function name. +2. If the required modification belongs to a function that is not part of any class, provide the file path and function name. +3. If the required modification does not belong to any specific class or a function (e.g. global variables, imports, new class, new global function etc.), it is sufficient to provide only the file path. +4. If the required modification belongs to a class (e.g. adding a new method to a class, changing the class inheritance), provide the file path and class name. If you are modifying the __init__ method of a class, you should provide the function name as well. + +IMPORTANT: +1. If multiple different edits need to be edited in the same file, you should create separate entries for each edit, specifying the same file path but different class/function names as applicable. Each entry should compulsorily include the file path. +2. Do NOT include duplicate entries in your output for which the file, class, and function names are all identical. +3. Ensure that the file paths are accurate and relative to the root of the repository without any leading "./" or "/". All locations must be valid and exist in the codebase and this applies to class and function names as well. +4. Aim for high precision (all returned locations are relevant) and high recall (no relevant locations missed). +5. The agent will terminate its execution after you call this tool. +""" + +class LocalizationFinishTool(ToolDefinition[LocalizationFinishAction, LocalizationFinishObservation]): + """Tool for submitting final localization results.""" + + """Tool for submitting final code localization results.""" + + @classmethod + def create( + cls, + conv_state, # noqa: ARG003 + **params + ) -> Sequence["LocalizationFinishTool"]: + """Create LocalizationFinishTool instance. + + Args: + conv_state: Conversation state (provides workspace info) + workspace_dir: Optional workspace directory override + **params: Additional parameters + + Returns: + A sequence containing a single LocalizationFinishTool instance. + """ + if params: + raise ValueError("LocalizationFinishTool doesn't accept parameters") + + return [ + cls( + action_type=LocalizationFinishAction, + observation_type=LocalizationFinishObservation, + description=TOOL_DESCRIPTION, + executor=LocalizationFinishExecutor(), + annotations=ToolAnnotations( + title="localization_finish", + readOnlyHint=True, + destructiveHint=False, + idempotentHint=True, + openWorldHint=False, + ), + ) + ] + +# Override __module__ on all classes so both client and server use the same +# module path ("custom_tools.localization_finish") in every registry: +# - the pydantic discriminator registry (prevents "Duplicate class definition") +# - the tool registry _MODULE_QUALNAMES (tells the server what to import) +LocalizationFinishTool.__module__ = "custom_tools.localization_finish" + +register_tool("LocalizationFinishTool", LocalizationFinishTool) \ No newline at end of file diff --git a/plugins/codescout/platoon/codescout/env.py b/plugins/codescout/platoon/codescout/env.py new file mode 100644 index 0000000..8ee7bf0 --- /dev/null +++ b/plugins/codescout/platoon/codescout/env.py @@ -0,0 +1,181 @@ +from typing import List, Tuple +from platoon.utils.openhands_utils import is_finished +from platoon.openhands.env import OpenHandsEnv +from openhands.sdk.event import ActionEvent +from platoon.codescout.custom_tools.localization_finish import LocalizationFinishAction + +def get_structured_locations(events): + """Extract structured locations from LocalizationFinishAction in events. + Args: + events: List of conversation events to search through. + Returns: + List of location dicts with 'file', 'class', 'function' keys, or None if not found. + """ + # Find the last LocalizationFinishAction + cnt = [1 for event in events if isinstance(event, ActionEvent) and event.source == "agent" and isinstance(event.action, LocalizationFinishAction)] + cnt = sum(cnt) + if cnt != 1: # the localization finish tool must be called exactly once. + return None + for event in reversed(events): + if ( + isinstance(event, ActionEvent) + and event.source == "agent" + and isinstance(event.action, LocalizationFinishAction) + ): + # Extract structured locations from the action + locations = [] + for loc in event.action.locations: + locations.append({ + "file": loc.file, + "class_name": loc.class_name, + "function_name": loc.function_name, + }) + return locations + return None + +def parse_structured_outputs(structured_locations: List[dict]) -> Tuple[List[str], List[str], List[str]]: + """ + Process structured location outputs and extract files, modules, and entities. + + Args: + structured_locations: List of dicts with 'file', 'class_name', 'function_name' keys + Returns: + Tuple of (all_found_files, all_found_modules, all_found_entities) where each is a list of strs + + Example structured input format: + [ + {'file': 'path/to/file1.py', 'class_name': 'MyClass', 'function_name': 'my_method'}, + {'file': 'path/to/file2.py', 'class_name': None, 'function_name': 'standalone_function'}, + {'file': 'path/to/file1.py', 'class_name': None, 'function_name': 'global_function'}, + {'file': 'path/to/file2.py', 'class_name': 'AnotherClass', 'function_name': None}, + {'file': 'path/to/file3.py', 'class_name': None, 'function_name': None} + ] + + Example output: + [['path/to/file1.py', 'path/to/file2.py', 'path/to/file3.py'], ['path/to/file1.py:MyClass', 'path/to/file2.py:AnotherClass', 'path/to/file1.py:global_function', 'path/to/file2.py:standalone_function'], ['path/to/file1.py:MyClass.my_method', 'path/to/file2.py:standalone_function', 'path/to/file1.py:global_function', 'path/to/file2.py:AnotherClass']] + """ + + all_found_files = [] + all_found_modules = [] + all_found_entities = [] + + found_empty_filename = False + + for location in structured_locations: + file_path = location.get("file", None) + class_name = location.get("class_name", None) + function_name = location.get("function_name", None) + + # NOTE: Ideally the case of file_path being None should raise an error from the agent-sdk but adding here for safety + if file_path is None or file_path.strip() == "": + found_empty_filename = True + break + + all_found_files.append(file_path) + + module = None + if class_name: + module = f"{file_path}:{class_name}" + elif function_name: + module = f"{file_path}:{function_name}" + + if module: + all_found_modules.append(module) + + entity = None + if class_name and function_name: + entity = f"{file_path}:{class_name}.{function_name}" + elif function_name: + entity = f"{file_path}:{function_name}" + + if entity: + all_found_entities.append(entity) + if found_empty_filename: + return [], [], [] + all_found_files = list(set(all_found_files)) + all_found_modules = list(set(all_found_modules)) + all_found_entities = list(set(all_found_entities)) + return all_found_files, all_found_modules, all_found_entities + +def compute_file_f1_score(predicted_files, true_files, beta=1.0): + pred, true = set(predicted_files), set(true_files) + if not true: + return 0.0 # return 0 reward if ground truth is empty + tp = len(pred & true) + precision = tp / len(pred) if pred else 0.0 + recall = tp / len(true) if true else 0.0 + return (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) if (precision + recall) > 0 else 0.0 + +def multilevel_localization_f1_reward( + instance: dict, + structured_locations: list[dict] | None = None, + file_level_weight: float=1.0, + module_level_weight: float=1.0, + entity_level_weight: float=1.0, +): + + if structured_locations is None: + return 0, { + "multilevel_localization_f1_reward": 0, + "file_reward": 0, + "module_reward": 0, + "entity_reward": 0, + } + + gt_files = [] + gt_modules = [] + gt_entities = [] + reward = 0 + + for change in instance.get("file_changes", []): + if "file" in change: + gt_files.append(change["file"]) + if "changes" in change: + edited_modules = change["changes"].get("edited_modules", []) + edited_modules = [] if edited_modules is None else edited_modules + for module in edited_modules: + gt_modules.append(module) + + edited_entities = change["changes"].get("edited_entities", []) + edited_entities = [] if edited_entities is None else edited_entities + for entity in edited_entities: + gt_entities.append(entity) + gt_files = set(gt_files) + gt_modules = set(gt_modules) + gt_entities = set(gt_entities) + + if structured_locations is not None: + predicted_files, predicted_modules, predicted_entities = parse_structured_outputs(structured_locations) + else: + predicted_files, predicted_modules, predicted_entities = get_simple_results_from_raw_outputs(final_message) + + file_f1_score = compute_file_f1_score(predicted_files, gt_files) + module_f1_score = compute_file_f1_score(predicted_modules, gt_modules) + entity_f1_score = compute_file_f1_score(predicted_entities, gt_entities) + + reward = ( + file_f1_score * file_level_weight + + module_f1_score * module_level_weight + + entity_f1_score * entity_level_weight + ) + + return reward, { + "multilevel_localization_f1_reward": reward, + "file_reward": file_f1_score, + "module_reward": module_f1_score, + "entity_reward": entity_f1_score, + } + +class CodeScoutEnv(OpenHandsEnv): + async def evaluate(self) -> tuple[float, dict]: + if not is_finished(await self.observe()): + return 0, {} + + structured_locations = get_structured_locations(self._conversation.state.events) + + if structured_locations is None: + return 0, {} + + instance: dict = self.task.misc + reward, metadata = multilevel_localization_f1_reward(instance, structured_locations) + return reward, metadata \ No newline at end of file diff --git a/plugins/codescout/platoon/codescout/prompts/user_prompt.j2 b/plugins/codescout/platoon/codescout/prompts/user_prompt.j2 new file mode 100644 index 0000000..b898c29 --- /dev/null +++ b/plugins/codescout/platoon/codescout/prompts/user_prompt.j2 @@ -0,0 +1,14 @@ +I have access to a python code repository in the directory {{ working_dir }} . Consider the following issue description: + + +{{ instance.problem_statement }} + + +Act as a code search agent and localize the specific files, classes or functions of code that need modification to resolve the issue in . + +NOTE: You do not need to solve the issue, all you need to do is localize relevant code from the repository. Your output will be used to guide another agent to solve the issue. + +IMPORTANT: Your output MUST follow the below rules: +1. The final output must be a tool call to the "localization_finish" tool containing relevant code locations. +2. The locations of the file path must be RELATIVE to the {{ working_dir }} directory WITHOUT any leading "./" in the output. +3. Only include those locations in your output that need modification to resolve the issue in . Do NOT include any locations that do not need modification. \ No newline at end of file diff --git a/plugins/codescout/platoon/codescout/rollout.py b/plugins/codescout/platoon/codescout/rollout.py new file mode 100644 index 0000000..35787ae --- /dev/null +++ b/plugins/codescout/platoon/codescout/rollout.py @@ -0,0 +1,220 @@ +import os +from jinja2 import Environment, FileSystemLoader +import asyncio +from platoon.envs.base import Task +from platoon.codescout.env import CodeScoutEnv +from pathlib import Path +from openhands.sdk import LLM, get_logger, Agent, Tool +from openhands.workspace import ApptainerWorkspace +from platoon.episode.trajectory import TrajectoryCollection +from platoon.config_defs import RolloutConfig +from platoon.episode.loop import run_episode +from platoon.episode.context import current_trajectory_collection +from platoon.visualization.event_sinks import JsonlFileSink +from platoon.codescout.tasks import EVAL_AGENT_SERVER_IMAGE, USER_PROMPT_FILENAME, APPTAINER_CACHE_DIR +from platoon.openhands.agent import OpenHandsAgent +import platform +import uuid +from openhands.tools.terminal import TerminalTool +from platoon.codescout.custom_tools.localization_finish import LocalizationFinishTool # noqa: F401 - registers the tool with the correct module qualname + +logger = get_logger(__name__) + +# NOTE: ApptainerWorkspace._wait_for_health has a hard-coded default of 120s. +# If that is too short when the SIF cache is cold or the agent server is slow to start. Patch it to default to 600s instead using below monkey patch. +# _orig_wait_for_health = ApptainerWorkspace._wait_for_health +# def _patched_wait_for_health(self, timeout: float = 600.0) -> None: +# return _orig_wait_for_health(self, timeout=timeout) +# ApptainerWorkspace._wait_for_health = _patched_wait_for_health # type: ignore[method-assign] + +def detect_platform(): + """Detects the correct platform string.""" + machine = platform.machine().lower() + if "arm" in machine or "aarch64" in machine: + return "linux/arm64" + return "linux/amd64" + +def prepare_workspace(instance: dict): + uuid_str = str(uuid.uuid4())[:8] + workspace = Path(f"/tmp/testbed/{uuid_str}/") + instance_id: str = instance["instance_id"] + repo_name: str = instance["repo"] + patch: str = instance["patch"] + + instance_dir_name = f"{repo_name.replace('/', '_')}_{instance_id}" + instance_path = workspace / instance_dir_name + + os.makedirs(APPTAINER_CACHE_DIR, exist_ok=True) + + # use the openhands agent server image and then setup env manually + workspace = ApptainerWorkspace( + server_image=EVAL_AGENT_SERVER_IMAGE, + working_dir=str(instance_path), + platform=detect_platform(), + cache_dir=os.environ.get("APPTAINER_CACHEDIR", APPTAINER_CACHE_DIR), + detach_logs=True + ) + + def _run(cmd: str, timeout: float = 120.0) -> None: + """Run a command inside the workspace, raising on failure or timeout. + + httpx.ReadTimeout bubbles out of execute_command with the unhelpful + message 'timed out'. This wrapper catches *any* exception and + re-raises with the command text so we can identify the culprit. + """ + try: + result = workspace.execute_command(cmd, timeout=timeout) + except Exception as exc: + raise RuntimeError(f"Command raised {type(exc).__name__}: {exc}\n cmd: {cmd}") from exc + if result.exit_code != 0: + raise RuntimeError( + f"Command failed (exit {result.exit_code}): {result.stderr}\n cmd: {cmd}" + ) + + try: + _run(f"git clone https://github.com/{repo_name}.git {str(instance_path)}", timeout=120.0) + _run(f"cd {str(instance_path)} && git apply <<'EOF'\n{patch}\nEOF", timeout=120.0) + except Exception as e: + raise RuntimeError(f"Error preparing workspace for instance {instance_id}: {e}") + return True, instance_path, workspace + +def get_instruction( + instance: dict, + prompt_path: str, + workspace_path: str, +) -> str: + """Generate instruction for the agent.""" + # Set up Jinja2 environment + prompts_dir = os.path.dirname(prompt_path) + template_name = os.path.basename(prompt_path) + env = Environment(loader=FileSystemLoader(prompts_dir)) + template = env.get_template(template_name) + + # Prepare context for rendering + context = { + "instance": instance, + "working_dir": workspace_path, + } + + # Render the instruction + instruction = template.render(context) + return instruction + +def prepare_llm(config: RolloutConfig) -> LLM: + model_name = config.model_name + temperature = 1.0 + if not model_name.startswith("openai/") and not model_name.startswith("litellm_proxy/"): + model_name = "openai/" + model_name + + llm=LLM( + usage_id="agent", + model=model_name, + base_url=config.model_endpoint, + api_key="sk-xxx", + temperature=temperature, + litellm_extra_body={ + "include_stop_str_in_output": False, + "chat_template_kwargs": { + # "add_generation_prompt": True, #NOTE: setting this to true raises errors + "enable_thinking": False + } + } + ) + return llm + +async def cleanup_resources(agent, env): + if env is not None: + await env.close() + env = None + +async def run_rollout(task: Task, config: RolloutConfig) -> dict | TrajectoryCollection: + agent = env = agent_wrapper_platoon = None + try: + if config.verbose: + print(f"[run_rollout] Process {os.getpid()}: Starting rollout for task {task.id}", flush=True) + instance: dict = task.misc + try: + loop = asyncio.get_event_loop() + # Run in a separate thread to avoid blocking the event loop. + status, working_dir, workspace = await loop.run_in_executor( + None, # Uses default ThreadPoolExecutor + prepare_workspace, + instance + ) + except Exception as e: + raise RuntimeError( + f"Workspace setup failed for task {task.id}: {e}" + ) + if not status or working_dir is None: + raise RuntimeError(f"Workspace setup failed for task {task.id}") + + user_prompt_filename = USER_PROMPT_FILENAME + prompt_dir = (Path(__file__).parent / "prompts").resolve() + user_prompt_path = prompt_dir / user_prompt_filename + assert user_prompt_path.exists(), f"User prompt path {user_prompt_path} not found" + input_message = get_instruction(instance, str(user_prompt_path), str(working_dir)) + + task.goal = input_message + task.max_steps = config.max_steps if config.max_steps is not None else 6 + + llm: LLM = prepare_llm(config) + agent: Agent = Agent( + llm=llm, + tools=[Tool(name=TerminalTool.name), Tool(name="LocalizationFinishTool")], + system_prompt_filename="/app/prompts_codescout/system_prompt.j2", + include_default_tools=[] + ) + agent_wrapper_platoon: OpenHandsAgent = OpenHandsAgent() + env: CodeScoutEnv = CodeScoutEnv(task=task, agent=agent, workspace=workspace) + + traj_collection = TrajectoryCollection() + current_trajectory_collection.set(traj_collection) + + events_path = os.path.join( + config.output_dir, + "events", + f"events_{task.id}_{traj_collection.id}.jsonl" + ) + + traj_collection.register_event_handlers( + JsonlFileSink( + events_path, + collection_id=traj_collection.id, + process_id=os.getpid() + ) + ) + + rollout_task = asyncio.create_task(run_episode(agent_wrapper_platoon, env, timeout=300)) + try: + # Apply a hard timeout to the entire rollout, not just individual steps + _ = await asyncio.wait_for(rollout_task, timeout=330) + except asyncio.TimeoutError: + if config.verbose: + print(f"Process {os.getpid()}: Rollout timed out for task {task.id}", flush=True) + raise + except Exception as e: + if config.verbose: + print(f"Process {os.getpid()}: Rollout failed for task {task.id}: {str(e)}", flush=True) + raise + + try: + await asyncio.wait_for(cleanup_resources(agent_wrapper_platoon, env), timeout=60) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception) as _: + pass + if config.return_dict: + return current_trajectory_collection.get().to_dict() + else: + return current_trajectory_collection.get() + except Exception as e: + if config.verbose: + print(f"Error running rollout for task {task.id}: {e}", flush=True) + try: + await asyncio.wait_for(cleanup_resources(agent_wrapper_platoon, env), timeout=60) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception) as _: + pass + raise + finally: + try: + await asyncio.wait_for(cleanup_resources(agent_wrapper_platoon, env), timeout=60) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception) as _: + pass \ No newline at end of file diff --git a/plugins/codescout/platoon/codescout/tasks.py b/plugins/codescout/platoon/codescout/tasks.py new file mode 100644 index 0000000..1fb6164 --- /dev/null +++ b/plugins/codescout/platoon/codescout/tasks.py @@ -0,0 +1,49 @@ +from platoon.envs.base import Task +from typing import Dict, Optional +import numpy as np +from datasets import load_dataset + +EVAL_AGENT_SERVER_IMAGE = "docker.io/adityasoni8/eval-agent-server:5f106d0-custom-base-image_tag_latest-source" +USER_PROMPT_FILENAME = "user_prompt.j2" +APPTAINER_CACHE_DIR = "/tmp/apptainer_cache" + +data_loaded: bool = False +train_data_map: Optional[Dict[str, Task]] = {} +val_data_map: Optional[Dict[str, Task]] = {} + +def create_task_from_instance(x: dict) -> Task: + task = Task( + id=x['instance_id'], + misc=x, + ) + return task + +def load_data(): + global data_loaded, train_data_map, val_data_map + if data_loaded: + return train_data_map, val_data_map + + dataset = load_dataset("adityasoni17/SWE-smith-py-code-search", split='train').to_pandas() + np.random.seed(42) + split_indices = np.random.rand(len(dataset)) < 0.9 + train_df = dataset.iloc[split_indices] + val_df = dataset.iloc[~split_indices] + for _, row in train_df.iterrows(): + if len(row["problem_statement"]) > 0: + train_data_map[row['instance_id']] = create_task_from_instance(row.to_dict()) + for _, row in val_df.iterrows(): + if len(row["problem_statement"]) > 0: + val_data_map[row['instance_id']] = create_task_from_instance(row.to_dict()) + data_loaded = True + print(f"Loaded {len(train_data_map)} training instances and {len(val_data_map)} validation instances.", flush=True) + return train_data_map, val_data_map + +def get_task(task_id: str) -> Task: + load_data() + global train_data_map, val_data_map + if task_id in train_data_map: + return train_data_map[task_id] + elif task_id in val_data_map: + return val_data_map[task_id] + else: + raise ValueError(f"Task ID {task_id} not found in training or validation data.") \ No newline at end of file diff --git a/plugins/codescout/platoon/codescout/train.py b/plugins/codescout/platoon/codescout/train.py new file mode 100644 index 0000000..4d3248b --- /dev/null +++ b/plugins/codescout/platoon/codescout/train.py @@ -0,0 +1,37 @@ +import sys +import logging +from datasets import Dataset +from areal.api.cli_args import load_expr_config +logging.basicConfig(level=logging.INFO) # Quiet by default +logging.getLogger("platoon.train.areal.workflows").setLevel(logging.INFO) +logging.getLogger("httpx").setLevel(logging.WARNING) # Silence httpx spam + +from platoon.codescout.tasks import get_task, load_data +from platoon.codescout.rollout import run_rollout +from platoon.train.areal import PlatoonArealRLTrainer, PlatoonArealRLTrainerConfig +from platoon.train.areal.workflows import StepWiseArealWorkflow + +def main(args): + config, _ = load_expr_config(args, PlatoonArealRLTrainerConfig) + config: PlatoonArealRLTrainerConfig = config + + train_datamap, val_datamap = load_data() + train_dataset = Dataset.from_list([{ "task_id": x } for x in train_datamap.keys()]) + val_dataset = Dataset.from_list([{ "task_id": x } for x in val_datamap.keys()]) + + with PlatoonArealRLTrainer( + config=config, + train_dataset=train_dataset, + val_dataset=val_dataset, + ) as trainer: + proxy_server = trainer.proxy_server + workflow = StepWiseArealWorkflow(run_rollout, get_task, config.workflow_config, proxy_server, 'train_rollout', trainer.actor.device) + eval_workflow = StepWiseArealWorkflow(run_rollout, get_task, config.workflow_config, proxy_server, 'eval_rollout', trainer.actor.device) + + trainer.train( + workflow=workflow, + eval_workflow=eval_workflow, + ) + +if __name__ == "__main__": + main(sys.argv[1:]) \ No newline at end of file diff --git a/plugins/codescout/platoon/codescout/train_codescout.yaml b/plugins/codescout/platoon/codescout/train_codescout.yaml new file mode 100644 index 0000000..5c7c874 --- /dev/null +++ b/plugins/codescout/platoon/codescout/train_codescout.yaml @@ -0,0 +1,152 @@ +experiment_name: codescout-grpo +trial_name: trial8 + +cluster: + n_nodes: 1 + n_gpus_per_node: 6 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: sglang:d2p1t1+d4p1t1 +seed: 42 +# enable_offload: false +total_train_epochs: 1 +tokenizer_path: ${actor.path} +gconfig: + n_samples: 8 + min_new_tokens: 0 + max_new_tokens: 8192 + greedy: false + temperature: 1.0 + +workflow_config: + rollout_config: + model_name: ${actor.path} + max_steps: 6 + output_dir: /tmp/areal/experiments/codescout-grpo-trial8 + verbose: True + train: true + timeout: 600 + group_size: 8 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 2 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 3 + enable_rollout_tracing: false + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: adityasoni17/Qwen3-1.7B-RFT-500 + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 32768 + optimizer: + type: adam + lr: 1e-6 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.0003 + eps_clip_higher: 0.0004 + temperature: ${gconfig.temperature} + reward_scaling: 1.0 + reward_bias: 0.0 + kl_ctl: 0.0 + kl_estimator: k1 + importance_sampling_level: sequence + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: false + behav_imp_weight_cap: null + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 32768 + optimizer: null + +train_dataset: + type: rl + batch_size: 8 + path: "" + +valid_dataset: + batch_size: 8 + type: rl + path: "" + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: 10 + freq_secs: null + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: online + project: swe-grpo-platoon + +recover: + mode: auto + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 5 + freq_steps: null + freq_secs: 3600 + +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +launcher: + inference_server_cpus_per_gpu: 1 + inference_server_mem_per_gpu: 38000 + trainer_cpus_per_gpu: 1 + trainer_mem_per_gpu: 45000 \ No newline at end of file diff --git a/plugins/codescout/pyproject.toml b/plugins/codescout/pyproject.toml new file mode 100644 index 0000000..2b4b409 --- /dev/null +++ b/plugins/codescout/pyproject.toml @@ -0,0 +1,70 @@ +[project] +name = "platoon-codescout" +version = "0.1.0" +description = "Platoon plugin for training OSS models with RL for repo-level code localization." +requires-python = "~=3.12.0" +authors = [ + {name = "Aditya Bharat Soni", email = "adityabs@cs.cmu.edu"} +] +dependencies = [ + "platoon >= 0.1.0", + "platoon-openhands >= 0.1.0", + "openhands-sdk", + "openhands-tools", + "openhands-workspace", + "openhands-agent-server" +] +[project.optional-dependencies] +# Training backends - install one of these for training +tinker = [ + "platoon[tinker]", +] +# NOTE: areal backend requires uv for installation (not available on PyPI) +areal = [ + "platoon[areal]", +] +# Logging integrations +wandb = [ + "platoon[wandb]", +] +# uv-specific configuration +[tool.uv] +no-build-isolation-package = ['flash-attn'] +# tinker and areal backends are mutually exclusive +conflicts = [ + [ + { extra = "tinker" }, + { extra = "areal" }, + ], +] +override-dependencies = [ + "fastapi[standard]>=0.115.0", + "openai==1.99.6", + "xgrammar==0.1.24", + "outlines-core==0.1.26", + "pyarrow==20.0.0", + "huggingface_hub==0.34", + "datasets==4.3.0", + "networkx==3.3.0" # This can be removed if ai-rubric pins 3.3.0 or areal relaxes the pin. +] +[tool.uv.sources] +platoon = { path = "../..", editable = true } +platoon-openhands = { path = "../openhands", editable = true } +openhands-sdk = { git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "737eae2b2384f04fe5853639294931284ab1f283", subdirectory = "openhands-sdk" } +openhands-tools = { git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "737eae2b2384f04fe5853639294931284ab1f283", subdirectory = "openhands-tools" } +openhands-workspace = { git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "737eae2b2384f04fe5853639294931284ab1f283", subdirectory = "openhands-workspace" } +openhands-agent-server = { git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "737eae2b2384f04fe5853639294931284ab1f283", subdirectory = "openhands-agent-server" } + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + + +[tool.hatch.build.targets.wheel] +packages = ["platoon"] diff --git a/plugins/openhands/platoon/openhands/env.py b/plugins/openhands/platoon/openhands/env.py index 326b2aa..55a4123 100644 --- a/plugins/openhands/platoon/openhands/env.py +++ b/plugins/openhands/platoon/openhands/env.py @@ -3,13 +3,10 @@ import asyncio import threading from copy import deepcopy - -from openhands.sdk.agent.base import AgentBase -from openhands.sdk.conversation import get_agent_final_response -from openhands.sdk.conversation.base import BaseConversation -from openhands.sdk.conversation.conversation import Conversation -from openhands.sdk.conversation.state import ConversationExecutionStatus -from openhands.sdk.workspace.base import BaseWorkspace +import concurrent.futures +from openhands.sdk.agent import AgentBase +from openhands.sdk.conversation import get_agent_final_response, BaseConversation, Conversation, ConversationExecutionStatus, RemoteConversation +from openhands.sdk.workspace import BaseWorkspace from platoon.envs.base import Task from platoon.episode.context import ( current_trajectory, @@ -19,7 +16,7 @@ ) from platoon.utils.openhands_utils import get_obs_for_last_action, is_finished -from .types import OpenHandsAction, OpenHandsObservation, OpenHandsTrajectoryStep +from platoon.openhands.types import OpenHandsAction, OpenHandsObservation, OpenHandsTrajectoryStep class OpenHandsEnv: @@ -30,7 +27,8 @@ def __init__(self, task: Task, agent: AgentBase, workspace: str | BaseWorkspace) workspace = str(workspace) self._workspace = workspace self._conversation = None - + self._run_thread: threading.Thread | None = None + async def reset(self) -> OpenHandsObservation: self._conversation: BaseConversation = Conversation( agent=self._agent, @@ -38,10 +36,13 @@ async def reset(self) -> OpenHandsObservation: visualizer=None, max_iteration_per_run=self._task.max_steps, ) + if isinstance(self._conversation, RemoteConversation): + self._conversation.delete_on_close = True self._state = OpenHandsObservation(task=self._task, conversation_state=self._conversation.state) self._conversation.send_message(self._task.goal) # NOTE: Run the conversation in a separate thread to avoid blocking the main thread. - threading.Thread(target=self._conversation.run, daemon=True).start() + self._run_thread = threading.Thread(target=self._conversation.run, kwargs={'timeout': 300}, daemon=True) + self._run_thread.start() traj_collection = current_trajectory_collection.get() traj = current_trajectory.get() @@ -83,11 +84,17 @@ async def step(self, action: OpenHandsAction) -> OpenHandsObservation: if is_finished(self._state): self._state.finished = True - finish_message.set(get_agent_final_response(self._conversation.state.events)) + agent_final_msg: str | None = get_agent_final_response(self._conversation.state.events) + if agent_final_msg is None or agent_final_msg.strip() == "": + agent_final_msg = "No final response from agent." + finish_message.set(agent_final_msg) self._state.misc["finish_message"] = finish_message.get() - if self._state.conversation_state.agent_status == ConversationExecutionStatus.STUCK: + if self._state.conversation_state.execution_status == ConversationExecutionStatus.STUCK: error_message.set("Agent got stuck") self._state.misc["error_message"] = error_message.get() + elif self._state.conversation_state.execution_status == ConversationExecutionStatus.ERROR: + error_message.set("Agent encountered an error") + self._state.misc["error_message"] = error_message.get() traj_collection = current_trajectory_collection.get() traj = current_trajectory.get() @@ -98,8 +105,49 @@ async def step(self, action: OpenHandsAction) -> OpenHandsObservation: async def close(self) -> None: if self._conversation is not None: - self._conversation.close() - self._conversation = None + conversation = self._conversation + self._conversation = None + # Fire-and-forget: submit close() to a thread pool so the DELETE + # request completes even if this coroutine is cancelled by + # asyncio.wait_for() or CancelledError from the parent task. + # We use a standalone executor submit (not awaited) so cancellation + # of this coroutine cannot prevent the HTTP DELETE from being sent. + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + future = executor.submit(self._close_conversation_sync, conversation, self._workspace) + try: + # Give it a reasonable amount of time, but don't block forever + await asyncio.wait_for( + asyncio.wrap_future(future), + timeout=120 + ) + except (asyncio.TimeoutError, asyncio.CancelledError, Exception) as e: + # Even if we're cancelled or timed out, the thread-pool task + # will still finish in the background (the DELETE gets sent). + print(f"env.close() interrupted ({type(e).__name__}: {e}), " + f"cleanup thread will finish in background", flush=True) + finally: + # Don't call executor.shutdown(wait=True) which would block; + # let the daemon thread finish on its own. + executor.shutdown(wait=False) + # Wait briefly for the run-polling thread to notice the conversation + # was deleted and exit on its own. + if self._run_thread is not None: + self._run_thread.join(timeout=5) + if self._run_thread.is_alive(): + print("Warning: conversation run thread still alive after close()", flush=True) + self._run_thread = None + + @staticmethod + def _close_conversation_sync(conversation: BaseConversation, workspace=None) -> None: + """Synchronous helper that calls conversation.close() in a background thread. + This runs outside the asyncio event loop so it cannot be cancelled by + CancelledError. The DELETE request will always be sent.""" + try: + conversation.close() + if workspace is not None and not isinstance(workspace, str): + workspace.cleanup() + except Exception as e: + print(f"Error in background conversation.close(): {e}", flush=True) # TODO: Consider adding a return_copy option here. async def observe(self) -> OpenHandsObservation: