From bf60e9d3a4cadacc23abf51157cb35365fae4854 Mon Sep 17 00:00:00 2001 From: ashmitkx <66110457+ashmitkx@users.noreply.github.com> Date: Thu, 25 Jun 2026 06:54:57 +0000 Subject: [PATCH 1/2] add base vita files for overlay --- .../VitaBench/src/vita/agent/llm_agent.py | 225 ++++ .../VitaBench/src/vita/cli.py | 265 ++++ .../src/vita/data_model/simulation.py | 433 ++++++ .../VitaBench/src/vita/domains/ota/tools.py | 1161 +++++++++++++++++ .../src/vita/domains/ota/tools_schema.py | 780 +++++++++++ .../src/vita/evaluator/evaluator_traj.py | 737 +++++++++++ .../src/vita/orchestrator/orchestrator.py | 440 +++++++ .../src/vita/prompts/agent_system_prompt.yaml | 28 + .../prompts/solo_agent_system_prompt.yaml | 34 + .../VitaBench/src/vita/run.py | 920 +++++++++++++ .../VitaBench/src/vita/user/user_simulator.py | 271 ++++ .../VitaBench/src/vita/utils/utils.py | 240 ++++ 12 files changed, 5534 insertions(+) create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/agent/llm_agent.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/cli.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/data_model/simulation.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools_schema.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/evaluator/evaluator_traj.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/orchestrator/orchestrator.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/prompts/agent_system_prompt.yaml create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/prompts/solo_agent_system_prompt.yaml create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/run.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/user/user_simulator.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/utils/utils.py diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/agent/llm_agent.py b/examples/AgenticBenchmarks/VitaBench/src/vita/agent/llm_agent.py new file mode 100644 index 0000000..8a8bd0a --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/agent/llm_agent.py @@ -0,0 +1,225 @@ +from copy import deepcopy +from typing import List, Optional + +from loguru import logger +from pydantic import BaseModel + +from vita.agent.base import ( + LocalAgent, + ValidAgentInputMessage, + is_valid_agent_history_message, +) +from vita.data_model.message import ( + APICompatibleMessage, + AssistantMessage, + Message, + MultiToolMessage, + SystemMessage, +) +from vita.environment.tool import Tool +from vita.utils.llm_utils import generate +from vita.utils.utils import get_now, get_weekday +from vita.prompts import get_prompts + + +class LLMAgentState(BaseModel): + """The state of the agent.""" + + system_messages: list[SystemMessage] + messages: list[APICompatibleMessage] + + +class LLMAgent(LocalAgent[LLMAgentState]): + """ + An LLM agent that can be used to solve a task. + """ + + def __init__( + self, + tools: List[Tool], + domain_policy: str, + llm: Optional[str] = None, + llm_args: Optional[dict] = None, + time=None, + enable_think: bool = False, + language: str = None + ): + """ + Initialize the LLMAgent. + """ + super().__init__(tools=tools, domain_policy=domain_policy) + self.llm = llm + self.llm_args = deepcopy(llm_args) if llm_args is not None else {} + self.time = time + " " + get_weekday(time, language) + self.enable_think = enable_think + + @property + def system_prompt(self) -> str: + if self.time is not None: + return self.domain_policy.format( + time=self.time + ) + return self.domain_policy.format( + time=get_now("%Y-%m-%d %H:%M:%S") + ) + + def get_init_state( + self, message_history: Optional[list[Message]] = None + ) -> LLMAgentState: + """Get the initial state of the agent. + + Args: + message_history: The message history of the conversation. + + Returns: + The initial state of the agent. + """ + if message_history is None: + message_history = [] + assert all(is_valid_agent_history_message(m) for m in message_history), ( + "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." + ) + + + return LLMAgentState( + system_messages=[SystemMessage(role="system", content=self.system_prompt)], + messages=message_history, + ) + + def generate_next_message( + self, message: ValidAgentInputMessage, state: LLMAgentState + ) -> tuple[AssistantMessage, LLMAgentState]: + """ + Respond to a user or tool message. + """ + if isinstance(message, MultiToolMessage): + state.messages.extend(message.tool_messages) + else: + state.messages.append(message) + + messages = state.system_messages + state.messages + + assistant_message = generate( + model=self.llm, + tools=self.tools, + messages=messages, + enable_think=self.enable_think, + **self.llm_args, + ) + state.messages.append(assistant_message) + + return assistant_message, state + + + def set_seed(self, seed: int): + """Set the seed for the LLM.""" + if self.llm is None: + raise ValueError("LLM is not set") + cur_seed = self.llm_args.get("seed", None) + if cur_seed is not None: + logger.warning(f"Seed is already set to {cur_seed}, resetting it to {seed}") + self.llm_args["seed"] = seed + + + + +class LLMSoloAgent(LocalAgent[LLMAgentState]): + """ + An LLM agent that can be used to solve a task without any interaction with the customer. + The task need to specify a ticket format. + """ + + def __init__( + self, + tools: List[Tool], + domain_policy: str, + llm: Optional[str] = None, + llm_args: Optional[dict] = None, + time=None, + enable_think: bool = False, + language: str = None + ): + """ + Initialize the LLMAgent. + """ + super().__init__(tools=tools, domain_policy=domain_policy) + self.llm = llm + self.llm_args = deepcopy(llm_args) if llm_args is not None else {} + self.time = time + " " + get_weekday(time, language) + self.enable_think = enable_think + + @property + def system_prompt(self) -> str: + prompts = get_prompts() + if self.time is not None: + return prompts.solo_agent_system_prompt.format( + time=self.time + ) + return prompts.solo_agent_system_prompt.format( + time=get_now("%Y-%m-%d %H:%M:%S") + ) + + @classmethod + def is_stop(cls, message: AssistantMessage) -> bool: + """Check if the message is a stop message.""" + if message.content is None: + return False + return cls.STOP_TOKEN in message.content + + def get_init_state( + self, message_history: Optional[list[Message]] = None + ) -> LLMAgentState: + """Get the initial state of the agent. + + Args: + message_history: The message history of the conversation. + + Returns: + The initial state of the agent. + """ + if message_history is None: + message_history = [] + assert all(is_valid_agent_history_message(m) for m in message_history), ( + "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." + ) + return LLMAgentState( + system_messages=[SystemMessage(role="system", content=self.system_prompt)], + messages=message_history, + ) + + def generate_next_message( + self, message: Optional[ValidAgentInputMessage], state: LLMAgentState + ) -> tuple[AssistantMessage, LLMAgentState]: + """ + Respond to a user or tool message. + """ + # if isinstance(message, UserMessage): + # raise ValueError("LLMSoloAgent does not support user messages.") + if isinstance(message, MultiToolMessage): + state.messages.extend(message.tool_messages) + elif message is None: + assert len(state.messages) == 0, "Message history should be empty" + else: + state.messages.append(message) + messages = state.system_messages + state.messages + assistant_message = generate( + model=self.llm, + tools=self.tools, + messages=messages, + tool_choice="auto", + enable_think=self.enable_think, + **self.llm_args, + ) + if not assistant_message.is_tool_call() and not self.is_stop(assistant_message): + raise ValueError("LLMSoloAgent only supports tool calls before ###STOP###.") + state.messages.append(assistant_message) + return assistant_message, state + + def set_seed(self, seed: int): + """Set the seed for the LLM.""" + if self.llm is None: + raise ValueError("LLM is not set") + cur_seed = self.llm_args.get("seed", None) + if cur_seed is not None: + logger.warning(f"Seed is already set to {cur_seed}, resetting it to {seed}") + self.llm_args["seed"] = seed \ No newline at end of file diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/cli.py b/examples/AgenticBenchmarks/VitaBench/src/vita/cli.py new file mode 100644 index 0000000..a613101 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/cli.py @@ -0,0 +1,265 @@ +import argparse +from typing import get_args + +from vita.config import ( + DEFAULT_AGENT_IMPLEMENTATION, + DEFAULT_LOG_LEVEL, + DEFAULT_MAX_CONCURRENCY, + DEFAULT_MAX_ERRORS, + DEFAULT_MAX_STEPS, + DEFAULT_NUM_TRIALS, + DEFAULT_SEED, + DEFAULT_USER_IMPLEMENTATION, + DEFAULT_EVALUATION_TYPE, + DEFAULT_LANGUAGE, + DEFAULT_LLM_AGENT, + DEFAULT_LLM_USER, + DEFAULT_LLM_EVALUATOR, + models, +) +from vita.data_model.simulation import RunConfig, EvaluationType +from vita.run import get_options, run_domain + + +def add_run_args(parser): + """Add run arguments to a parser.""" + parser.add_argument( + "--domain", + "-d", + type=str, + default="delivery,instore,ota", + help="The domain to run the simulation on", + ) + parser.add_argument( + "--num-trials", + type=int, + default=DEFAULT_NUM_TRIALS, + help="The number of times each task is run. Default is 1.", + ) + parser.add_argument( + "--agent", + type=str, + default=DEFAULT_AGENT_IMPLEMENTATION, + choices=get_options().agents, + help=f"The agent implementation to use. Default is {DEFAULT_AGENT_IMPLEMENTATION}.", + ) + parser.add_argument( + "--agent-llm", + type=str, + default=DEFAULT_LLM_AGENT, + help=f"The LLM to use for the agent. Default is {DEFAULT_LLM_AGENT}.", + ) + parser.add_argument( + "--agent-llm-args", + type=dict, + default={}, + help=f"The arguments to pass to the LLM for the agent.", + ) + parser.add_argument( + "--user", + type=str, + choices=get_options().users, + default=DEFAULT_USER_IMPLEMENTATION, + help=f"The user implementation to use. Default is {DEFAULT_USER_IMPLEMENTATION}.", + ) + parser.add_argument( + "--user-llm", + type=str, + default=DEFAULT_LLM_USER, + help=f"The LLM to use for the user. Default is {DEFAULT_LLM_USER}.", + ) + parser.add_argument( + "--user-llm-args", + type=dict, + default={}, + help=f"The arguments to pass to the LLM for the user.", + ) + parser.add_argument( + "--task-set-name", + type=str, + default=None, + choices=get_options().task_sets, + help="The task set to run the simulation on. If not provided, will load default task set for the domain.", + ) + parser.add_argument( + "--task-ids", + type=str, + nargs="+", + help="(Optional) run only the tasks with the given IDs. If not provided, will run num_tasks tasks.", + ) + parser.add_argument( + "--num-tasks", + type=int, + default=None, + help="The number of tasks to run.", + ) + parser.add_argument( + "--max-steps", + type=int, + default=DEFAULT_MAX_STEPS, + help=f"The maximum number of steps to run the simulation. Default is {DEFAULT_MAX_STEPS}.", + ) + parser.add_argument( + "--evaluation-type", + type=str, + default=DEFAULT_EVALUATION_TYPE, + choices=get_args(EvaluationType), + help=f"The type of evaluation to use. Choices: trajectory, trajectory_full_traj_rubric, trajectory_sliding_wo_rubric, trajectory_full_traj_wo_rubric.", + ) + parser.add_argument( + "--evaluator-llm", + type=str, + default=DEFAULT_LLM_EVALUATOR, + help=f"The LLM to use for evaluation. Default is {DEFAULT_LLM_EVALUATOR}.", + ) + parser.add_argument( + "--evaluator-llm-args", + type=dict, + default={}, + help=f"The arguments to pass to the LLM for evaluation", + ) + parser.add_argument( + "--max-errors", + type=int, + default=DEFAULT_MAX_ERRORS, + help=f"The maximum number of tool errors allowed in a row in the simulation. Default is {DEFAULT_MAX_ERRORS}.", + ) + parser.add_argument( + "--save-to", + type=str, + required=False, + help="The path to save the simulation results. Will be saved to data/simulations/.json. If not provided, will save to _____.json. If the file already exists, it will try to resume the run.", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=DEFAULT_MAX_CONCURRENCY, + help=f"The maximum number of concurrent simulations to run. Default is {DEFAULT_MAX_CONCURRENCY}.", + ) + parser.add_argument( + "--seed", + type=int, + default=DEFAULT_SEED, + help=f"The seed to use for the simulation. Default is {DEFAULT_SEED}.", + ) + parser.add_argument( + "--log-level", + type=str, + default=DEFAULT_LOG_LEVEL, + help=f"The log level to use for the simulation. Default is {DEFAULT_LOG_LEVEL}.", + ) + parser.add_argument( + "--re-evaluate-file", + type=str, + help="Path to simulation file for re-evaluation mode. If provided, will re-evaluate the simulations from this file instead of running new ones.", + ) + parser.add_argument( + "--csv-output", + type=str, + help="Path to CSV file to append results. If provided, will append all simulation results to this CSV file after completion.", + ) + parser.add_argument( + "--enable-think", + action="store_true", + help="Enable think mode for the agent. Default is False.", + ) + parser.add_argument( + "--language", + type=str, + choices=["chinese", "english"], + default=DEFAULT_LANGUAGE, + help="The language to use for prompts and tasks. Choices: chinese, english. Default is chinese.", + ) + parser.add_argument( + "--re-run", + action="store_true", + help="Re-run tasks specified by --task-ids. If used with --re-evaluate-file, will re-run specified tasks and then re-evaluate all tasks together.", + ) + + + +def main(): + parser = argparse.ArgumentParser(description="vita command line interface") + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Run command + run_parser = subparsers.add_parser("run", help="Run a benchmark") + add_run_args(run_parser) + run_parser.set_defaults( + func=lambda args: run_domain( + RunConfig( + domain=args.domain, + task_set_name=args.task_set_name, + task_ids=args.task_ids, + num_tasks=args.num_tasks, + agent=args.agent, + llm_agent=args.agent_llm, + llm_args_agent=args.agent_llm_args if args.agent_llm_args else models.get(args.agent_llm, {}), + user=args.user, + llm_user=args.user_llm, + llm_args_user=args.user_llm_args if args.user_llm_args else models.get(args.user_llm, {}), + num_trials=args.num_trials, + max_steps=args.max_steps, + evaluation_type=args.evaluation_type, + llm_evaluator=args.evaluator_llm, + llm_args_evaluator=args.evaluator_llm_args if args.evaluator_llm_args else models.get(args.evaluator_llm, {}), + max_errors=args.max_errors, + save_to=args.save_to, + max_concurrency=args.max_concurrency, + seed=args.seed, + log_level=args.log_level, + re_evaluate_file=getattr(args, 're_evaluate_file', None), + csv_output_file=getattr(args, 'csv_output', None), + enable_think=args.enable_think, + language=args.language, + re_run=getattr(args, 're_run', False) + ) + ) + ) + + # View command + view_parser = subparsers.add_parser("view", help="View simulation results") + view_parser.add_argument( + "--file", + type=str, + help="Path to the simulation results file to view", + ) + view_parser.add_argument( + "--only-show-failed", + action="store_true", + help="Only show failed tasks.", + ) + view_parser.add_argument( + "--only-show-all-failed", + action="store_true", + help="Only show tasks that failed in all trials.", + ) + view_parser.set_defaults(func=lambda args: run_view_simulations(args)) + + # Domain command + domain_parser = subparsers.add_parser("domain", help="Show domain documentation") + domain_parser.add_argument( + "domain", + type=str, + help="Name of the domain to show documentation for (e.g., 'ota', 'delivery', 'instore')", + ) + + args = parser.parse_args() + if not hasattr(args, "func"): + parser.print_help() + return + + args.func(args) + + +def run_view_simulations(args): + from vita.scripts.view_simulations import main as view_main + + view_main( + sim_file=args.file, + only_show_failed=args.only_show_failed, + only_show_all_failed=args.only_show_all_failed, + ) + +if __name__ == "__main__": + main() diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/data_model/simulation.py b/examples/AgenticBenchmarks/VitaBench/src/vita/data_model/simulation.py new file mode 100644 index 0000000..b0a5587 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/data_model/simulation.py @@ -0,0 +1,433 @@ +import json +from copy import deepcopy +from enum import Enum +from pathlib import Path +from typing import Optional, Any, Literal, Dict + +import pandas as pd +from pydantic import BaseModel, Field +from typing_extensions import Annotated + +from vita.config import ( + DEFAULT_LLM_AGENT, + DEFAULT_LLM_USER, + DEFAULT_LOG_LEVEL, + DEFAULT_MAX_CONCURRENCY, + DEFAULT_MAX_ERRORS, + DEFAULT_MAX_STEPS, + DEFAULT_EVALUATION_TYPE, + DEFAULT_NUM_TRIALS, + DEFAULT_SAVE_TO, + DEFAULT_SEED, + DEFAULT_LLM_EVALUATOR, + DEFAULT_LANGUAGE, + models +) +from vita.data_model.message import Message +from vita.data_model.tasks import RewardType, Task +from vita.environment.environment import EnvironmentInfo +from vita.utils.utils import get_now + + +EvaluationType = Literal[ + "trajectory", + "trajectory_full_traj_rubric", + "trajectory_sliding_wo_rubric", + "trajectory_full_traj_wo_rubric" +] + + +class RunConfig(BaseModel): + domain: Annotated[ + str, + Field( + description="The domain to run the simulation on", + default="ota", + ), + ] + task_set_name: Annotated[ + Optional[str], + Field( + description="The task set to run the simulation on. If not provided, will load default task set for the domain.", + default=None, + ), + ] + task_ids: Annotated[ + Optional[list[str]], + Field( + description="The task IDs to run the simulation on", + default=None, + ), + ] + num_tasks: Annotated[ + Optional[int], + Field( + description="The number of tasks to run the simulation on", + default=None, + ), + ] + is_remote: Annotated[ + bool, + Field( + description="Whether to run the simulation remotely", + default=False, + ), + ] + agent: Annotated[ + str, + Field( + description="The type of agent to run the simulation on", + default="llm_agent", + ), + ] + llm_agent: Annotated[ + str, + Field( + description="The model to use for the agent", + default=DEFAULT_LLM_AGENT, + ), + ] + llm_args_agent: Annotated[ + dict, + Field( + description="The arguments to pass to the LLM for the agent", + default_factory=lambda: deepcopy(models[DEFAULT_LLM_AGENT]), + ), + ] + user: Annotated[ + str, + Field( + description="The type of user to run the simulation on", + default="user_simulator", + ), + ] + llm_user: Annotated[ + str, + Field( + description="The model to use for the user", + default=DEFAULT_LLM_USER, + ), + ] + llm_args_user: Annotated[ + dict, + Field( + description="The arguments to pass to the LLM for the user", + default_factory=lambda: deepcopy(models[DEFAULT_LLM_USER]), + ), + ] + num_trials: Annotated[ + int, + Field( + description="The number of trials to run the simulation on", + default=DEFAULT_NUM_TRIALS, + ), + ] + max_steps: Annotated[ + int, + Field( + description="The maximum number of steps to run the simulation", + default=DEFAULT_MAX_STEPS, + ), + ] + evaluation_type: Annotated[ + EvaluationType, + Field( + description="The type of evaluation to use. Choices: trajectory, trajectory_full_traj_rubric, trajectory_sliding_wo_rubric, trajectory_full_traj_wo_rubric.", + default=DEFAULT_EVALUATION_TYPE, + ), + ] + max_errors: Annotated[ + int, + Field( + description="The maximum number of tool errors allowed in a row in the simulation", + default=DEFAULT_MAX_ERRORS, + ), + ] + save_to: Annotated[ + Optional[str], + Field( + description="The path to json file where to save the simulation results", + default=DEFAULT_SAVE_TO, + ), + ] + max_concurrency: Annotated[ + int, + Field( + description="The maximum number of concurrent simulations to run", + default=DEFAULT_MAX_CONCURRENCY, + ), + ] + seed: Annotated[ + Optional[int], + Field( + description="The seed to use for the simulation", + default=DEFAULT_SEED, + ), + ] + log_level: Annotated[ + Optional[str], + Field( + description="The log level to use for the simulation", + default=DEFAULT_LOG_LEVEL, + ), + ] + re_evaluate_file: Annotated[ + Optional[str], + Field( + description="Path to simulation file for re-evaluation mode", + default=None, + ), + ] + csv_output_file: Annotated[ + Optional[str], + Field( + description="Path to csv output file for result analysis", + default=None, + ), + ] + enable_think: Annotated[ + bool, + Field( + description="Whether to enable think step for the agent", + default=False, + ), + ] + language: Annotated[ + str, + Field( + description="The language to use for prompts and tasks. Choices: chinese, english", + default=DEFAULT_LANGUAGE, + ), + ] + llm_evaluator: Annotated[ + str, + Field( + description="The LLM to use for evaluation", + default=DEFAULT_LLM_EVALUATOR, + ), + ] + llm_args_evaluator: Annotated[ + dict, + Field( + description="The arguments to pass to the LLM for evaluation", + default_factory=lambda: deepcopy(models[DEFAULT_LLM_EVALUATOR]), + ), + ] + re_run: Annotated[ + bool, + Field( + description="Whether to re-run tasks specified by task_ids. If used with re_evaluate_file, will re-run specified tasks and then re-evaluate all tasks together.", + default=False, + ), + ] + + def validate(self) -> None: + """ + Validate the run config + """ + # Validate re_run parameter usage + if self.re_run: + if not self.re_evaluate_file: + raise ValueError("--re-run can only be used with --re-evaluate-file") + if not self.task_ids: + raise ValueError("--re-run requires --task-ids to specify which tasks to re-run") + + +class NLRubricCheck(BaseModel): + """ + A natural language assertion. + """ + + nl_rubric: Optional[str] = None + met: bool + justification: str + +class RewardInfo(BaseModel): + """ + The reward received by the agent. + """ + + reward: Annotated[float, Field(description="The reward received by the agent.")] + nl_rubrics: Annotated[ + Optional[list[NLRubricCheck]], + Field(description="The natural language assertions.", default=None), + ] + reward_breakdown: Annotated[ + Optional[dict[RewardType, float]], + Field( + description="The breakdown of the reward.", + default=None, + ), + ] + info: Annotated[ + Optional[dict], + Field(description="Additional information about the reward.", default=None), + ] + window_evaluations: Annotated[ + Optional[list[dict]], + Field(description="Detailed evaluation information for each sliding window.", default=None), + ] + + +class AgentInfo(BaseModel): + """ + Agent information. + """ + + implementation: str = Field(description="The type of agent.") + llm: Optional[str] = Field(description="The LLM used by the agent.", default=None) + llm_args: Optional[dict] = Field( + description="The arguments to pass to the LLM for the agent.", default=None + ) + + +class UserInfo(BaseModel): + """ + User information. + """ + + implementation: str = Field(description="The type of user.") + llm: Optional[str] = Field(description="The LLM used by the user.", default=None) + llm_args: Optional[dict] = Field( + description="The arguments to pass to the LLM for the user.", default=None + ) + global_simulation_guidelines: Optional[str] = Field( + description="The global simulation guidelines for the user.", default=None + ) + + +class Info(BaseModel): + """Information about the simulator.""" + + git_commit: str = Field(description="The git commit hash.") + num_trials: int = Field(description="The number of trials.") + max_steps: int = Field(description="The maximum number of steps.") + max_errors: int = Field(description="The maximum number of errors.") + user_info: UserInfo = Field(description="User information.") + agent_info: AgentInfo = Field(description="Agent information.") + environment_info: EnvironmentInfo = Field(description="Environment information.") + seed: Optional[int] = Field( + description="The seed used for the simulation.", default=None + ) + + +class TerminationReason(str, Enum): + USER_STOP = "user_stop" + AGENT_STOP = "agent_stop" + MAX_STEPS = "max_steps" + TOO_MANY_ERRORS = "too_many_errors" + INVALID_AGENT_MESSAGE = "invalid_agent_message" + + +class SimulationRun(BaseModel): + """ + Simulation run for the given task. + """ + + id: str = Field(description="The unique identifier for the simulation run.") + task_id: str = Field(description="The unique identifier for the task.") + timestamp: str = Field( + description="The timestamp of the simulation.", default_factory=get_now + ) + start_time: str = Field(description="The start time of the simulation.") + end_time: str = Field(description="The end time of the simulation.") + duration: float = Field(description="The duration of the simulation.") + termination_reason: TerminationReason = Field( + description="The reason for the termination of the simulation." + ) + agent_cost: Optional[float] = Field( + description="The cost of the agent.", default=None + ) + user_cost: Optional[float] = Field( + description="The cost of the user.", default=None + ) + reward_info: Optional[RewardInfo] = Field( + description="The reward received by the agent.", default=None + ) + messages: list[Message] = Field( + description="The messages exchanged between the user, agent and environment." + ) + states: Dict[str, Any] = Field( + description="The final state, including old states and new states", default={"old_states": [], "new_states": []} + ) + trial: Optional[int] = Field(description="Trial number", default=None) + seed: Optional[int] = Field( + description="Seed used for the simulation.", default=None + ) + + + +class Results(BaseModel): + """ + Run results + """ + + timestamp: Optional[str] = Field( + description="The timestamp of the simulation.", default_factory=get_now + ) + info: Info = Field(description="Information.") + tasks: list[Task] = Field(description="The list of tasks.") + simulations: list[SimulationRun] = Field(description="The list of simulations.") + + @classmethod + def load(cls, path: Path) -> "Results": + with open(path, "r") as f: + return cls.model_validate_json(f.read()) + + def save(self, path: Path) -> None: + """ + Save the results to a file. + """ + with open(path, "w", encoding="utf-8") as f: + f.write(self.model_dump_json(indent=4)) + + def to_df(self) -> pd.DataFrame: + """ + Convert a Results object to a pandas DataFrame. + """ + def clean_value_for_dataframe(value): + """Clean a value to ensure it's compatible with pandas DataFrame""" + if value is None: + return None + try: + # Try to serialize to check if it's compatible + json.dumps(value) + return value + except (TypeError, ValueError): + # If not serializable, convert to string + if hasattr(value, 'tolist'): + try: + return value.tolist() + except: + return str(value) + else: + return str(value) + + rows = [] + for sim in self.simulations: + row = { + "simulation_id": sim.id, + "task_id": sim.task_id, + "trial": sim.trial, + "seed": sim.seed, + "reward": sim.reward_info.reward, + "agent_cost": sim.agent_cost, + "user_cost": sim.user_cost, + "termination_reason": sim.termination_reason, + "duration": sim.duration, + "num_messages": len(sim.messages), + "info_git_commit": self.info.git_commit, + "info_seed": self.info.seed, + "info_num_trials": self.info.num_trials, + "info_max_steps": self.info.max_steps, + "info_max_errors": self.info.max_errors, + "info_domain": self.info.environment_info.domain_name, + "info_user_implementation": self.info.user_info.implementation, + "info_user_llm": self.info.user_info.llm, + "info_user_llm_args": clean_value_for_dataframe(self.info.user_info.llm_args), + "info_agent_implementation": self.info.agent_info.implementation, + "info_agent_llm": self.info.agent_info.llm, + "info_agent_llm_args": clean_value_for_dataframe(self.info.agent_info.llm_args), + } + rows.append(row) + return pd.DataFrame(rows) diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools.py new file mode 100644 index 0000000..07c29ea --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools.py @@ -0,0 +1,1161 @@ +"""Toolkit for the OTA domain.""" + +from typing import List, Optional, Union + +from vita.domains.ota.data_model import ( + OTADB, + Hotel, Attraction, Flight, Train, HotelProduct, AttractionProduct, FlightProduct, + TrainProduct, + OTAOrderStatus +) +from vita.data_model.tasks import Order +from vita.environment.toolkit import ToolKitBase, ToolType, is_tool +from vita.utils.utils import check_date_format, rerank, fuzzy_match + + +class OTATools(ToolKitBase): + """All the tools for the OTA domain.""" + + db: OTADB + + def __init__(self, db: OTADB) -> None: + super().__init__(db) + + def _check_user(self, user_id: str) -> bool: + """Check if the user is valid. + Args: + user_id: The user id + Returns: + bool: True if the user is valid, False otherwise + """ + if user_id != self.db.user_id: + return False + return True + + def _get_hotel_tags(self, hotel_id: str) -> str: + """Get the hotel tags for a specific hotel. + Args: + hotel_id: The hotel id + + Returns: + str: hotel_name + ',' + tags + """ + hotel = self.db.hotels[hotel_id] + return hotel.hotel_name + ','.join(hotel.tags) + + def _get_attraction_tags(self, attraction_id: str) -> str: + """Get the attraction tags for a specific attraction. + Args: + attraction_id: The attraction id + + Returns: + str: attraction_name + ',' + description + ',' + location + """ + attraction = self.db.attractions[attraction_id] + return f"{attraction.attraction_name},{attraction.description},{attraction.location.address}" + + + def _add_ota_order(self, order: Order) -> str: + """Add order to both shared database and domain-specific database. + + Args: + order: The order to add + + Returns: + "done" if successful, error message otherwise + """ + if order.order_id in self.db.orders: + return "Order already exists" + + self.db.orders[order.order_id] = order + + return "done" + + def _modify_ota_order(self, order: Order) -> str: + """Modify order in both shared database and domain-specific database. + + Args: + order: The order to modify + + Returns: + "done" if successful, error message otherwise + """ + if order.order_id not in self.db.orders: + return "Order not found" + + self.db.orders[order.order_id] = order + + return "done" + + def _get_ota_order(self, order_id: Optional[str] = None, scene: Optional[str] = None) -> Union[Order, List[Order]]: + """Get the order from the database. + + Args: + order_id: The order id + scene: The scene of the order + Returns: + The order. + + Raises: + ValueError: If the order is not found. + """ + if scene: + return [order for order in self.db.orders.values() if order.order_type == scene] + elif order_id: + if order_id not in self.db.orders: + raise ValueError(f"Order {order_id} not found") + return self.db.orders[order_id] + else: + return [order for order in self.db.orders.values() if order.order_type in ["hotel", "attraction", "flight", "train"]] + + def _get_hotel(self, hotel_id: Optional[str] = None) -> Union[Hotel, List[Hotel]]: + """Get the hotel from the database. + Args: + hotel_id: The hotel id, such as '6086499569'. + + Returns: + The hotel. + """ + if hotel_id is None: + return list(self.db.hotels.values()) + if hotel_id not in self.db.hotels: + raise ValueError(f"hotel {hotel_id} not found") + return self.db.hotels[hotel_id] + + def _get_attraction(self, attraction_id: Optional[str] = None) -> Union[Attraction, List[Attraction]]: + """Get the attraction from the database. + Args: + attraction_id: The attraction id. + + Returns: + The attraction. + """ + if attraction_id is None: + return list(self.db.attractions.values()) + if attraction_id not in self.db.attractions: + raise ValueError(f"attraction {attraction_id} not found") + return self.db.attractions[attraction_id] + + def _get_flight(self, flight_id: Optional[str] = None) -> Union[Flight, List[Flight]]: + """Get the flight from the database. + Args: + flight_id: The flight id. + + Returns: + The flight. + """ + if flight_id is None: + return list(self.db.flights.values()) + if flight_id not in self.db.flights: + raise ValueError(f"flight {flight_id} not found") + return self.db.flights[flight_id] + + def _get_train(self, train_id: Optional[str] = None) -> Union[Train, List[Train]]: + """Get the train from the database. + Args: + train_id: The train id. + + Returns: + The train. + """ + if train_id is None: + return list(self.db.trains.values()) + if train_id not in self.db.trains: + raise ValueError(f"train {train_id} not found") + return self.db.trains[train_id] + + @is_tool(ToolType.READ) + def get_ota_hotel_info(self, hotel_id: str) -> str: + + assert hotel_id, "Hotel ID cannot be empty" + + try: + hotel = self._get_hotel(hotel_id) + return f"Hotel Info:\n{repr(hotel)}" + except ValueError as e: + return f"Error: {e}" + + @is_tool(ToolType.READ) + def get_ota_attraction_info(self, attraction_id: str) -> str: + + assert attraction_id, "Attraction ID cannot be empty" + + try: + attraction = self._get_attraction(attraction_id) + return f"Attraction Info:\n{repr(attraction)}" + except ValueError as e: + return f"Error: {e}" + + @is_tool(ToolType.READ) + def get_ota_flight_info(self, flight_id: str) -> str: + + assert flight_id, "Flight ID cannot be empty" + + try: + flight = self._get_flight(flight_id) + return f"Flight Info:\n{repr(flight)}" + except ValueError as e: + return f"Error: {e}" + + @is_tool(ToolType.READ) + def get_ota_train_info(self, train_id: str) -> str: + + assert train_id, "Train ID cannot be empty" + + try: + train = self._get_train(train_id) + return f"Train Info:\n{repr(train)}" + except ValueError as e: + return f"Error: {e}" + + @is_tool(tool_type=ToolType.READ) + def hotel_search_recommend(self, + city_name: str, + key_words: Optional[List[str]] = None) -> str: + assert city_name, "City name cannot be empty" + assert isinstance(city_name, str), "City name must be a string" + + if key_words is not None: + assert isinstance(key_words, list), "Key words must be a list" + assert all(isinstance(kw, str) and kw.strip() for kw in key_words), "All key words must be non-empty strings" + + try: + target_hotels = [] + for hotel in self._get_hotel(): + if not fuzzy_match(city_name, hotel.location.address): + continue + target_hotels.append(hotel) + + top_k = 50 + if not target_hotels: + return "No hotels found matching the criteria." + + hotel_tag_dict = {} + for hotel in target_hotels: + hotel_tag_dict[hotel.hotel_id] = self._get_hotel_tags(hotel.hotel_id) + + keywords_str = "".join(key_words or []) + assert keywords_str and keywords_str.strip(), "Keywords cannot be empty" + id_candidates_sorted = rerank(keywords_str, hotel_tag_dict) + selected_ids = [ic[0] for ic in id_candidates_sorted[:top_k]] + + if not selected_ids: + return "No hotels found matching the keywords" + + selected_hotels = [str(self._get_hotel(hotel_id)) for hotel_id in selected_ids] + selected_hotels_repr = "\n".join(selected_hotels) + return selected_hotels_repr + except Exception as e: + return f"Error searching hotels: {e}" + + @is_tool(tool_type=ToolType.READ) + def attractions_search_recommend(self, city_name: str, key_words: List[str]) -> str: + assert city_name, "City name cannot be empty" + assert isinstance(city_name, str), "City name must be a string" + assert key_words, "Key words cannot be empty" + assert isinstance(key_words, list), "Key words must be a list" + assert all(isinstance(kw, str) and kw.strip() for kw in key_words), "All key words must be non-empty strings" + + try: + target_attractions = [] + for attraction in self._get_attraction(): + if not fuzzy_match(city_name, attraction.location.address): + continue + target_attractions.append(attraction) + + top_k = 50 + if not target_attractions: + return "No attractions found matching the criteria." + + attraction_tag_dict = {} + for attraction in target_attractions: + attraction_tag_dict[attraction.attraction_id] = self._get_attraction_tags(attraction.attraction_id) + + keywords_str = "".join(key_words) + assert keywords_str and keywords_str.strip(), "Keywords cannot be empty" + id_candidates_sorted = rerank(keywords_str, attraction_tag_dict) + selected_ids = [ic[0] for ic in id_candidates_sorted[:top_k]] + + if not selected_ids: + return "No attractions found matching the keywords" + + selected_attractions = [str(self._get_attraction(attraction_id)) for attraction_id in selected_ids] + selected_attractions_repr = "\n".join(selected_attractions) + return selected_attractions_repr + except Exception as e: + return f"Error searching attractions: {e}" + + @is_tool(tool_type=ToolType.READ) + def flight_search_recommend(self, departure: str, destination: str) -> str: + assert departure, "Departure city cannot be empty" + assert destination, "Destination city cannot be empty" + assert isinstance(departure, str), "Departure city must be a string" + assert isinstance(destination, str), "Destination city must be a string" + + try: + target_flights = [] + for flight in self._get_flight(): + if not fuzzy_match(departure, flight.departure_city): + continue + if not fuzzy_match(destination, flight.arrival_city): + continue + target_flights.append(flight) + + if not target_flights: + return "No flights found matching the criteria. Please check if the departure and destination cities are correct." + + flights_repr = "\n".join([str(flight) for flight in target_flights]) + return flights_repr + except Exception as e: + return f"Error searching flights: {e}" + + @is_tool(tool_type=ToolType.READ) + def train_ticket_search(self, departure: str, destination: str, date: str) -> str: + + assert departure, "Departure city cannot be empty" + assert destination, "Destination city cannot be empty" + assert date, "Departure date cannot be empty" + assert check_date_format(date), "Date format is incorrect, correct format is %Y-%m-%d" + + target_trains = [] + for train in self._get_train(): + # Check if train's departure date is date, where date is in train's products class (TrainProduct), check fields and note implementation logic + for product in train.products: + if product.date == date: + if not fuzzy_match(departure, train.departure_city): + continue + if not fuzzy_match(destination, train.arrival_city): + continue + target_trains.append(train) + + if not target_trains: + return "No trains found matching the criteria" + + trains_repr = "\n".join([str(train) for train in target_trains]) + return trains_repr + + @is_tool(tool_type=ToolType.WRITE) + def create_hotel_order(self, hotel_id: str, room_id: str, user_id: str) -> str: + assert hotel_id, "Hotel ID cannot be empty" + assert room_id, "Room ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert self._check_user(user_id), "User ID does not match" + + try: + hotel = self._get_hotel(hotel_id) + except ValueError as e: + return f"Error: {e}" + + ordered_rooms = [] + room_found = False + for product in hotel.products: + if product.product_id == room_id: + room_found = True + if product.quantity <= 0: + return f"No available rooms at the moment for room {room_id}" + product.quantity = product.quantity - 1 + ordered_room = HotelProduct( + product_id=product.product_id, + price=product.price, + date=product.date, + quantity=1, + room_type=product.room_type + ) + ordered_rooms.append(ordered_room) + break + + if not room_found: + return f"Room {room_id} not found in hotel {hotel_id}" + + order = Order( + order_id=self.db.assign_order_id("hotel", user_id, hotel_id=hotel_id, product_id=room_id), + order_type="hotel", + user_id=user_id, + store_id=hotel_id, + total_price=sum([room.price for room in ordered_rooms]), + update_time=self.get_now("%Y-%m-%d %H:%M:%S"), + create_time=self.get_now("%Y-%m-%d %H:%M:%S"), + status="unpaid", + products=ordered_rooms + ) + + response = self._add_ota_order(order) + if response == "done": + return repr(order) + else: + return f"Failed to create order: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def create_attraction_order(self, attraction_id: str, ticket_id: str, user_id: str, date: str, quantity: int) -> str: + assert attraction_id, "Attraction ID cannot be empty" + assert ticket_id, "Ticket ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert date, "Date cannot be empty" + assert isinstance(quantity, int), "Quantity must be an integer" + assert quantity > 0, "Booking quantity must be greater than 0" + assert check_date_format(date), "Date format is incorrect, correct format is %Y-%m-%d" + assert self._check_user(user_id), "User ID does not match" + + try: + attraction = self._get_attraction(attraction_id) + except ValueError as e: + return f"Error: {e}" + + target_product = None + for product in attraction.products: + if product.date == date and product.product_id == ticket_id: + target_product = product + break + + if target_product is None: + return f"The attraction {attraction_id} does not have ticket {ticket_id} on date {date}" + + if target_product.quantity < quantity: + return f"Insufficient ticket inventory for the specified date {date}. Available: {target_product.quantity}, Requested: {quantity}" + + ordered_tickets = [] + target_product.quantity = target_product.quantity - quantity + ordered_ticket = AttractionProduct( + product_id=target_product.product_id, + price=target_product.price, + date=date, + quantity=quantity, + ticket_type=target_product.ticket_type + ) + ordered_tickets.append(ordered_ticket) + + order = Order( + order_id=self.db.assign_order_id("attraction", user_id), + order_type="attraction", + user_id=user_id, + store_id=attraction_id, + total_price=sum([ticket.price * ticket.quantity for ticket in ordered_tickets]), + update_time=self.get_now("%Y-%m-%d %H:%M:%S"), + create_time=self.get_now("%Y-%m-%d %H:%M:%S"), + status="unpaid", + products=ordered_tickets + ) + + response = self._add_ota_order(order) + if response == "done": + return repr(order) + else: + return f"Failed to create order: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def create_flight_order(self, flight_id: str, seat_id: str, user_id: str, date: str, quantity: int) -> str: + assert flight_id, "Flight ID cannot be empty" + assert seat_id, "Seat ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert date, "Date cannot be empty" + assert isinstance(quantity, int), "Quantity must be an integer" + assert quantity > 0, "Booking quantity must be greater than 0" + assert check_date_format(date), "Date format is incorrect, correct format is %Y-%m-%d" + assert self._check_user(user_id), "User ID does not match" + + try: + flight = self._get_flight(flight_id) + except ValueError as e: + return f"Error: {e}" + + target_product = None + for product in flight.products: + if product.date == date and product.product_id == seat_id: + target_product = product + break + + if target_product is None: + return f"The flight {flight_id} does not have seat {seat_id} on date {date}" + + if target_product.quantity < quantity: + return f"Insufficient seat inventory for the specified date {date}. Available: {target_product.quantity}, Requested: {quantity}" + + ordered_seats = [] + target_product.quantity = target_product.quantity - quantity + ordered_seat = FlightProduct( + product_id=target_product.product_id, + price=target_product.price, + date=date, + quantity=quantity, + seat_type=target_product.seat_type, + ) + ordered_seats.append(ordered_seat) + + order = Order( + order_id=self.db.assign_order_id("flight", user_id), + order_type="flight", + user_id=user_id, + store_id=flight_id, + total_price=sum([seat.price * seat.quantity for seat in ordered_seats]), + update_time=self.get_now("%Y-%m-%d %H:%M:%S"), + create_time=self.get_now("%Y-%m-%d %H:%M:%S"), + status="unpaid", + products=ordered_seats + ) + + response = self._add_ota_order(order) + if response == "done": + return repr(order) + else: + return f"Failed to create order: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def create_train_order(self, train_id: str, seat_id: str, user_id: str, date: str, quantity: int) -> str: + assert train_id, "Train ID cannot be empty" + assert seat_id, "Seat ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert date, "Date cannot be empty" + assert isinstance(quantity, int), "Quantity must be an integer" + assert quantity > 0, "Booking quantity must be greater than 0" + assert check_date_format(date), "Date format is incorrect, correct format is %Y-%m-%d" + assert self._check_user(user_id), "User ID does not match" + + try: + train = self._get_train(train_id) + except ValueError as e: + return f"Error: {e}" + + target_product = None + for product in train.products: + if product.date == date and product.product_id == seat_id: + target_product = product + break + + if target_product is None: + return f"The train {train_id} does not have seat {seat_id} on date {date}" + + if target_product.quantity < quantity: + return f"Insufficient seat inventory for the specified date {date}. Available: {target_product.quantity}, Requested: {quantity}" + + ordered_seats = [] + target_product.quantity = target_product.quantity - quantity + ordered_seat = TrainProduct( + product_id=target_product.product_id, + price=target_product.price, + date=date, + quantity=quantity, + seat_type=target_product.seat_type + ) + ordered_seats.append(ordered_seat) + + order = Order( + order_id=self.db.assign_order_id("train", user_id), + order_type="train", + user_id=user_id, + store_id=train_id, + total_price=sum([seat.price * seat.quantity for seat in ordered_seats]), + update_time=self.get_now("%Y-%m-%d %H:%M:%S"), + create_time=self.get_now("%Y-%m-%d %H:%M:%S"), + status="unpaid", + products=ordered_seats + ) + + response = self._add_ota_order(order) + if response == "done": + return repr(order) + else: + return f"Failed to create order: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def pay_hotel_order(self, order_id: str) -> str: + assert order_id, "Order ID cannot be empty" + + try: + order = self._get_ota_order(order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "hotel": + return f"Order {order_id} is not a hotel order" + + if order.status != "unpaid": + return f"Order status must be unpaid. Current status: {order.status}" + + order.status = "paid" + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + return "Payment successful" + else: + return f"Payment failed: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def pay_attraction_order(self, order_id: str) -> str: + assert order_id, "Order ID cannot be empty" + + try: + order = self._get_ota_order(order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "attraction": + return f"Order {order_id} is not an attraction order" + + if order.status != "unpaid": + return f"Order status must be unpaid. Current status: {order.status}" + + order.status = "paid" + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + return "Payment successful" + else: + return f"Payment failed: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def pay_flight_order(self, order_id: str) -> str: + assert order_id, "Order ID cannot be empty" + + try: + order = self._get_ota_order(order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "flight": + return f"Order {order_id} is not a flight order" + + if order.status != "unpaid": + return f"Order status must be unpaid. Current status: {order.status}" + + order.status = "paid" + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + return "Payment successful" + else: + return f"Payment failed: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def pay_train_order(self, order_id: str) -> str: + assert order_id, "Order ID cannot be empty" + + try: + order = self._get_ota_order(order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "train": + return f"Order {order_id} is not a train order" + + if order.status != "unpaid": + return f"Order status must be unpaid. Current status: {order.status}" + + order.status = "paid" + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + return "Payment successful" + else: + return f"Payment failed: {response}" + + @is_tool(tool_type=ToolType.READ) + def search_hotel_order(self, user_id: str, date: Optional[str] = None, status: Optional[OTAOrderStatus] = "paid") -> str: + assert user_id, "User ID cannot be empty" + assert self._check_user(user_id), "User ID does not match" + + if date: + assert check_date_format(date), "Date format is incorrect, correct format is %Y-%m-%d" + + try: + hotel_orders = [] + for order in self._get_ota_order(scene="hotel"): + order_selected = None + if order.user_id == user_id: + order_selected = order + if status and order.status != status: + order_selected = None + if date and order_selected is not None: + if not hasattr(order_selected, 'products'): + order_selected = None + else: + has_date_product = False + for product in order_selected.products: + if hasattr(product, 'date') and product.date == date: + has_date_product = True + break + if not has_date_product: + order_selected = None + if order_selected: + hotel_orders.append(order_selected) + + if not hotel_orders: + date_filter = f" on date {date}" if date else "" + status_filter = f" with status {status}" if status else "" + return f"No hotel orders found for user {user_id}{date_filter}{status_filter}" + + orders_repr = "\n".join([str(order) for order in hotel_orders]) + return orders_repr + except Exception as e: + return f"Error searching hotel orders: {e}" + + @is_tool(tool_type=ToolType.READ) + def search_attraction_order(self, user_id: str, date: Optional[str] = None, + status: Optional[OTAOrderStatus] = "paid") -> str: + assert user_id, "User ID cannot be empty" + assert self._check_user(user_id), "User ID does not match" + + if date: + assert check_date_format(date), "Date format is incorrect, correct format is %Y-%m-%d" + + try: + attraction_orders = [] + for order in self._get_ota_order(scene="attraction"): + order_selected = None + if order.user_id == user_id: + order_selected = order + if status and order.status != status: + order_selected = None + if date and order_selected is not None: + if not hasattr(order_selected, 'products'): + order_selected = None + else: + has_date_product = False + for product in order_selected.products: + if hasattr(product, 'date') and product.date == date: + has_date_product = True + break + if not has_date_product: + order_selected = None + if order_selected: + attraction_orders.append(order_selected) + + if not attraction_orders: + date_filter = f" on date {date}" if date else "" + status_filter = f" with status {status}" if status else "" + return f"No attraction orders found for user {user_id}{date_filter}{status_filter}" + + orders_repr = "\n".join([str(order) for order in attraction_orders]) + return orders_repr + except Exception as e: + return f"Error searching attraction orders: {e}" + + @is_tool(tool_type=ToolType.READ) + def search_flight_order(self, user_id: str, date: Optional[str] = None, + status: Optional[OTAOrderStatus] = "paid") -> str: + assert user_id, "User ID cannot be empty" + assert self._check_user(user_id), "User ID does not match" + + if date: + assert check_date_format(date), "Date format is incorrect, correct format is %Y-%m-%d" + + try: + flight_orders = [] + for order in self._get_ota_order(scene="flight"): + order_selected = None + if order.user_id == user_id: + order_selected = order + if status and order.status != status: + order_selected = None + if date and order_selected is not None: + if not hasattr(order_selected, 'products'): + order_selected = None + else: + has_date_product = False + for product in order_selected.products: + if hasattr(product, 'date') and product.date == date: + has_date_product = True + break + if not has_date_product: + order_selected = None + if order_selected: + flight_orders.append(order_selected) + + if not flight_orders: + date_filter = f" on date {date}" if date else "" + status_filter = f" with status {status}" if status else "" + return f"No flight orders found for user {user_id}{date_filter}{status_filter}" + + orders_repr = "\n".join([str(order) for order in flight_orders]) + return orders_repr + except Exception as e: + return f"Error searching flight orders: {e}" + + + @is_tool(tool_type=ToolType.READ) + def search_train_order(self, user_id: str, date: Optional[str] = None, + status: Optional[OTAOrderStatus] = "paid") -> str: + assert user_id, "User ID cannot be empty" + assert self._check_user(user_id), "User ID does not match" + + if date: + assert check_date_format(date), "Date format is incorrect, correct format is %Y-%m-%d" + + try: + train_orders = [] + for order in self._get_ota_order(scene="train"): + order_selected = None + if order.user_id == user_id: + order_selected = order + if status and order.status != status: + order_selected = None + if date and order_selected is not None: + if not hasattr(order_selected, 'products'): + order_selected = None + else: + has_date_product = False + for product in order_selected.products: + if hasattr(product, 'date') and product.date == date: + has_date_product = True + break + if not has_date_product: + order_selected = None + if order_selected: + train_orders.append(order_selected) + + if not train_orders: + date_filter = f" on date {date}" if date else "" + status_filter = f" with status {status}" if status else "" + return f"No train orders found for user {user_id}{date_filter}{status_filter}" + + orders_repr = "\n".join([str(order) for order in train_orders]) + return orders_repr + except Exception as e: + return f"Error searching train orders: {e}" + + @is_tool(tool_type=ToolType.READ) + def get_hotel_order_detail(self, order_id: str) -> str: + assert order_id, "Order ID cannot be empty" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "hotel": + return f"Order {order_id} is not a hotel order" + + return repr(order) + + @is_tool(tool_type=ToolType.READ) + def get_attraction_order_detail(self, order_id: str) -> str: + assert order_id, "Order ID cannot be empty" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "attraction": + return f"Order {order_id} is not an attraction order" + + return repr(order) + + @is_tool(tool_type=ToolType.READ) + def get_flight_order_detail(self, order_id: str) -> str: + assert order_id, "Order ID cannot be empty" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "flight": + return f"Order {order_id} is not a flight order" + + return repr(order) + + @is_tool(tool_type=ToolType.READ) + def get_train_order_detail(self, order_id: str) -> str: + assert order_id, "Order ID cannot be empty" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "train": + return f"Order {order_id} is not a train order" + + return repr(order) + + @is_tool(tool_type=ToolType.WRITE) + def modify_train_order(self, order_id: str, user_id: str, new_date: str) -> str: + assert order_id, "Order ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert new_date, "New departure date cannot be empty" + assert self._check_user(user_id), "User ID does not match" + assert check_date_format(new_date), "Date format is incorrect, correct format is %Y-%m-%d" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "train": + return f"Order {order_id} is not a train order" + + if order.user_id != user_id: + return f"Order {order_id} does not belong to user {user_id}" + + if order.status != "paid": + return f"Only paid orders can be modified. Current status: {order.status}" + + if len(order.products) != 1: + return "Only single train ticket order modification is supported" + + old_product = order.products[0] + train_id = order.store_id + + try: + train = self._get_train(train_id) + except ValueError as e: + return f"Error: {e}" + + seat_type = old_product.get("seat_type") if isinstance(old_product, dict) else old_product.seat_type + quantity = old_product.get("quantity") if isinstance(old_product, dict) else old_product.quantity + + new_product = None + for product in train.products: + if product.date == new_date and product.seat_type == seat_type: + new_product = product + break + + if new_product is None: + return f"New date {new_date} does not have {seat_type} type seats" + + if new_product.quantity < quantity: + return f"Insufficient {seat_type} seat inventory for new date {new_date}. Available: {new_product.quantity}, Required: {quantity}" + + for product in train.products: + old_date = old_product.get("date") if isinstance(old_product, dict) else old_product.date + if product.date == old_date and product.seat_type == seat_type: + product.quantity += quantity + break + + new_product.quantity -= quantity + + old_price = old_product.get("price") if isinstance(old_product, dict) else old_product.price + old_total = old_price * quantity + new_total = new_product.price * quantity + diff = new_total - old_total + + if diff > 0: + order.status = "unpaid" + + order.products = [TrainProduct( + product_id=new_product.product_id, + price=new_product.price, + date=new_date, + seat_type=seat_type, + quantity=quantity + )] + order.total_price = new_total + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + if diff > 0: + return f"Modification successful, need to pay additional amount: {diff}." + else: + return f"Modification successful, price difference: {diff}, refunded." + else: + return f"Modification failed: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def modify_flight_order(self, order_id: str, user_id: str, new_date: str) -> str: + assert order_id, "Order ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert new_date, "New departure date cannot be empty" + assert check_date_format(new_date), "Date format is incorrect, correct format is %Y-%m-%d" + assert self._check_user(user_id), "User ID does not match" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "flight": + return f"Order {order_id} is not a flight order" + + if order.user_id != user_id: + return f"Order {order_id} does not belong to user {user_id}" + + if order.status != "paid": + return f"Only paid orders can be modified. Current status: {order.status}" + + if len(order.products) != 1: + return "Only single flight ticket order modification is supported" + + old_product = order.products[0] + flight_id = order.store_id + + try: + flight = self._get_flight(flight_id) + except ValueError as e: + return f"Error: {e}" + + seat_type = old_product.get("seat_type") if isinstance(old_product, dict) else old_product.seat_type + quantity = old_product.get("quantity") if isinstance(old_product, dict) else old_product.quantity + + new_product = None + for product in flight.products: + if product.date == new_date and product.seat_type == seat_type: + new_product = product + break + + if new_product is None: + return f"New date {new_date} does not have {seat_type} type seats" + + if new_product.quantity < quantity: + return f"Insufficient {seat_type} seat inventory for new date {new_date}. Available: {new_product.quantity}, Required: {quantity}" + + for product in flight.products: + old_date = old_product.get("date") if isinstance(old_product, dict) else old_product.date + if product.date == old_date and product.seat_type == seat_type: + product.quantity += quantity + break + + new_product.quantity -= quantity + + old_price = old_product.get("price") if isinstance(old_product, dict) else old_product.price + old_total = old_price * quantity + new_total = new_product.price * quantity + diff = new_total - old_total + + if diff > 0: + order.status = "unpaid" + + order.products = [FlightProduct( + product_id=new_product.product_id, + price=new_product.price, + date=new_date, + seat_type=seat_type, + quantity=quantity + )] + order.total_price = new_total + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + if diff > 0: + return f"Modification successful, need to pay additional amount: {diff}, please pay as soon as possible" + else: + return f"Modification successful, price difference: {diff}, refunded" + else: + return f"Modification failed: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def cancel_hotel_order(self, order_id: str, user_id: str) -> str: + assert order_id, "Order ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert self._check_user(user_id), "User ID does not match" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "hotel": + return f"Order {order_id} is not a hotel order" + + if order.user_id != user_id: + return f"Order {order_id} does not belong to user {user_id}" + + if order.status in ["cancelled"]: + return f"Order {order_id} is already cancelled" + + refund = 0 + if order.status == "paid": + refund = order.total_price + + order.status = "cancelled" + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + return f"Cancellation successful, refund amount: {refund}" + else: + return f"Cancellation failed: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def cancel_attraction_order(self, order_id: str, user_id: str) -> str: + assert order_id, "Order ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert self._check_user(user_id), "User ID does not match" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "attraction": + return f"Order {order_id} is not an attraction order" + + if order.user_id != user_id: + return f"Order {order_id} does not belong to user {user_id}" + + if order.status in ["cancelled"]: + return f"Order {order_id} is already cancelled" + + refund = 0 + if order.status == "paid": + refund = order.total_price + + order.status = "cancelled" + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + return f"Cancellation successful, refund amount: {refund}" + else: + return f"Cancellation failed: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def cancel_flight_order(self, order_id: str, user_id: str) -> str: + assert order_id, "Order ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert self._check_user(user_id), "User ID does not match" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "flight": + return f"Order {order_id} is not a flight order" + + if order.user_id != user_id: + return f"Order {order_id} does not belong to user {user_id}" + + if order.status in ["cancelled"]: + return f"Order {order_id} is already cancelled" + + refund = 0 + if order.status == "paid": + refund = order.total_price + + order.status = "cancelled" + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + return f"Cancellation successful, refund amount: {refund}" + else: + return f"Cancellation failed: {response}" + + @is_tool(tool_type=ToolType.WRITE) + def cancel_train_order(self, order_id: str, user_id: str) -> str: + assert order_id, "Order ID cannot be empty" + assert user_id, "User ID cannot be empty" + assert self._check_user(user_id), "User ID does not match" + + try: + order = self._get_ota_order(order_id=order_id) + except ValueError as e: + return f"Error: {e}" + + if order.order_type != "train": + return f"Order {order_id} is not a train order" + + if order.user_id != user_id: + return f"Order {order_id} does not belong to user {user_id}" + + if order.status in ["cancelled"]: + return f"Order {order_id} is already cancelled" + + refund = 0 + if order.status == "paid": + refund = order.total_price + + order.status = "cancelled" + order.update_time = self.get_now("%Y-%m-%d %H:%M:%S") + response = self._modify_ota_order(order) + if response == "done": + return f"Cancellation successful, refund amount: {refund}" + else: + return f"Cancellation failed: {response}" diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools_schema.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools_schema.py new file mode 100644 index 0000000..aaf8b3f --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools_schema.py @@ -0,0 +1,780 @@ +"""Tool schema definitions for the ota domain. + +This file contains the descriptions and mappings for all tools decorated with @is_tool +in the OTATools class. +""" + +from typing import Dict, Any +from vita.utils.schema_utils import create_tool_schema_manager + +# Tool descriptions extracted from tools.py - Chinese version +TOOL_DESCRIPTIONS_ZH = { + "get_ota_hotel_info": { + "description": "获取酒店信息,包含酒店id、名称、评分、星级、地址、标签、房间列表", + "preconditions": "在酒店查询预订场景,需要获取酒店的详细信息", + "postconditions": "返回酒店的详细信息,引导用户选择房间并下单", + "args": { + "hotel_id": "酒店id" + }, + "returns": "酒店信息", + "tool_type": "READ" + }, + + "get_ota_attraction_info": { + "description": "获取景点信息,包含景点id、名称、地址、描述、评分、开放时间、门票价格、票种列表", + "preconditions": "在景点旅游场景,需要获取景点的详细信息", + "postconditions": "返回景点的详细信息,引导用户选择门票并下单", + "args": { + "attraction_id": "景点id" + }, + "returns": "景点信息", + "tool_type": "READ" + }, + + "get_ota_flight_info": { + "description": "获取航班信息,包含航班id、航班号、出发城市、到达城市、出发机场位置、到达机场位置、出发时间、到达时间、航班标签、座位类型列表", + "preconditions": "在机票查询购买场景,需要获取航班的详细信息", + "postconditions": "返回航班的详细信息,引导用户选择座位并下单", + "args": { + "flight_id": "航班id" + }, + "returns": "航班信息", + "tool_type": "READ" + }, + + "get_ota_train_info": { + "description": "获取火车信息,包含火车id、车次、出发城市、到达城市、出发车站位置、到达车站位置、出发时间、到达时间、火车标签、座位类型列表", + "preconditions": "在火车票查询购买场景,需要获取火车的详细信息", + "postconditions": "返回火车的详细信息,引导用户选择座位并下单", + "args": { + "train_id": "火车id" + }, + "returns": "火车信息", + "tool_type": "READ" + }, + + "hotel_search_recommend": { + "description": "酒旅场景下,基于用户的地点需求和偏好,推荐合适的酒店选项,提供酒店的基础信息,包含酒店id、名称、评分、星级、地址、标签", + "preconditions": "用户请求预定酒店,给出了酒店相关的关键词或地点", + "postconditions": "返回符合条件的酒店列表,如需查看酒店详情(房间列表、价格等)需要使用酒店详情查询工具,引导用户选择酒店", + "args": { + "city_name": "城市名称", + "key_words": "搜索关键词(匹配酒店名称、酒店介绍等)" + }, + "returns": "结构化输出酒店基础信息", + "tool_type": "READ" + }, + + "attractions_search_recommend": { + "description": "基于用户的地点需求和偏好,推荐合适的景点选项,提供景点的基础信息,包含景点id、名称、地址、描述、评分、开放时间", + "preconditions": "用户请求预定景点,给出了景点相关的关键词或地点", + "postconditions": "返回符合条件的景点列表,如需查看景点详情(门票列表、价格等)需要使用景点详情查询工具,引导用户选择景点", + "args": { + "city_name": "城市名称", + "key_words": "搜索关键词(匹配景点名称、位置、地址、特色等)" + }, + "returns": "结构化输出景点基础信息", + "tool_type": "READ" + }, + + "flight_search_recommend": { + "description": "基于用户的地点需求和偏好,推荐合适的航班选项,提供航班的基础信息,包含航班id、航班号、出发城市、到达城市、出发机场位置、到达机场位置、出发时间、到达时间、航班标签", + "preconditions": "用户请求预定航班,给出了航班相关的关键词或地点", + "postconditions": "返回符合条件的航班列表,如需查看航班详情(座位类型列表、价格、日期等)需要使用航班详情查询工具,引导用户选择航班", + "args": { + "departure": "出发城市", + "destination": "目的城市" + }, + "returns": "结构化输出航班基础信息", + "tool_type": "READ" + }, + + "train_ticket_search": { + "description": "基于用户的地点需求和偏好,推荐合适的火车选项,提供火车票的基础信息,包含火车id、车次、出发城市、到达城市、出发车站位置、到达车站位置、出发时间、到达时间、火车标签", + "preconditions": "用户请求预定火车,给出了火车相关的关键词或地点", + "postconditions": "返回符合条件的火车票列表,如需查看火车票详情(座位类型列表、价格、日期等)需要使用火车票详情查询工具,引导用户选择火车票", + "args": { + "departure": "出发城市", + "destination": "目的城市", + "date": "出发日期" + }, + "returns": "结构化输出火车基础信息", + "tool_type": "READ" + }, + + "create_hotel_order": { + "description": "用户预订酒店时,系统根据用户的需求(如酒店名称、入住日期、人数等)生成订单", + "preconditions": "用户已登录并提供有效的身份标识(user_id),用户提供了有效的酒店名称(hotel_name)和房间类型(room_type),系统有关于目标酒店的信息,并且该酒店在所请求的日期内有房间", + "postconditions": "生成订单,请用户确认支付", + "args": { + "hotel_id": "酒店ID", + "room_id": "房间ID", + "user_id": "用户ID" + }, + "returns": "创建订单操作的反馈输出", + "tool_type": "WRITE" + }, + + "create_attraction_order": { + "description": "用户根据景点和日期购买门票,系统返回门票的相关信息并进行下单", + "preconditions": "在景点旅游场景,用户请求预定景点,给出了预定相关必要信息", + "postconditions": "生成订单,请用户确认支付", + "args": { + "attraction_id": "景点ID", + "ticket_id": "门票ID", + "user_id": "用户ID", + "date": "参观日期,格式为 %Y-%m-%d", + "quantity": "数量" + }, + "returns": "创建订单操作的反馈输出", + "tool_type": "WRITE" + }, + + "create_flight_order": { + "description": "用户根据航班号、日期、座位类型、数量购买机票,系统返回机票的相关信息并进行下单", + "preconditions": "在机票查询购买场景,用户请求预定航班,给出了预定相关必要信息", + "postconditions": "生成订单,请用户确认支付", + "args": { + "flight_id": "航班ID", + "seat_id": "座位ID", + "user_id": "用户ID", + "date": "出发日期", + "quantity": "数量" + }, + "returns": "创建订单操作的反馈输出", + "tool_type": "WRITE" + }, + + "create_train_order": { + "description": "用户根据车次、日期、座位类型、数量购买火车票,系统返回火车票的相关信息并进行下单", + "preconditions": "在火车票查询购买场景,用户请求预定火车,给出了预定相关必要信息", + "postconditions": "生成订单,请用户确认支付", + "args": { + "train_id": "火车ID", + "seat_id": "座位ID", + "user_id": "用户ID", + "date": "出发日期", + "quantity": "数量" + }, + "returns": "创建订单操作的反馈输出", + "tool_type": "WRITE" + }, + + "pay_hotel_order": { + "description": "用户进行酒店订单支付", + "preconditions": "在酒店查询预订场景,用户请求支付酒店订单,上文确定了订单ID", + "postconditions": "确认支付并更新订单状态为已支付", + "args": { + "order_id": "订单ID" + }, + "returns": "支付结果信息", + "tool_type": "WRITE" + }, + + "pay_attraction_order": { + "description": "用户进行门票订单支付", + "preconditions": "在景点旅游场景,用户请求支付景点门票订单,上文确定了订单ID", + "postconditions": "确认支付并更新订单状态为已支付", + "args": { + "order_id": "订单ID" + }, + "returns": "支付结果信息", + "tool_type": "WRITE" + }, + + "pay_flight_order": { + "description": "用户进行机票订单支付", + "preconditions": "在机票查询购买场景,用户请求支付航班订单,上文确定了订单ID", + "postconditions": "确认支付并更新订单状态为已支付", + "args": { + "order_id": "订单ID" + }, + "returns": "支付结果信息", + "tool_type": "WRITE" + }, + + "pay_train_order": { + "description": "用户进行火车票订单支付", + "preconditions": "在火车票查询购买场景,用户请求支付火车票订单,上文确定了订单ID", + "postconditions": "确认支付并更新订单状态为已支付", + "args": { + "order_id": "订单ID" + }, + "returns": "支付结果信息", + "tool_type": "WRITE" + }, + + "search_hotel_order": { + "description": "根据用户ID,查询用户的酒店订单,返回包含订单ID、订单类型、用户ID、酒店ID、订单总价、下单时间、更新时间和订单状态", + "preconditions": "用户需求为查询酒店订单", + "postconditions": "返回订单信息,方便之后进行修改/取消", + "args": { + "user_id": "用户ID", + "date": "日期", + "status": "订单状态" + }, + "returns": "指定用户的酒店订单信息", + "tool_type": "READ" + }, + + "search_attraction_order": { + "description": "根据用户ID,查询用户的景点门票订单,返回包含订单ID、订单类型、用户ID、景点ID、订单总价、下单时间、更新时间和订单状态", + "preconditions": "用户需求为查询景点门票订单", + "postconditions": "返回订单信息,方便之后进行修改/取消", + "args": { + "user_id": "用户ID", + "date": "日期", + "status": "订单状态" + }, + "returns": "指定用户的景点门票订单信息", + "tool_type": "READ" + }, + + "search_flight_order": { + "description": "根据用户ID,查询用户的机票订单,返回包含订单ID、订单类型、用户ID、航班ID、订单总价、下单时间、更新时间和订单状态", + "preconditions": "用户需求为查询机票订单", + "postconditions": "返回订单信息,方便之后进行修改/取消", + "args": { + "user_id": "用户ID", + "date": "日期", + "status": "订单状态" + }, + "returns": "指定用户的机票订单信息", + "tool_type": "READ" + }, + + "search_train_order": { + "description": "根据用户ID,查询用户的火车票订单,返回包含订单ID、订单类型、用户ID、火车ID、订单总价、下单时间、更新时间和订单状态", + "preconditions": "用户需求为查询火车票订单", + "postconditions": "返回订单信息,方便之后进行修改/取消", + "args": { + "user_id": "用户ID", + "date": "日期", + "status": "订单状态" + }, + "returns": "指定用户的火车票订单信息", + "tool_type": "READ" + }, + + "get_hotel_order_detail": { + "description": "获取酒店订单详情,包含订单ID、订单类型、用户ID、酒店ID、订单总价、下单时间、更新时间、订单状态和订单房间详细信息(房间类型、入住日期、价格、房间ID)", + "preconditions": "用户请求获取酒店订单详情,上文确定了订单ID", + "postconditions": "返回订单详情", + "args": { + "order_id": "订单ID" + }, + "returns": "订单详情", + "tool_type": "READ" + }, + + "get_attraction_order_detail": { + "description": "获取景点门票订单详情,包含订单ID、订单类型、用户ID、景点ID、订单总价、下单时间、更新时间、订单状态和订单景点详细信息(门票类型、日期、价格、门票ID)", + "preconditions": "用户请求获取景点门票订单详情,上文确定了订单ID", + "postconditions": "返回订单详情", + "args": { + "order_id": "订单ID" + }, + "returns": "订单详情", + "tool_type": "READ" + }, + + "get_flight_order_detail": { + "description": "获取机票订单详情,包含订单ID、订单类型、用户ID、航班ID、订单总价、下单时间、更新时间、订单状态和订单机票详细信息(座位类型、日期、价格、座位ID)", + "preconditions": "用户请求获取机票订单详情,上文确定了订单ID", + "postconditions": "返回订单详情", + "args": { + "order_id": "订单ID" + }, + "returns": "订单详情", + "tool_type": "READ" + }, + + "get_train_order_detail": { + "description": "获取火车票订单详情,包含订单ID、订单类型、用户ID、火车ID、订单总价、下单时间、更新时间、订单状态和订单火车票详细信息(座位类型、日期、价格、座位ID)", + "preconditions": "用户请求获取火车票订单详情,上文确定了订单ID", + "postconditions": "返回订单详情", + "args": { + "order_id": "订单ID" + }, + "returns": "订单详情", + "tool_type": "READ" + }, + + "modify_train_order": { + "description": "修改火车票订单,支持更改出发日期,自动处理补差价或退差价。", + "preconditions": "在火车票查询购买场景,用户请求修改火车票订单,上文确定了订单ID", + "postconditions": "修改订单并更新订单状态,若需补差价,订单状态改为unpaid,否则保持原状态,如需补差价,需引导用户支付当笔订单", + "args": { + "order_id": "订单ID", + "user_id": "用户ID", + "new_date": "新的出发日期,格式为 %Y-%m-%d" + }, + "returns": "(修改后的订单内容, 差价,正为需补差价,负为退差价)", + "tool_type": "WRITE" + }, + + "modify_flight_order": { + "description": "修改机票订单,支持更改出发日期,自动处理补差价或退差价。", + "preconditions": "在机票查询购买场景,用户请求修改机票订单,上文确定了订单ID", + "postconditions": "修改订单并更新订单状态,若需补差价,订单状态改为unpaid,否则保持原状态,如需补差价,需引导用户支付当笔订单", + "args": { + "order_id": "订单ID", + "user_id": "用户ID", + "new_date": "新的出发日期,格式为 %Y-%m-%d" + }, + "returns": "(修改后的订单内容, 差价,正为需补差价,负为退差价)", + "tool_type": "WRITE" + }, + + "cancel_hotel_order": { + "description": "用户取消已预订的酒店订单", + "preconditions": "在酒店查询预订场景,用户请求取消酒店订单,上文确定了订单ID", + "postconditions": "取消订单并更新订单状态,若需退差价,告知用户即可", + "args": { + "order_id": "订单ID", + "user_id": "用户ID" + }, + "returns": "取消订单的退款金额", + "tool_type": "WRITE" + }, + + "cancel_attraction_order": { + "description": "用户取消已预订的景点门票订单", + "preconditions": "历史对话中有订单id或者已经进行过订单查询,用户有权限取消该订单", + "postconditions": "如果退差价,告知用户即可", + "args": { + "order_id": "订单ID", + "user_id": "用户ID" + }, + "returns": "取消订单的退款金额", + "tool_type": "WRITE" + }, + + "cancel_flight_order": { + "description": "用户取消已预订的机票订单", + "preconditions": "历史对话中有订单id或者已经进行过订单查询,用户有权限取消该订单", + "postconditions": "如果退差价,告知用户即可", + "args": { + "order_id": "订单ID", + "user_id": "用户ID" + }, + "returns": "取消订单的退款金额", + "tool_type": "WRITE" + }, + + "cancel_train_order": { + "description": "用户取消已预订的火车票订单", + "preconditions": "历史对话中有订单id或者已经进行过订单查询,用户有权限取消该订单", + "postconditions": "如果退差价,告知用户即可", + "args": { + "order_id": "订单ID", + "user_id": "用户ID" + }, + "returns": "取消订单的退款金额", + "tool_type": "WRITE" + } +} + +# Tool descriptions extracted from tools.py - English version +TOOL_DESCRIPTIONS_EN = { + "get_ota_hotel_info": { + "description": "Get hotel information including hotel id, name, rating, star level, address, tags, and room list", + "preconditions": "In hotel query and booking scenario, need to get detailed hotel information", + "postconditions": "Return detailed hotel information, guide user to select room and place order", + "args": { + "hotel_id": "Hotel id" + }, + "returns": "Hotel information", + "tool_type": "READ" + }, + + "get_ota_attraction_info": { + "description": "Get attraction information including attraction id, name, address, description, rating, opening hours, ticket prices, and ticket type list", + "preconditions": "In attraction travel scenario, need to get detailed attraction information", + "postconditions": "Return detailed attraction information, guide user to select tickets and place order", + "args": { + "attraction_id": "Attraction id" + }, + "returns": "Attraction information", + "tool_type": "READ" + }, + + "get_ota_flight_info": { + "description": "Get flight information including flight id, flight number, departure city, arrival city, departure airport location, arrival airport location, departure time, arrival time, flight tags, and seat type list", + "preconditions": "In flight ticket query and purchase scenario, need to get detailed flight information", + "postconditions": "Return detailed flight information, guide user to select seats and place order", + "args": { + "flight_id": "Flight id" + }, + "returns": "Flight information", + "tool_type": "READ" + }, + + "get_ota_train_info": { + "description": "Get train information including train id, train number, departure city, arrival city, departure station location, arrival station location, departure time, arrival time, train tags, and seat type list", + "preconditions": "In train ticket query and purchase scenario, need to get detailed train information", + "postconditions": "Return detailed train information, guide user to select seats and place order", + "args": { + "train_id": "Train id" + }, + "returns": "Train information", + "tool_type": "READ" + }, + + "hotel_search_recommend": { + "description": "In hotel query and booking scenario, recommend suitable hotel options based on user location needs and preferences, provide basic hotel information including hotel id, name, rating, star level, address, and tags", + "preconditions": "User requests hotel booking, provides hotel-related keywords or location", + "postconditions": "Return list of hotels meeting criteria, if hotel details (room list, prices, etc.) are needed, use hotel detail query tool, guide user to select hotel", + "args": { + "city_name": "City name", + "key_words": "Search keywords (matching hotel name, hotel introduction, etc.)" + }, + "returns": "Structured output of basic hotel information", + "tool_type": "READ" + }, + + "attractions_search_recommend": { + "description": "Recommend suitable attraction options based on user location needs and preferences, provide basic attraction information including attraction id, name, address, description, rating, and opening hours", + "preconditions": "User requests attraction booking, provides attraction-related keywords or location", + "postconditions": "Return list of attractions meeting criteria, if attraction details (ticket list, prices, etc.) are needed, use attraction detail query tool, guide user to select attraction", + "args": { + "city_name": "City name", + "key_words": "Search keywords (matching attraction name, location, address, features, etc.)" + }, + "returns": "Structured output of basic attraction information", + "tool_type": "READ" + }, + + "flight_search_recommend": { + "description": "Recommend suitable flight options based on user location needs and preferences, provide basic flight information including flight id, flight number, departure city, arrival city, departure airport location, arrival airport location, departure time, arrival time, and flight tags", + "preconditions": "User requests flight booking, provides flight-related keywords or location", + "postconditions": "Return list of flights meeting criteria, if flight details (seat type list, prices, dates, etc.) are needed, use flight detail query tool, guide user to select flight", + "args": { + "departure": "Departure city", + "destination": "Destination city" + }, + "returns": "Structured output of basic flight information", + "tool_type": "READ" + }, + + "train_ticket_search": { + "description": "Recommend suitable train options based on user location needs and preferences, provide basic train ticket information including train id, train number, departure city, arrival city, departure station location, arrival station location, departure time, arrival time, and train tags", + "preconditions": "User requests train booking, provides train-related keywords or location", + "postconditions": "Return list of train tickets meeting criteria, if train ticket details (seat type list, prices, dates, etc.) are needed, use train ticket detail query tool, guide user to select train ticket", + "args": { + "departure": "Departure city", + "destination": "Destination city", + "date": "Departure date" + }, + "returns": "Structured output of basic train information", + "tool_type": "READ" + }, + + "create_hotel_order": { + "description": "When user books hotel, system generates order based on user requirements (such as hotel name, check-in date, number of people, etc.)", + "preconditions": "User is logged in and provides valid identity (user_id), user provides valid hotel name (hotel_name) and room type (room_type), system has information about target hotel, and hotel has rooms available on requested dates", + "postconditions": "Generate order, ask user to confirm payment", + "args": { + "hotel_id": "Hotel ID", + "product_id": "Room ID", + "user_id": "User ID" + }, + "returns": "Feedback output of creating order operation", + "tool_type": "WRITE" + }, + + "create_attraction_order": { + "description": "User purchases tickets based on attraction and date, system returns ticket-related information and places order", + "preconditions": "In attraction travel scenario, user requests to book attraction, provides necessary booking information", + "postconditions": "Generate order, ask user to confirm payment", + "args": { + "attraction_id": "Attraction ID", + "ticket_id": "Ticket ID", + "user_id": "User ID", + "date": "Visit date, format: %Y-%m-%d", + "quantity": "Quantity" + }, + "returns": "Feedback output of creating order operation", + "tool_type": "WRITE" + }, + + "create_flight_order": { + "description": "User purchases flight tickets based on flight and seat type, system returns flight ticket-related information and places order", + "preconditions": "In flight ticket query and purchase scenario, user requests to book flight, provides necessary booking information", + "postconditions": "Generate order, ask user to confirm payment", + "args": { + "flight_id": "Flight ID", + "seat_id": "Seat ID", + "user_id": "User ID", + "date": "Departure date, format: %Y-%m-%d", + "quantity": "Quantity" + }, + "returns": "Feedback output of creating order operation", + "tool_type": "WRITE" + }, + + "create_train_order": { + "description": "User purchases train tickets based on train and seat type, system returns train ticket-related information and places order", + "preconditions": "In train ticket query and purchase scenario, user requests to book train, provides necessary booking information", + "postconditions": "Generate order, ask user to confirm payment", + "args": { + "train_id": "Train ID", + "seat_id": "Seat ID", + "user_id": "User ID", + "date": "Departure date, format: %Y-%m-%d", + "quantity": "Quantity" + }, + "returns": "Feedback output of creating order operation", + "tool_type": "WRITE" + }, + + "pay_hotel_order": { + "description": "User pays for hotel order", + "preconditions": "In hotel query and booking scenario, user requests payment for hotel order, order ID is determined above", + "postconditions": "Confirm payment and update order status to paid", + "args": { + "order_id": "Order ID" + }, + "returns": "Payment result information", + "tool_type": "WRITE" + }, + + "pay_attraction_order": { + "description": "User pays for attraction ticket order", + "preconditions": "In attraction travel scenario, user requests payment for attraction ticket order, order ID is determined above", + "postconditions": "Confirm payment and update order status to paid", + "args": { + "order_id": "Order ID" + }, + "returns": "Payment result information", + "tool_type": "WRITE" + }, + + "pay_flight_order": { + "description": "User pays for flight order", + "preconditions": "In flight ticket query and purchase scenario, user requests payment for flight order, order ID is determined above", + "postconditions": "Confirm payment and update order status to paid", + "args": { + "order_id": "Order ID" + }, + "returns": "Payment result information", + "tool_type": "WRITE" + }, + + "pay_train_order": { + "description": "User pays for train ticket order", + "preconditions": "In train ticket query and purchase scenario, user requests payment for train ticket order, order ID is determined above", + "postconditions": "Confirm payment and update order status to paid", + "args": { + "order_id": "Order ID" + }, + "returns": "Payment result information", + "tool_type": "WRITE" + }, + + "search_hotel_order": { + "description": "Query user's hotel orders based on user ID, return order ID, order type, user ID, hotel ID, order total price, order time, update time and order status", + "preconditions": "User needs to query hotel orders", + "postconditions": "Return order information, convenient for later modification/cancellation", + "args": { + "user_id": "User ID", + "date": "Date", + "status": "Order status" + }, + "returns": "Specified user's hotel order information", + "tool_type": "READ" + }, + + "search_attraction_order": { + "description": "Query user's attraction ticket orders based on user ID, return order ID, order type, user ID, attraction ID, order total price, order time, update time and order status", + "preconditions": "User needs to query attraction ticket orders", + "postconditions": "Return order information, convenient for later modification/cancellation", + "args": { + "user_id": "User ID", + "date": "Date", + "status": "Order status" + }, + "returns": "Specified user's attraction ticket order information", + "tool_type": "READ" + }, + + "search_flight_order": { + "description": "Query user's flight orders based on user ID, return order ID, order type, user ID, flight ID, order total price, order time, update time and order status", + "preconditions": "User needs to query flight orders", + "postconditions": "Return order information, convenient for later modification/cancellation", + "args": { + "user_id": "User ID", + "date": "Date", + "status": "Order status" + }, + "returns": "Specified user's flight order information", + "tool_type": "READ" + }, + + "search_train_order": { + "description": "Query user's train ticket orders based on user ID, return order ID, order type, user ID, train ID, order total price, order time, update time and order status", + "preconditions": "User needs to query train ticket orders", + "postconditions": "Return order information, convenient for later modification/cancellation", + "args": { + "user_id": "User ID", + "date": "Date", + "status": "Order status" + }, + "returns": "Specified user's train ticket order information", + "tool_type": "READ" + }, + + "get_hotel_order_detail": { + "description": "Get hotel order details, including order ID, order type, user ID, hotel ID, order total price, order time, update time, order status and order room detailed information (room type, check-in date, price, room ID)", + "preconditions": "User requests hotel order details, order ID is determined above", + "postconditions": "Return order details", + "args": { + "order_id": "Order ID" + }, + "returns": "Order details", + "tool_type": "READ" + }, + + "get_attraction_order_detail": { + "description": "Get attraction ticket order details, including order ID, order type, user ID, attraction ID, order total price, order time, update time, order status and order attraction detailed information (ticket type, date, price, ticket ID)", + "preconditions": "User requests attraction ticket order details, order ID is determined above", + "postconditions": "Return order details", + "args": { + "order_id": "Order ID" + }, + "returns": "Order details", + "tool_type": "READ" + }, + + "get_flight_order_detail": { + "description": "Get flight order details, including order ID, order type, user ID, flight ID, order total price, order time, update time, order status and order flight ticket detailed information (seat type, date, price, seat ID)", + "preconditions": "User requests flight order details, order ID is determined above", + "postconditions": "Return order details", + "args": { + "order_id": "Order ID" + }, + "returns": "Order details", + "tool_type": "READ" + }, + + "get_train_order_detail": { + "description": "Get train ticket order details, including order ID, order type, user ID, train ID, order total price, order time, update time, order status and order train ticket detailed information (seat type, date, price, seat ID)", + "preconditions": "User requests train ticket order details, order ID is determined above", + "postconditions": "Return order details", + "args": { + "order_id": "Order ID" + }, + "returns": "Order details", + "tool_type": "READ" + }, + + "modify_train_order": { + "description": "Modify train ticket order, support changing departure date, automatically handle price difference compensation or refund", + "preconditions": "In train ticket query and purchase scenario, user requests to modify train ticket order, order ID is determined above", + "postconditions": "Modify order and update order status, if price difference compensation is needed, order status changes to unpaid, otherwise maintains original status, if price difference compensation is needed, guide user to pay for current order", + "args": { + "order_id": "Order ID", + "user_id": "User ID", + "new_date": "New departure date, format: %Y-%m-%d" + }, + "returns": "(Modified order content, price difference, positive means need to compensate, negative means refund)", + "tool_type": "WRITE" + }, + + "modify_flight_order": { + "description": "Modify flight order, support changing departure date, automatically handle price difference compensation or refund", + "preconditions": "In flight ticket query and purchase scenario, user requests to modify flight order, order ID is determined above", + "postconditions": "Modify order and update order status, if price difference compensation is needed, order status changes to unpaid, otherwise maintains original status, if price difference compensation is needed, guide user to pay for current order", + "args": { + "order_id": "Order ID", + "user_id": "User ID", + "new_date": "New departure date, format: %Y-%m-%d" + }, + "returns": "(Modified order content, price difference, positive means need to compensate, negative means refund)", + "tool_type": "WRITE" + }, + + "cancel_hotel_order": { + "description": "User cancels booked hotel order", + "preconditions": "In hotel query and booking scenario, user requests to cancel hotel order, order ID is determined above", + "postconditions": "Cancel order and update order status, if refund is needed, inform user", + "args": { + "order_id": "Order ID", + "user_id": "User ID" + }, + "returns": "Order cancellation refund amount", + "tool_type": "WRITE" + }, + + "cancel_attraction_order": { + "description": "User cancels booked attraction ticket order", + "preconditions": "Order ID exists in conversation history or order query has been performed, user has permission to cancel this order", + "postconditions": "If refund is needed, inform user", + "args": { + "order_id": "Order ID", + "user_id": "User ID" + }, + "returns": "Order cancellation refund amount", + "tool_type": "WRITE" + }, + + "cancel_flight_order": { + "description": "User cancels booked flight order", + "preconditions": "Order ID exists in conversation history or order query has been performed, user has permission to cancel this order", + "postconditions": "If refund is needed, inform user", + "args": { + "order_id": "Order ID", + "user_id": "User ID" + }, + "returns": "Order cancellation refund amount", + "tool_type": "WRITE" + }, + + "cancel_train_order": { + "description": "User cancels booked train ticket order", + "preconditions": "Order ID exists in conversation history or order query has been performed, user has permission to cancel this order", + "postconditions": "If refund is needed, inform user", + "args": { + "order_id": "Order ID", + "user_id": "User ID" + }, + "returns": "Order cancellation refund amount", + "tool_type": "WRITE" + } +} + +# Create OTA tool schema manager +_schema_manager = create_tool_schema_manager("ota", TOOL_DESCRIPTIONS_ZH, TOOL_DESCRIPTIONS_EN) + +# For backward compatibility, provide original function interface +def get_tool_descriptions(): + """Get tool descriptions based on language configuration.""" + return _schema_manager.get_tool_descriptions() + +def get_tool_description(tool_name: str) -> Dict[str, Any]: + """Get the description of a specific tool by name.""" + return _schema_manager.get_tool_description(tool_name) + +def get_all_tool_names() -> list: + """Get a list of all available tool names.""" + return _schema_manager.get_all_tool_names() + +def get_tools_by_type(tool_type: str) -> Dict[str, Dict[str, Any]]: + """Get all tools of a specific type.""" + return _schema_manager.get_tools_by_type(tool_type) + +def get_tool_count_by_type() -> Dict[str, int]: + """Get the count of tools by type.""" + return _schema_manager.get_tool_count_by_type() + +def get_tool_args(tool_name: str) -> Dict[str, str]: + """Get the arguments of a specific tool by name.""" + return _schema_manager.get_tool_args(tool_name) + +def get_tool_returns(tool_name: str) -> str: + """Get the return value description of a specific tool by name.""" + return _schema_manager.get_tool_returns(tool_name) + +# Backward compatible variable +TOOL_DESCRIPTIONS = get_tool_descriptions() +TOOL_TYPE_MAPPING = _schema_manager.tool_type_mapping diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/evaluator/evaluator_traj.py b/examples/AgenticBenchmarks/VitaBench/src/vita/evaluator/evaluator_traj.py new file mode 100644 index 0000000..52c1792 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/evaluator/evaluator_traj.py @@ -0,0 +1,737 @@ +import json +import copy +from typing import List + +from vita.config import DEFAULT_LLM_EVALUATOR, models +from vita.data_model.message import UserMessage, SystemMessage, Message +from vita.evaluator.evaluator_base import EvaluatorBase +from vita.data_model.simulation import NLRubricCheck, RewardInfo, NLRubricCheck +from vita.data_model.tasks import RewardType, Task, EvaluationCriteria +from vita.utils.llm_utils import generate +from vita.utils import evaluator_extracter, get_weekday +from vita.prompts import get_prompts + + +class TrajectoryEvaluator(EvaluatorBase): + """ + Judge that evaluates whether a trajectory adheres to all the natural-language rubrics using sliding window approach. + """ + + @classmethod + def calculate_reward( + cls, + task: Task, + full_trajectory: List[Message], + final_state: dict, + window_size: int = 10, + overlap: int = 2, + llm_evaluator: str = None, + llm_args_evaluator: dict = None, + language: str = None, + ) -> RewardInfo: + """ + Calculate the reward for the simulation by using sliding window evaluation on the full trajectory + + Args: + task: The task containing evaluation criteria + full_trajectory: Complete list of messages in the conversation + final_state: Final state of the simulation + window_size: Number of messages per window (default: 10) + overlap: Number of messages to overlap between windows (default: 2) + """ + if task.evaluation_criteria is None: + return RewardInfo( + reward=1.0, + nl_rubrics=[], + info={"note": "No evaluation criteria"}, + reward_breakdown={RewardType.NL_ASSERTION: 1.0}, + ) + + evaluation_criteria = task.evaluation_criteria + if not evaluation_criteria.expected_states and not evaluation_criteria.overall_rubrics: + return RewardInfo( + reward=1.0, + nl_rubrics=[], + info={"note": "No rubric to evaluate"}, + reward_breakdown={RewardType.NL_ASSERTION: 1.0}, + ) + + env_info = { + "system_time": "", + "database": [] + } + + if hasattr(task, 'environment') and task.environment: + time_str = task.environment.get("time", "") + if time_str: + weekday = get_weekday(time_str, language) + env_info["system_time"] = f"{time_str} {weekday or ''}" + + current_rubric_states = cls._initialize_rubric_states(evaluation_criteria) + + windows = cls._create_sliding_windows(full_trajectory, window_size, overlap) + + step = window_size - overlap + window_evaluations = [] + + for i, window in enumerate(windows): + print(f"Processing window {i+1}/{len(windows)} with {len(window)} messages") + window_start_idx = i * step + current_rubric_states, window_eval_info = cls._evaluate_window( + env_info, task, window, current_rubric_states, i+1, len(windows), window_start_idx, + llm_evaluator, llm_args_evaluator, language + ) + window_evaluations.append(window_eval_info) + + final_nl_rubric_checks = cls._convert_states_to_checks(current_rubric_states) + + all_expectations_met = all(result.met for result in final_nl_rubric_checks) and len(final_nl_rubric_checks) > 0 + rubric_score = sum(1.0 if result.met else 0.0 for result in final_nl_rubric_checks) / len(final_nl_rubric_checks) + reward = 1.0 if all_expectations_met else 0.0 + + return RewardInfo( + reward=reward, + nl_rubrics=final_nl_rubric_checks, + reward_breakdown={RewardType.NL_ASSERTION: rubric_score}, + info={"evaluation_method": "sliding_window", "num_windows": len(windows), "window_size": window_size}, + window_evaluations=window_evaluations + ) + + @classmethod + def _initialize_rubric_states(cls, evaluation_criteria: EvaluationCriteria) -> dict: + """ + Initialize rubric states - all start as False (not met) + """ + rubric_states = {} + rubric_idx = 0 + seen_rubrics = set() + + if evaluation_criteria.expected_states: + for expected_state in evaluation_criteria.expected_states: + if hasattr(expected_state, 'state_rubrics') and expected_state.state_rubrics: + for rubric in expected_state.state_rubrics: + if rubric in seen_rubrics: + continue + seen_rubrics.add(rubric) + + key = f"rubric_{rubric_idx}" + rubric_states[key] = { + "rubric": rubric, + "justification": "Not evaluated yet", + "meetExpectation": False + } + rubric_idx += 1 + + if evaluation_criteria.overall_rubrics: + for rubric in evaluation_criteria.overall_rubrics: + if rubric in seen_rubrics: + continue + seen_rubrics.add(rubric) + + key = f"rubric_{rubric_idx}" + rubric_states[key] = { + "rubric": rubric, + "justification": "Not evaluated yet", + "meetExpectation": False + } + rubric_idx += 1 + + return rubric_states + + @classmethod + def _create_sliding_windows(cls, messages: List[Message], window_size: int, overlap: int = 2) -> List[List[Message]]: + """ + Create sliding windows from the message list with overlap + + Args: + messages: List of messages to create windows from + window_size: Size of each window + overlap: Number of messages to overlap between windows + """ + if len(messages) <= window_size: + return [messages] + + windows = [] + step = window_size - overlap + + i = 0 + while i < len(messages): + window = messages[i:i + window_size] + if len(window) > 0: + windows.append(window) + + if i + window_size >= len(messages): + break + + i += step + + return windows + + @classmethod + def _evaluate_window( + cls, + env_info: dict, + task: Task, + window: List[Message], + current_states: dict, + window_idx: int, + total_windows: int, + window_start_idx: int = 0, + llm_evaluator: str = None, + llm_args_evaluator: dict = None, + language: str = None, + ) -> tuple[dict, dict]: + """ + Evaluate a single window and update rubric states + + Returns: + tuple: (updated_states, window_evaluation_info) + """ + if llm_evaluator is None: + llm_evaluator = DEFAULT_LLM_EVALUATOR + if llm_args_evaluator is None: + llm_args_evaluator = models[DEFAULT_LLM_EVALUATOR] + + window_content = cls._format_window_content(window, window_start_idx) + + current_rubrics_str = cls._format_current_rubrics(current_states) + + prompts = get_prompts(language) + system_prompt = prompts.sliding_window_eval_template.format( + env_info=env_info, + user_instruction=task.instructions, + window_idx=window_idx, + total_windows=total_windows + ) + + user_prompt = f""" +# Input + +{window_content} + + + +{current_rubrics_str} + +""" + + messages = [ + SystemMessage(role="system", content=system_prompt), + UserMessage(role="user", content=user_prompt), + ] + + assistant_message = generate( + model=llm_evaluator, + messages=messages, + **llm_args_evaluator, + ) + + window_evaluation_info = { + "window_idx": window_idx, + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "assistant_message_content": assistant_message.content, + "assistent_message_usage": assistant_message.usage + } + + updated_states = copy.deepcopy(current_states) + result_data = evaluator_extracter(assistant_message.content) + + if result_data: + for result in result_data: + rubric_idx = result.get("rubric_idx") + if rubric_idx and rubric_idx in updated_states: + updated_states[rubric_idx]["justification"] = result.get("justification", "No justification provided") + updated_states[rubric_idx]["meetExpectation"] = result.get("meetExpectation", updated_states[rubric_idx]["meetExpectation"]) + else: + print(f"Warning: Failed to parse LLM response for window {window_idx}, keeping current states") + + return updated_states, window_evaluation_info + + @classmethod + def _format_window_content(cls, window: List[Message], window_start_idx: int = 0) -> str: + """ + Format window messages into a readable string with global message indices + + Args: + window: List of messages in the current window + window_start_idx: Global index of the first message in this window + """ + content_lines = [] + for i, message in enumerate(window): + role = getattr(message, 'role', 'unknown') + content = getattr(message, 'content', '') + + full_content = content + + if role == 'assistant' and hasattr(message, 'tool_calls') and message.tool_calls: + tool_calls_str = [] + for tool_call in message.tool_calls: + if hasattr(tool_call, 'name'): + tool_name = tool_call.name + elif isinstance(tool_call, dict): + tool_name = tool_call.get('name', 'unknown_tool') + else: + tool_name = 'unknown_tool' + + if hasattr(tool_call, 'arguments'): + tool_args = tool_call.arguments + elif isinstance(tool_call, dict): + tool_args = tool_call.get('arguments', {}) + else: + tool_args = {} + + if isinstance(tool_args, dict): + args_str = ', '.join([f"{k}={repr(v)}" for k, v in tool_args.items()]) + else: + args_str = str(tool_args) + tool_calls_str.append(f"{tool_name}({args_str})") + + if tool_calls_str: + if full_content: + full_content += " " + ".".join(tool_calls_str) + else: + full_content = ".".join(tool_calls_str) + + if full_content: + global_idx = window_start_idx + i + 1 + content_lines.append(f"[{global_idx}] {role}: {full_content}") + + return "\n".join(content_lines) + + @classmethod + def _format_current_rubrics(cls, current_states: dict) -> str: + """ + Format current rubric states for LLM input + """ + rubrics_list = [] + for key, state in current_states.items(): + rubrics_list.append({ + "rubric_idx": key, + "rubric": state["rubric"], + "justification": state["justification"], + "meetExpectation": state["meetExpectation"] + }) + + return json.dumps(rubrics_list, ensure_ascii=False, indent=2) + + @classmethod + def _convert_states_to_checks(cls, final_states: dict) -> List[NLRubricCheck]: + """ + Convert final rubric states to NLRubricCheck objects + """ + checks = [] + for key, state in final_states.items(): + check = NLRubricCheck( + nl_rubric=state["rubric"], + met=state["meetExpectation"], + justification=state["justification"] + ) + checks.append(check) + + return checks + + @classmethod + def calculate_reward_full_traj_rubric( + cls, + task: Task, + full_trajectory: List[Message], + final_state: dict, + llm_evaluator: str = None, + llm_args_evaluator: dict = None, + language: str = None, + ) -> RewardInfo: + if task.evaluation_criteria is None: + return RewardInfo( + reward=1.0, + nl_rubrics=[], + info={"note": "No evaluation criteria"}, + reward_breakdown={RewardType.NL_ASSERTION: 1.0}, + ) + + evaluation_criteria = task.evaluation_criteria + if not evaluation_criteria.expected_states and not evaluation_criteria.overall_rubrics: + return RewardInfo( + reward=1.0, + nl_rubrics=[], + info={"note": "No rubric to evaluate"}, + reward_breakdown={RewardType.NL_ASSERTION: 1.0}, + ) + + env_info = { + "system_time": "", + "database": [] + } + + if hasattr(task, 'environment') and task.environment: + time_str = task.environment.get("time", "") + if time_str: + weekday = get_weekday(time_str, language) + env_info["system_time"] = f"{time_str} {weekday or ''}" + + current_rubric_states = cls._initialize_rubric_states(evaluation_criteria) + + current_rubric_states, trajectory_eval_info = cls._evaluate_trajectory( + env_info, task, full_trajectory, current_rubric_states, llm_evaluator, llm_args_evaluator, language + ) + + final_nl_rubric_checks = cls._convert_states_to_checks(current_rubric_states) + + all_expectations_met = all(result.met for result in final_nl_rubric_checks) and len(final_nl_rubric_checks) > 0 + rubric_score = sum(1.0 if result.met else 0.0 for result in final_nl_rubric_checks) / len( + final_nl_rubric_checks) + reward = 1.0 if all_expectations_met else 0.0 + + return RewardInfo( + reward=reward, + nl_rubrics=final_nl_rubric_checks, + reward_breakdown={RewardType.NL_ASSERTION: rubric_score}, + info={"evaluation_method": "full_trajectory_with_rubrics"} + ) + + @classmethod + def _evaluate_trajectory( + cls, + env_info: dict, + task: Task, + trajectory: List[Message], + current_states: dict, + llm_evaluator: str = None, + llm_args_evaluator: dict = None, + language: str = None, + ) -> tuple[dict, dict]: + if llm_evaluator is None: + llm_evaluator = DEFAULT_LLM_EVALUATOR + if llm_args_evaluator is None: + llm_args_evaluator = models[DEFAULT_LLM_EVALUATOR] + + trajectory_content = cls._format_window_content(trajectory) + + current_rubrics_str = cls._format_current_rubrics(current_states) + + prompts = get_prompts(language) + system_prompt = prompts.full_trajectory_eval_template.format( + env_info=env_info, + user_instruction=task.instructions + ) + + user_prompt = f""" + # Input + + {trajectory_content} + + + + {current_rubrics_str} + + """ + + messages = [ + SystemMessage(role="system", content=system_prompt), + UserMessage(role="user", content=user_prompt), + ] + + assistant_message = generate( + model=llm_evaluator, + messages=messages, + **llm_args_evaluator, + ) + + trajectory_evaluation_info = { + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "assistant_message_content": assistant_message.content + } + + updated_states = copy.deepcopy(current_states) + result_data = evaluator_extracter(assistant_message.content) + + if result_data: + for result in result_data: + rubric_idx = result.get("rubric_idx") + if rubric_idx and rubric_idx in updated_states: + updated_states[rubric_idx]["justification"] = result.get("justification", + "No justification provided") + updated_states[rubric_idx]["meetExpectation"] = result.get("meetExpectation", + updated_states[rubric_idx][ + "meetExpectation"]) + else: + print(f"Warning: Failed to parse LLM response for total trajectory, keeping current states") + + return updated_states, trajectory_evaluation_info + + @classmethod + def calculate_reward_sliding_wo_rubric( + cls, + task: Task, + full_trajectory: List[Message], + final_state: dict, + window_size: int = 10, + overlap: int = 2, + llm_evaluator: str = None, + llm_args_evaluator: dict = None, + language: str = None, + ) -> RewardInfo: + if task.evaluation_criteria is None: + return RewardInfo( + reward=1.0, + nl_rubrics=[], + info={"note": "No evaluation criteria"}, + reward_breakdown={RewardType.NL_ASSERTION: 1.0}, + ) + + evaluation_criteria = task.evaluation_criteria + if not evaluation_criteria.expected_states and not evaluation_criteria.overall_rubrics: + return RewardInfo( + reward=1.0, + nl_rubrics=[], + info={"note": "No rubric to evaluate"}, + reward_breakdown={RewardType.NL_ASSERTION: 1.0}, + ) + + env_info = { + "system_time": "", + "database": [] + } + + if hasattr(task, 'environment') and task.environment: + time_str = task.environment.get("time", "") + if time_str: + weekday = get_weekday(time_str, language) + env_info["system_time"] = f"{time_str} {weekday or ''}" + + windows = cls._create_sliding_windows(full_trajectory, window_size, overlap) + current_evaluation = { + "justification": "Not evaluated yet", + "meetExpectation": False + } + + step = window_size - overlap + window_evaluations = [] + + memory = "" + for i, window in enumerate(windows): + print(f"Processing window {i + 1}/{len(windows)} with {len(window)} messages") + window_start_idx = i * step + current_evaluation, window_eval_info, memory = cls._evaluate_window_sliding_wo_rubric( + env_info, task, memory, current_evaluation, window, i + 1, len(windows), window_start_idx, + llm_evaluator, llm_args_evaluator, language + ) + window_evaluations.append(window_eval_info) + + final_nl_rubric_checks = cls._convert_states_to_checks_no_rubric(current_evaluation) + + all_expectations_met = final_nl_rubric_checks.met + reward = 1.0 if all_expectations_met else 0.0 + + return RewardInfo( + reward=reward, + nl_rubrics=[final_nl_rubric_checks], + info={"evaluation_method": "sliding_window_no_rubrics", "num_windows": len(windows), "window_size": window_size}, + window_evaluations=window_evaluations + ) + + @classmethod + def _evaluate_window_sliding_wo_rubric( + cls, + env_info: dict, + task: Task, + memory: str, + current_evaluation, + window: List[Message], + window_idx: int, + total_windows: int, + window_start_idx: int = 0, + llm_evaluator: str = None, + llm_args_evaluator: dict = None, + language: str = None, + ) -> tuple[dict, dict]: + """ + Evaluate a single window and update rubric states + + Returns: + tuple: (updated_states, window_evaluation_info) + """ + if llm_evaluator is None: + llm_evaluator = DEFAULT_LLM_EVALUATOR + if llm_args_evaluator is None: + llm_args_evaluator = models[DEFAULT_LLM_EVALUATOR] + + window_content = cls._format_window_content(window, window_start_idx) + + current_evaluation_str = json.dumps(current_evaluation, ensure_ascii=False, indent=2) + + prompts = get_prompts(language) + system_prompt = prompts.sliding_window_eval_no_rubrics_eval_template.format( + env_info=env_info, + user_instruction=task.instructions, + window_idx=window_idx, + total_windows=total_windows + ) + + user_prompt = f""" +# Input + +{window_content} + + +{memory} + + +{current_evaluation_str} + +""" + + messages = [ + SystemMessage(role="system", content=system_prompt), + UserMessage(role="user", content=user_prompt), + ] + + assistant_message = generate( + model=llm_evaluator, + messages=messages, + **llm_args_evaluator, + ) + + window_evaluation_info = { + "window_idx": window_idx, + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "assistant_message_content": assistant_message.content + } + + updated_states = current_evaluation + result_data = evaluator_extracter(assistant_message.content) + + memory = result_data.pop("memory") + if result_data: + updated_states = result_data + else: + print(f"Warning: Failed to parse LLM response for window {window_idx}") + + return updated_states, window_evaluation_info, memory + + @classmethod + def _convert_states_to_checks_no_rubric(cls, final_states: dict) -> NLRubricCheck: + check = NLRubricCheck( + met=final_states["meetExpectation"], + justification=final_states["justification"] + ) + + return check + + @classmethod + def calculate_reward_full_traj_wo_rubric( + cls, + task: Task, + full_trajectory: List[Message], + final_state: dict, + llm_evaluator: str = None, + llm_args_evaluator: dict = None, + language: str = None, + ) -> RewardInfo: + if task.evaluation_criteria is None: + return RewardInfo( + reward=1.0, + nl_rubrics=[], + info={"note": "No evaluation criteria"}, + reward_breakdown={RewardType.NL_ASSERTION: 1.0}, + ) + + evaluation_criteria = task.evaluation_criteria + if not evaluation_criteria.expected_states and not evaluation_criteria.overall_rubrics: + return RewardInfo( + reward=1.0, + nl_rubrics=[], + info={"note": "No rubric to evaluate"}, + reward_breakdown={RewardType.NL_ASSERTION: 1.0}, + ) + + env_info = { + "system_time": "", + "database": [] + } + + if hasattr(task, 'environment') and task.environment: + time_str = task.environment.get("time", "") + if time_str: + weekday = get_weekday(time_str, language) + env_info["system_time"] = f"{time_str} {weekday or ''}" + + final_evaluation, trajectory_eval_info = cls._evaluate_trajectory_full_traj_wo_rubric( + env_info, task, full_trajectory, llm_evaluator, llm_args_evaluator, language + ) + + final_nl_rubric_checks = cls._convert_states_to_checks_no_rubric(final_evaluation) + + all_expectations_met = final_nl_rubric_checks.met + reward = 1.0 if all_expectations_met else 0.0 + + return RewardInfo( + reward=reward, + nl_rubrics=[final_nl_rubric_checks], + info={"evaluation_method": "full_trajectory_no_rubrics"} + ) + + @classmethod + def _evaluate_trajectory_full_traj_wo_rubric( + cls, + env_info: dict, + task: Task, + trajectory: List[Message], + llm_evaluator: str = None, + llm_args_evaluator: dict = None, + language: str = None, + ) -> tuple[dict, dict]: + if llm_evaluator is None: + llm_evaluator = DEFAULT_LLM_EVALUATOR + if llm_args_evaluator is None: + llm_args_evaluator = models[DEFAULT_LLM_EVALUATOR] + + trajectory_content = cls._format_window_content(trajectory) + + prompts = get_prompts(language) + system_prompt = prompts.full_trajectory_no_rubrics_eval_template.format( + env_info=env_info, + user_instruction=task.instructions + ) + + user_prompt = f""" + # Input + + {trajectory_content} + + """ + + messages = [ + SystemMessage(role="system", content=system_prompt), + UserMessage(role="user", content=user_prompt), + ] + + assistant_message = generate( + model=llm_evaluator, + messages=messages, + **llm_args_evaluator, + ) + + trajectory_evaluation_info = { + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "assistant_message_content": assistant_message.content + } + + result_data = evaluator_extracter(assistant_message.content) + + if result_data and isinstance(result_data, dict): + updated_states = result_data + elif result_data and isinstance(result_data, list) and len(result_data) > 0: + # If result_data is a list, take the first element + updated_states = result_data[0] + else: + print(f"Warning: Failed to parse LLM response for total trajectory, keeping current states") + updated_states = { + "justification": "Not evaluated yet", + "meetExpectation": False + } + + return updated_states, trajectory_evaluation_info \ No newline at end of file diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/orchestrator/orchestrator.py b/examples/AgenticBenchmarks/VitaBench/src/vita/orchestrator/orchestrator.py new file mode 100644 index 0000000..37676ae --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/orchestrator/orchestrator.py @@ -0,0 +1,440 @@ +import time +import uuid +from copy import deepcopy +from datetime import datetime, timedelta +from enum import Enum +from typing import Any, Optional, Dict + +from loguru import logger + +from vita.agent.base import BaseAgent, is_valid_agent_history_message +from vita.data_model.message import ( + AssistantMessage, + Message, + MultiToolMessage, + ToolMessage, + UserMessage, +) +from vita.data_model.simulation import SimulationRun, TerminationReason +from vita.data_model.tasks import Task +from vita.environment.db import DB +from vita.environment.environment import Environment +from vita.user.base import BaseUser, is_valid_user_history_message +from vita.user.user_simulator import UserSimulator, UserState +from vita.utils.llm_utils import get_cost +from vita.utils.utils import format_time, get_now, DATA_DIR +from vita.config import DEFAULT_LANGUAGE + + +class Role(str, Enum): + AGENT = "agent" + USER = "user" + ENV = "env" + + +def get_default_first_agent_message(language: str = None) -> AssistantMessage: + """Get the default first agent message based on language""" + if language is None: + language = DEFAULT_LANGUAGE + + content = "你好,请问需要什么服务?" if language == "chinese" else "Hello, how can I help you?" + return AssistantMessage( + role="assistant", content=content, cost=0.0 + ) + + +class Orchestrator: + """ + Orchestrator for the simulation given a task. + Passes messages between the Agent, User, and Environment. + """ + + def __init__( + self, + domain: str, + agent: BaseAgent, + user: BaseUser, + environment: Environment, + task: Task, + max_steps: int = 100, + max_errors: int = 10, + seed: Optional[int] = None, + solo_mode: bool = False, + language: str = None, + ): + self.domain = domain + self.agent = agent + self.user = user + self.environment = environment + self.task = task + self.seed = seed + self.solo_mode = solo_mode + self.language = language + self.agent_state: Optional[Any] = None + self.user_state: Optional[UserState] = None + self.trajectory: list[Message] = [] + self.max_steps = max_steps + self.max_errors = max_errors + self.step_count = 0 + self.done = False + self.termination_reason: Optional[TerminationReason] = None + self.num_errors = 0 + self.from_role: Optional[Role] = None + self.to_role: Optional[Role] = None + self.message: Optional[Message] = None + + def initialize(self): + """ + Initialize the orchestrator. + - If the tasks specifies an initial state, use it to initialize the environment. + - Initialize the agent and user states. + - Send the first message (default message from the agent to the user). + """ + message_history = ( + deepcopy(self.task.message_history) + if self.task is not None and self.task.message_history is not None + else [] + ) + for msg in message_history: + msg.turn_idx = None + + message_history = self._add_timestamps(message_history) + + if self.seed is not None: + self.agent.set_seed(self.seed) + self.user.set_seed(self.seed) + + if len(message_history) > 0: + self.validate_message_history(message_history) + + last_message = message_history[-1] + if isinstance(last_message, AssistantMessage): + self.from_role = Role.AGENT + if not last_message.is_tool_call(): + self.to_role = Role.USER + else: + self.to_role = Role.ENV + self.agent_state = self.agent.get_init_state( + message_history=[ + msg + for msg in message_history + if is_valid_agent_history_message(msg) + ] + ) + self.user_state = self.user.get_init_state( + message_history=[ + msg + for msg in message_history[:-1] + if is_valid_user_history_message(msg) + ] + ) + self.message = last_message + if self.agent.is_stop(last_message): + self.done = True + self.termination_reason = TerminationReason.AGENT_STOP + elif isinstance(last_message, UserMessage): + self.from_role = Role.USER + if not last_message.is_tool_call(): + self.to_role = Role.AGENT + else: + self.to_role = Role.ENV + self.user_state = self.user.get_init_state( + message_history=[ + msg + for msg in message_history + if is_valid_user_history_message(msg) + ] + ) + self.agent_state = self.agent.get_init_state( + message_history=[ + msg + for msg in message_history[:-1] + if is_valid_agent_history_message(msg) + ] + ) + self.message = last_message + self.done = UserSimulator.is_stop(last_message) + if self.done: + self.termination_reason = TerminationReason.USER_STOP + elif isinstance(last_message, ToolMessage): + self.from_role = Role.ENV + if last_message.requestor == "assistant": + self.to_role = Role.AGENT + self.agent_state = self.agent.get_init_state( + message_history=[ + msg + for msg in message_history[:-1] + if is_valid_agent_history_message(msg) + ] + ) + self.user_state = self.user.get_init_state( + message_history=[ + msg + for msg in message_history + if is_valid_user_history_message(msg) + ] + ) + else: + self.to_role = Role.USER + self.agent_state = self.agent.get_init_state( + message_history=[ + msg + for msg in message_history + if is_valid_agent_history_message(msg) + ] + ) + self.user_state = self.user.get_init_state( + message_history=[ + msg + for msg in message_history[:-1] + if is_valid_user_history_message(msg) + ] + ) + self.message = last_message + else: + raise ValueError( + f"Last message should be of type AssistantMessage, UserMessage, or ToolMessage, got {type(last_message)}" + ) + self.trajectory = message_history + + else: + self.agent_state = self.agent.get_init_state() + self.user_state = self.user.get_init_state() + first_message = deepcopy(get_default_first_agent_message(self.language)) + first_message.timestamp = get_now() + self.trajectory = [first_message] + self.message = first_message + self.from_role = Role.AGENT + self.to_role = Role.USER + + # with open(DATA_DIR / "tools" / f"{self.domain}.json", 'w') as f: + # json.dump([tools.openai_schema for tools in self.agent.tools], f, indent=4, ensure_ascii=False) + + def run(self) -> SimulationRun: + """ + Run the simulation. + + Returns: + SimulationRun: The simulation run. + """ + start_time = get_now() + start = time.perf_counter() + self.initialize() + while not self.done: + self.step() + if self.step_count >= self.max_steps: + self.done = True + self.termination_reason = TerminationReason.MAX_STEPS + if self.num_errors >= self.max_errors: + self.done = True + self.termination_reason = TerminationReason.TOO_MANY_ERRORS + duration = time.perf_counter() - start + messages = self.get_trajectory() + res = get_cost(messages) + if res is None: + agent_cost, user_cost = None, None + else: + agent_cost, user_cost = res + + simulation_run = SimulationRun( + id=str(uuid.uuid4()), + task_id=self.task.id, + start_time=start_time, + end_time=get_now(), + duration=duration, + termination_reason=self.termination_reason.value, + reward_info=None, + user_cost=user_cost, + agent_cost=agent_cost, + messages=messages, + seed=self.seed, + states=self.get_states(self.environment.tools.db, self.environment.tools.db.time) + ) + return simulation_run + + def step(self): + """ + Perform one step of the simulation. + Sends self.message from self.from_role to self.to_role + This can either be a message from agent to user/environment, environment to agent, or user to agent + Updates self.trajectory + """ + if self.done: + raise ValueError("Simulation is done") + logger.debug( + f"Step {self.step_count}. Sending message from {self.from_role} to {self.to_role}" + ) + logger.debug( + f"Step {self.step_count}.\nFrom role: {self.from_role}\nTo role: {self.to_role}\nMessage: {self.message}" + ) + if self.from_role in [Role.AGENT, Role.ENV] and self.to_role == Role.USER: + user_msg, self.user_state = self.user.generate_next_message( + self.message, self.user_state + ) + user_msg.validate() + if UserSimulator.is_stop(user_msg): + self.done = True + self.termination_reason = TerminationReason.USER_STOP + self.trajectory.append(user_msg) + self.message = user_msg + self.from_role = Role.USER + if user_msg.is_tool_call(): + self.to_role = Role.ENV + else: + self.to_role = Role.AGENT + elif ( + self.from_role == Role.USER or self.from_role == Role.ENV + ) and self.to_role == Role.AGENT: + # Retry up to 3 times if agent generates invalid message + max_retries = 3 + retry_count = 0 + agent_msg = None + original_agent_state = deepcopy(self.agent_state) # Save original state + + while retry_count < max_retries: + # Use a copy of the original state for each retry + current_agent_state = deepcopy(original_agent_state) + agent_msg, updated_agent_state = self.agent.generate_next_message( + self.message, current_agent_state + ) + + # Check if the message is valid (has text content or is tool call) + if agent_msg.has_text_content() or agent_msg.is_tool_call(): + # Only update the actual agent state if we get a valid message + self.agent_state = updated_agent_state + break + + retry_count += 1 + logger.warning(f"Agent generated invalid message (attempt {retry_count}/{max_retries}): {agent_msg}") + + # If all retries failed, terminate with INVALID_AGENT_MESSAGE + if retry_count >= max_retries: + self.done = True + self.termination_reason = TerminationReason.INVALID_AGENT_MESSAGE + return + + agent_msg.validate() + if self.agent.is_stop(agent_msg): + self.done = True + self.termination_reason = TerminationReason.AGENT_STOP + self.trajectory.append(agent_msg) + self.message = agent_msg + self.from_role = Role.AGENT + if agent_msg.is_tool_call(): + self.to_role = Role.ENV + else: + self.to_role = Role.USER + elif self.from_role in [Role.AGENT, Role.USER] and self.to_role == Role.ENV: + if not self.message.is_tool_call(): + raise ValueError("Agent or User should send tool call to environment") + tool_msgs = [] + for tool_call in self.message.tool_calls: + tool_msg = self.environment.get_response(tool_call) + tool_msgs.append(tool_msg) + # Increment error count if tool call failed + if tool_msg.error: + self.num_errors += 1 + assert len(self.message.tool_calls) == len(tool_msgs), ( + "Number of tool calls and tool messages should be the same" + ) + self.trajectory.extend(tool_msgs) + if ( + len(tool_msgs) > 1 + ): + self.message = MultiToolMessage( + role="tool", + tool_messages=tool_msgs, + ) + else: + self.message = tool_msgs[0] + self.to_role = self.from_role + self.from_role = Role.ENV + else: + raise ValueError( + f"Invalid role combination. From role: {self.from_role}, To role: {self.to_role}" + ) + self.step_count += 1 + + def get_trajectory(self) -> list[Message]: + """ + Get the trajectory of the simulation. + The trajectory is sorted by timestamp, turn_idx are added to messages, trajectory is returned. + """ + messages: list[Message] = sorted( + deepcopy(self.trajectory), + key=lambda x: x.timestamp, + ) + trajectory = [] + for i, msg in enumerate(messages): + msg = deepcopy(msg) + msg.turn_idx = i + trajectory.append(msg) + return trajectory + + def get_states(self, db: DB, env_time: str) -> Dict[str, Any]: + """ + Split the states into a dictionary. + """ + from vita.utils import str_to_datetime + + states = [] + if hasattr(db, "orders"): + states += list(db.orders.values()) + if hasattr(db, "books"): + states += list(db.books.values()) + if hasattr(db, "reservations"): + states += list(db.reservations.values()) + + states_dict = {"old_states": [], "new_states": []} + for state in states: + if str_to_datetime(state.update_time) < str_to_datetime(env_time): + states_dict["old_states"].append(state) + else: + states_dict["new_states"].append(state) + return states_dict + + @classmethod + def validate_message_history(cls, message_history: list[Message]): + """ + Validate a message history. + - Should only contain AssistantMessage, UserMessage, ToolMessage + - All assistant/user messages should be either to user or tool call, not both. + - If n tool calls are made by a participant, exactly n tool messages should follow with requestor matching the participant. + """ + num_expected_tool_messages = 0 + requestor = None + for msg in message_history: + if isinstance(msg, AssistantMessage) or isinstance(msg, UserMessage): + msg.validate() + if msg.is_tool_call(): + if num_expected_tool_messages > 0: + raise ValueError( + f"{num_expected_tool_messages} tool messages are missing. Got {msg.role} message." + ) + num_expected_tool_messages = len(msg.tool_calls) + requestor = msg.role + else: + num_expected_tool_messages = 0 + requestor = None + elif isinstance(msg, ToolMessage): + if num_expected_tool_messages == 0 or requestor is None: + raise ValueError("No tool messages expected.") + if requestor != msg.requestor: + raise ValueError( + f"Got tool message from {msg.requestor}, expected {requestor}." + ) + num_expected_tool_messages -= 1 + else: + raise ValueError(f"Invalid message type: {type(msg)}") + + def _add_timestamps( + self, message_history: list[Message] + ) -> list[tuple[str, Message]]: + """ + Add timestamps to the message history. + This is used to sort the messages by timestamp. + """ + time_offset = datetime.now() - timedelta(seconds=len(message_history)) + for i, msg in enumerate(message_history): + msg.timestamp = format_time(time_offset + timedelta(seconds=i)) + return message_history diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/agent_system_prompt.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/agent_system_prompt.yaml new file mode 100644 index 0000000..3b20305 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/agent_system_prompt.yaml @@ -0,0 +1,28 @@ +name: agent_system_prompt +chinese: |- + # 环境 + - 当前时间:{time} + + # 工具使用规范: + - 当用户需求需要调工具来完成时,先判断是否已知全部参数信息,如果已知则抽取相应参数,否则询问用户相关参数值 + - 当用户无法提供相关信息时,首先通过工具获取相关信息 + - 参考Precondition和Postcondition完成任务 + + # 对话规范 + - 仅利用上文已有信息,禁止无根据地构造信息并回复用户 + - 以完成用户需求为目标,禁止发散性引导用户提出新需求 + - 完成用户的任务需求后询问用户是否还有其他需求,如果用户表示没有,生成 '###STOP###' 标记来结束对话 + +english: |- + # Environment + - Current time: {time} + + # Tool Usage Guidelines: + - When the user's needs require using tools to complete, first determine whether all parameter information is known. If it is known, extract the corresponding parameters, otherwise ask the user for the relevant parameter values + - When the user cannot provide relevant information, first obtain relevant information through tools + - Complete tasks based on Precondition and Postcondition + + # Conversation Guidelines + - Only use information from the above context, prohibit constructing information without basis and replying to users + - Focus on completing user needs, prohibit divergent guidance to users to propose new needs + - After completing the user's task requirements, ask if there are any other needs. If the user indicates no, generate '###STOP###' mark to end the conversation diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/solo_agent_system_prompt.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/solo_agent_system_prompt.yaml new file mode 100644 index 0000000..7a630b2 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/solo_agent_system_prompt.yaml @@ -0,0 +1,34 @@ +name: solo_agent_system_prompt +chinese: |- + # 环境 + - 当前时间:{time} + + # 工具使用规范 + - 根据任务需求和提供的信息,确定需要调用的工具及参数 + - 按照逻辑顺序执行必要的工具调用来完成任务 + - 参考Precondition和Postcondition确保任务正确完成 + + # 任务要求 + - 你需要根据提供的完整任务描述和用户信息,按照顺序一次性完成用户的需求 + - 任务中涉及到的先下单后取消订单、先下单后修改订单等操作,请严格按照任务描述中的要求顺序执行 + - 所有必要的信息都已在任务描述中提供,包括用户偏好、约束条件等 + - 执行过程中不能与用户进行交互 + - 默认对于需要用户确认的逻辑,都认为用户已经确认 + - 在完成用户所有的需求以后,生成 '###STOP###' 标记来结束对话 + +english: |- + # Environment + - Current time: {time} + + # Tool Usage Guidelines + - Determine the tools and parameters to be called based on task requirements and provided information + - Execute necessary tool calls in logical order to complete tasks + - Refer to Precondition and Postcondition to ensure tasks are completed correctly + + # Task Requirements + - You need to complete the user's requirements in order at once based on the complete task description and user information provided + - For operations involving placing orders first and then canceling orders, placing orders first and then modifying orders, etc., please strictly follow the order requirements in the task description + - All necessary information has been provided in the task description, including user preferences, constraints, etc. + - No interaction with users during execution + - By default, for logic that requires user confirmation, it is considered that the user has already confirmed + - After completing all user requirements, generate '###STOP###' mark to end the conversation diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/run.py b/examples/AgenticBenchmarks/VitaBench/src/vita/run.py new file mode 100644 index 0000000..36f25e9 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/run.py @@ -0,0 +1,920 @@ +import json +import multiprocessing +import random +import threading +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Optional +from datetime import datetime + +from loguru import logger + +from vita.agent.llm_agent import LLMAgent, LLMSoloAgent +from vita.data_model.simulation import ( + AgentInfo, + Info, + Results, + RunConfig, + SimulationRun, + UserInfo, +) + +from vita.data_model.tasks import Task +from vita.data_model.simulation import EvaluationType +from vita.environment.environment import get_cross_environment, EnvironmentInfo +from vita.evaluator.evaluator import evaluate_simulation +from vita.metrics.agent_metrics import compute_metrics +from vita.orchestrator.orchestrator import Orchestrator +from vita.registry import RegistryInfo, registry +from vita.user.user_simulator import get_global_user_sim_guidelines +from vita.utils.display import ConsoleDisplay +from vita.utils.pydantic_utils import get_pydantic_hash +from vita.utils.utils import DATA_DIR, get_commit_hash, get_now, show_dict_diff, global_time +from vita.utils.csv_utils import save_results_to_csv + + +def get_options() -> RegistryInfo: + """ + Returns options for the simulator. + """ + return registry.get_info() + + +def get_environment_info( + domain_name: str, include_tool_info: bool = False +) -> EnvironmentInfo: + """Get information about the environment for a registered Domain""" + return EnvironmentInfo( + domain_name=domain_name, + tool_defs=None + ) + + +def load_tasks(task_set_name: str, language: str = None) -> list[Task]: + """ + Loads the tasks for the given domain. + """ + global registry + if ',' in task_set_name: + task_loader = registry.get_tasks_loader("cross_domain") + else: + task_loader = registry.get_tasks_loader(task_set_name) + tasks = task_loader(language) + return tasks + + +def get_tasks( + task_set_name: str, + task_ids: Optional[list[str]] = None, + num_tasks: Optional[int] = None, + language: str = None, +) -> list[Task]: + """ + Loads the tasks for the given domain. + """ + if task_ids is None and num_tasks is None: + return load_tasks(task_set_name=task_set_name, language=language) + tasks = [] + if task_ids is not None: + tasks = [ + task for task in load_tasks(task_set_name=task_set_name, language=language) if task.id in task_ids + ] + if len(tasks) != len(task_ids): + missing_tasks = set(task_ids) - set([task.id for task in tasks]) + raise ValueError( + f"Not all tasks were found for task set {task_set_name}: {missing_tasks}" + ) + if num_tasks is not None: + tasks = load_tasks(task_set_name=task_set_name, language=language)[:num_tasks] + return tasks + + +def make_run_name(config: RunConfig) -> str: + """ + Make a run name from the run config + """ + clean_llm_agent_name = config.llm_agent.split("/")[-1] + agent_name = f"{config.agent}_{clean_llm_agent_name}" + + clean_llm_user_name = config.llm_user.split("/")[-1] + user_name = f"{config.user}_{clean_llm_user_name}" + + # Add think mode indicator to the filename if enable_think is True + think_suffix = "_think" if config.enable_think else "" + + return f"{get_now()}_{config.domain}_{agent_name}_{user_name}{think_suffix}" + + +def run_domain(config: RunConfig) -> Results: + """ + Run simulations for a domain + Returns: + Results: The simulation results + """ + config.validate() + ConsoleDisplay.display_run_config(config) + + # Check if this is a re-evaluation mode with optional re-run + if hasattr(config, 're_evaluate_file') and config.re_evaluate_file: + results = re_evaluate_simulation(config) + return results + + if config.task_set_name is None: + task_set_name = config.domain + else: + task_set_name = config.task_set_name + tasks = get_tasks(task_set_name, config.task_ids, config.num_tasks, config.language) + + num_trials = config.num_trials + save_to = config.save_to + if save_to is None: + save_to = f"{make_run_name(config)}.json" + save_to = DATA_DIR / "simulations" / save_to + config.save_to = save_to + + # Run simulations with the specified evaluation type + simulation_results = run_tasks( + domain=config.domain, + tasks=tasks, + agent=config.agent, + user=config.user, + llm_agent=config.llm_agent, + llm_args_agent=config.llm_args_agent, + llm_user=config.llm_user, + llm_args_user=config.llm_args_user, + num_trials=num_trials, + max_steps=config.max_steps, + max_errors=config.max_errors, + save_to=save_to, + console_display=True, + evaluation_type=config.evaluation_type, + max_concurrency=config.max_concurrency, + seed=config.seed, + log_level=config.log_level, + enable_think=config.enable_think, + llm_evaluator=config.llm_evaluator, + llm_args_evaluator=config.llm_args_evaluator, + language=config.language, + ) + + metrics = compute_metrics(simulation_results) + ConsoleDisplay.display_agent_metrics(metrics) + + if config.csv_output_file and simulation_results.simulations: + try: + csv_output = config.csv_output_file + save_results_to_csv(simulation_results, csv_output, config, metrics) + ConsoleDisplay.console.print(f"\n💾 [bold green]Results appended to CSV: {csv_output}[/bold green]") + except Exception as e: + ConsoleDisplay.console.print(f"\n[bold red]Error saving to CSV: {e}[/bold red]") + + return simulation_results + + +def run_tasks( + domain: str, + tasks: list[Task], + agent: str, + user: str, + llm_agent: Optional[str] = None, + llm_args_agent: Optional[dict] = None, + llm_user: Optional[str] = None, + llm_args_user: Optional[dict] = None, + num_trials: int = 1, + max_steps: int = 100, + max_errors: int = 10, + save_to: Optional[str | Path] = None, + console_display: bool = True, + evaluation_type: EvaluationType = "trajectory", + max_concurrency: int = 1, + seed: Optional[int] = 300, + log_level: Optional[str] = "INFO", + enable_think: bool = False, + llm_evaluator: Optional[str] = None, + llm_args_evaluator: Optional[dict] = None, + language: str = None, +) -> Results: + """ + Runs tasks for a given domain. + If llm_as_judge is True, the LLM will be used to annotate the simulation run. + Calculates the reward for the simulation run. + Args: + domain (str): The domain to run the simulation on. + tasks (list[Task]): The tasks to run. + agent (str): The agent to run the simulation on. + user (str): The user to run the simulation on. + llm_agent (str): The model to use for the agent. + llm_args_agent (dict): The arguments to pass to the LLM for the agent. + llm_user (str): The model to use for the user. + llm_args_user (dict): The arguments to pass to the LLM for the user. + num_trials (int): The number of trials to run the simulation on. + max_steps (int): The maximum number of steps to run the simulation. + max_errors (int): The maximum number of errors to allow in the simulation. + save_to (str | Path): The path to json file where to save the simulation results. If the file already exists, it will try to resume the run. + console_display (bool): Whether to display the simulation results in the console. + evaluation_type (EvaluationType): The type of evaluation to use. + max_concurrency (int): The maximum number of concurrent simulations to run. + seed (int): The seed to use for the simulation. + log_level (str): The log level to use. + enable_think (bool): Whether to enable think mode for the agent LLM. + Returns: + The simulation results and the annotations (if llm_review is True). + """ + if isinstance(save_to, str): + save_to = Path(save_to) + # Set log level from config + logger.remove() + logger.add(lambda msg: print(msg), level=log_level) + if len(tasks) == 0: + raise ValueError("No tasks to run") + if num_trials <= 0: + raise ValueError("Number of trials must be greater than 0") + if max_steps <= 0: + raise ValueError("Max steps must be greater than 0") + if max_errors <= 0: + raise ValueError("Max errors must be greater than 0") + + random.seed(seed) + + seeds = [random.randint(0, 1000000) for _ in range(num_trials)] + if llm_args_agent is not None and "seed" in llm_args_agent: + logger.warning("Each trial will modify the seed for the agent") + + if llm_args_user is not None and "seed" in llm_args_user: + logger.warning("Each trial will modify the seed for the user") + + lock = threading.Lock() + + info = get_info( + domain=domain, + agent=agent, + user=user, + llm_agent=llm_agent, + llm_args_agent=llm_args_agent, + llm_user=llm_user, + llm_args_user=llm_args_user, + num_trials=num_trials, + max_steps=max_steps, + max_errors=max_errors, + seed=seed, + language=language, + ) + simulation_results = Results( + info=info, + tasks=tasks, + simulations=[], + ) + done_runs = set() + if save_to is not None: + # If save_to already exists, check if the user wants to resume the run. + if save_to.exists(): + response = ( + ConsoleDisplay.console.input( + "[yellow]File [bold]{}[/bold] already exists. Do you want to resume the run? (y/n)[/yellow] ".format( + save_to + ) + ) + .lower() + .strip() + ) + if response != "y": + raise FileExistsError( + f"File {save_to} already exists. Please delete it or use a different save_to name." + ) + with open(save_to, "r") as fp: + prev_simulation_results = Results.model_validate_json(fp.read()) + # Check if the run config has changed + if get_pydantic_hash(prev_simulation_results.info) != get_pydantic_hash( + simulation_results.info + ): + diff = show_dict_diff( + prev_simulation_results.info.model_dump(), + simulation_results.info.model_dump(), + ) + ConsoleDisplay.console.print( + f"The run config has changed.\n\n{diff}\n\nDo you want to resume the run? (y/n)" + ) + response = ( + ConsoleDisplay.console.input( + "[yellow]File [bold]{}[/bold] already exists. Do you want to resume the run? (y/n)[/yellow] ".format( + save_to + ) + ) + .lower() + .strip() + ) + if response != "y": + raise ValueError( + "The run config has changed. Please delete the existing file or use a different save_to name." + ) + # Check if the task set has changed + if not all( + get_pydantic_hash(task) == get_pydantic_hash(prev_task) + for task, prev_task in zip( + sorted(simulation_results.tasks, key=lambda x: x.id), + sorted(prev_simulation_results.tasks, key=lambda x: x.id), + ) + ): + raise ValueError( + "The task set has changed. Please delete the existing file or use a different save_to name." + ) + # Check which of the runs have already been done + done_runs = set( + [ + (sim.trial, sim.task_id, sim.seed) + for sim in prev_simulation_results.simulations + ] + ) + simulation_results = prev_simulation_results + ConsoleDisplay.console.print( + f"[bold yellow]Resuming run from {len(done_runs)} runs. {len(tasks) * num_trials - len(done_runs)} runs remaining.[/bold yellow]" + ) + # Create new save file + else: + # Check if save_to exists and create parent directories if needed + if not save_to.parent.exists(): + save_to.parent.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving simulation batch to {save_to}") + with open(save_to, "w") as fp: + fp.write(simulation_results.model_dump_json(indent=2)) + + def _save(simulation: SimulationRun): + def serialize_datetime(obj): + if isinstance(obj, datetime): + return obj.isoformat() + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + if save_to is None: + return + with lock: + with open(save_to, "r") as fp: + ckpt = json.load(fp) + + simulation_dict = simulation.model_dump() + + ckpt["simulations"].append(simulation_dict) + with open(save_to, "w", encoding="utf-8") as fp: + json.dump(ckpt, fp, indent=2, ensure_ascii=False, default=serialize_datetime) + + def _run(task: Task, trial: int, seed: int, progress_str: str) -> dict: + ConsoleDisplay.console.print( + f"[bold green]{progress_str} Running task {task.id}, trial {trial + 1}[/bold green]" + ) + try: + simulation = run_task( + domain=domain, + task=task, + agent=agent, + user=user, + llm_agent=llm_agent, + llm_args_agent=llm_args_agent, + llm_user=llm_user, + llm_args_user=llm_args_user, + max_steps=max_steps, + max_errors=max_errors, + evaluation_type=evaluation_type, + seed=seed, + max_retries=3, # Each task retries 3 times + enable_think=enable_think, + llm_evaluator=llm_evaluator, + llm_args_evaluator=llm_args_evaluator, + language=language, + ) + simulation.trial = trial + if console_display: + ConsoleDisplay.display_simulation(simulation, show_details=False) + _save(simulation) + return {"status": "success", "simulation": simulation} + except Exception as e: + logger.error(f"Error running task {task.id}, trial {trial}: {e}") + logger.warning(f"Task {task.id}, trial {trial} failed but continuing with other tasks") + if console_display: + ConsoleDisplay.console.print(f"[bold red]Task {task.id}, trial {trial} failed: {e}[/bold red]") + return {"status": "failed", "task_id": task.id, "trial": trial, "error": str(e)} + + args = [] + for trial in range(num_trials): + for i, task in enumerate(tasks): + if (trial, task.id, seeds[trial]) in done_runs: + ConsoleDisplay.console.print( + f"[bold yellow]Skipping task {task.id}, trial {trial} because it has already been run.[/bold yellow]" + ) + continue + progress_str = f"{i}/{len(tasks)} (trial {trial + 1}/{num_trials})" + args.append((task, trial, seeds[trial], progress_str)) + + with ThreadPoolExecutor(max_workers=max_concurrency) as executor: + res = list(executor.map(_run, *zip(*args))) + # Separate successful and failed tasks + successful_sims = [] + failed_sims = [] + for sim_result in res: + if sim_result["status"] == "success": + successful_sims.append(sim_result["simulation"]) + else: + failed_sims.append(sim_result) + + # Only add successful tasks to results + simulation_results.simulations.extend(successful_sims) + + # Count successful and failed tasks + ConsoleDisplay.console.print( + f"\n✨ [bold green]Successfully completed all simulations![/bold green]\n" + f"📊 [bold blue]Statistics:[/bold blue]\n" + f" ✅ Successful tasks: {len(successful_sims)}\n" + f" ❌ Failed tasks: {len(failed_sims)}\n" + f" 📝 Total tasks: {len(res)}\n" + f"To review the simulations, run: [bold blue]vita view[/bold blue]" + ) + + if failed_sims: + ConsoleDisplay.console.print(f"\n[bold red]Failed tasks:[/bold red]") + for failed_result in failed_sims: + ConsoleDisplay.console.print(f" - Task {failed_result['task_id']}, Trial {failed_result['trial']}: {failed_result['error']}") + + # Display all failed task IDs + failed_task_ids = list(set([failed_result['task_id'] for failed_result in failed_sims])) + ConsoleDisplay.console.print(f"\n[bold red]Failed task IDs:[/bold red] {', '.join(failed_task_ids)}") + + return simulation_results + + +def run_task( + domain: str, + task: Task, + agent: str, + user: str, + llm_agent: Optional[str] = None, + llm_args_agent: Optional[dict] = None, + llm_user: Optional[str] = None, + llm_args_user: Optional[dict] = None, + max_steps: int = 100, + max_errors: int = 10, + evaluation_type: EvaluationType = "trajectory", + seed: Optional[int] = None, + max_retries: int = 3, # Add maximum retry count parameter + enable_think: bool = False, + llm_evaluator: Optional[str] = None, + llm_args_evaluator: Optional[dict] = None, + language: str = None, +) -> SimulationRun: + """ + Runs tasks for a given domain. + If llm_as_judge is True, the LLM will be used to annotate the simulation run. + Calculates the reward for the simulation run. + Args: + domain (str): The domain to run the simulation on. + task (Task): The task to run. + agent (str): The agent to run the simulation on. + user (str): The user to run the simulation on. + llm_agent (str): The model to use for the agent. + llm_args_agent (dict): The arguments to pass to the LLM for the agent. + llm_user (str): The model to use for the user. + llm_args_user (dict): The arguments to pass to the LLM for the user. + max_steps (int): The maximum number of steps to run the simulation. + max_errors (int): The maximum number of errors to allow in the simulation. + evaluation_type (EvaluationType): The type of evaluation to use. + seed (int): The seed to use for the simulation. + max_retries (int): The maximum number of retries if an error occurs. + Returns: + The simulation run. + """ + + if max_steps <= 0: + raise ValueError("Max steps must be greater than 0") + if max_errors <= 0: + raise ValueError("Max errors must be greater than 0") + + for attempt in range(max_retries + 1): # +1 because the first attempt is not counted as a retry + try: + return _run_task_internal( + domain=domain, + task=task, + agent=agent, + user=user, + llm_agent=llm_agent, + llm_args_agent=llm_args_agent, + llm_user=llm_user, + llm_args_user=llm_args_user, + max_steps=max_steps, + max_errors=max_errors, + evaluation_type=evaluation_type, + seed=seed, + enable_think=enable_think, + llm_evaluator=llm_evaluator, + llm_args_evaluator=llm_args_evaluator, + language=language + ) + except Exception as e: + if attempt < max_retries: + logger.warning(f"Task {task.id} failed on attempt {attempt + 1}/{max_retries + 1}: {e}. Retrying...") + # Clear global state, prepare for retry + _clear_global_state() + continue + else: + logger.error(f"Task {task.id} failed after {max_retries + 1} attempts. Last error: {e}") + raise e + + +def _run_task_internal( + domain: str, + task: Task, + agent: str, + user: str, + llm_agent: Optional[str] = None, + llm_args_agent: Optional[dict] = None, + llm_user: Optional[str] = None, + llm_args_user: Optional[dict] = None, + max_steps: int = 100, + max_errors: int = 10, + evaluation_type: EvaluationType = "trajectory", + seed: Optional[int] = None, + enable_think: bool = False, + llm_evaluator: Optional[str] = None, + llm_args_evaluator: Optional[dict] = None, + language: str = None, +) -> SimulationRun: + """ + Internal implementation of run_task without retry logic. + """ + _clear_global_state() + + global registry + logger.info( + f"STARTING SIMULATION: Domain: {domain}, Task: {task.id}, Agent: {agent}, User: {user}" + ) + if "," in domain: + environment = get_cross_environment(domain, task.environment, language) + else: + environment_constructor = registry.get_env_constructor(domain) + environment = environment_constructor(task.environment, language) + AgentConstructor = registry.get_agent_constructor(agent) + + solo_mode = False + time = environment.tools.db.time + global global_time + global_time = time + logger.info(f"|| Time Set To: {time}") + + if issubclass(AgentConstructor, LLMAgent): + agent = AgentConstructor( + tools=environment.get_tools(), + domain_policy=environment.get_policy(), + llm=llm_agent, + llm_args=llm_args_agent, + time=time, + enable_think=enable_think, + language=language, + ) + elif issubclass(AgentConstructor, LLMSoloAgent): + solo_mode = True + agent = AgentConstructor( + tools=environment.get_tools(), + domain_policy=environment.get_policy(), + llm=llm_agent, + llm_args=llm_args_agent, + time=time, + enable_think=enable_think, + language=language, + ) + else: + raise ValueError( + f"Unknown agent type: {AgentConstructor}. Should be LLMAgent or LLMSoloAgent" + ) + + UserConstructor = registry.get_user_constructor(user) + + user = UserConstructor( + persona=str(task.user_scenario.user_profile), + instructions=str(task.instructions), + llm=llm_user, + llm_args=llm_args_user, + language=language, + ) + + orchestrator = Orchestrator( + domain=domain, + agent=agent, + user=user, + environment=environment, + task=task, + max_steps=max_steps, + max_errors=max_errors, + seed=seed, + solo_mode=solo_mode, + language=language + ) + simulation = orchestrator.run() + + reward_info = evaluate_simulation( + domain=domain, + task=task, + simulation=simulation, + evaluation_type=evaluation_type, + llm_evaluator=llm_evaluator, + llm_args_evaluator=llm_args_evaluator, + language=language, + ) + simulation.reward_info = reward_info + + logger.info( + f"FINISHED SIMULATION: Domain: {domain}, Task: {task.id}, Agent: {agent.__class__.__name__}, User: {user.__class__.__name__}. " + f"Reward: {reward_info.reward} | {reward_info.reward_breakdown}" + ) + + return simulation + + +def get_info( + domain: str, + agent: str, + user: str, + llm_agent: Optional[str] = None, + llm_args_agent: Optional[dict] = None, + llm_user: Optional[str] = None, + llm_args_user: Optional[dict] = None, + num_trials: int = 1, + max_steps: int = 100, + max_errors: int = 10, + seed: Optional[int] = None, + language: str = None, +) -> Info: + def clean_llm_args(llm_args: Optional[dict]) -> Optional[dict]: + """Clean LLM arguments to make them JSON serializable""" + if llm_args is None: + return None + + cleaned = {} + for key, value in llm_args.items(): + if hasattr(value, '__class__') and value.__class__.__name__ == 'type': + # Replace type objects with their class name + cleaned[key] = value.__name__ + else: + cleaned[key] = value + return cleaned + + user_info = UserInfo( + implementation=user, + llm=llm_user, + llm_args=clean_llm_args(llm_args_user), + global_simulation_guidelines=get_global_user_sim_guidelines(language), + ) + agent_info = AgentInfo( + implementation=agent, + llm=llm_agent, + llm_args=clean_llm_args(llm_args_agent), + ) + environment_info = get_environment_info( + domain, include_tool_info=False + ) + return Info( + git_commit=get_commit_hash(), + num_trials=num_trials, + max_steps=max_steps, + max_errors=max_errors, + user_info=user_info, + agent_info=agent_info, + environment_info=environment_info, + seed=seed, + ) + + +def _clear_global_state(): + from vita.data_model.tasks import ( + StoreBaseModel, ProductBaseModel + ) + + base_classes = [ + StoreBaseModel, ProductBaseModel + ] + + for base_class in base_classes: + try: + if hasattr(base_class, 'clear_thread_data'): + base_class.clear_thread_data() + except Exception as e: + pass + + +def re_evaluate_simulation(config: RunConfig) -> Results: + """ + Re-evaluate simulations from a saved simulation file, with optional re-running of specific tasks. + + Args: + config (RunConfig): The run configuration containing: + - re_evaluate_file (str): Path to the simulation file to load + - evaluation_type (EvaluationType): The type of evaluation to use + - save_to (Optional[str | Path]): Path to save the re-evaluation results + - re_run (bool): Whether to re-run tasks specified by task_ids + - task_ids (Optional[list[str]]): Task IDs to re-run (only used if re_run is True) + + Returns: + Results: The re-evaluation results + """ + re_evaluate_file = config.re_evaluate_file + evaluation_type = config.evaluation_type + save_to = config.save_to + re_run = getattr(config, 're_run', False) + task_ids_to_rerun = config.task_ids if re_run else None + + # Load the original simulation results + simulation_path = Path(re_evaluate_file) + if not simulation_path.exists(): + raise FileNotFoundError(f"Simulation file not found: {re_evaluate_file}") + + # Load the original results + with open(simulation_path, "r") as fp: + original_results = Results.model_validate_json(fp.read()) + + logger.info(f"Loaded simulation file: {re_evaluate_file}") + logger.info(f"Found {len(original_results.simulations)} simulations") + + # Handle re-running specific tasks if requested + if re_run and task_ids_to_rerun: + logger.info(f"Re-running tasks: {task_ids_to_rerun}") + + # Get tasks to re-run + if config.task_set_name is None: + task_set_name = config.domain + else: + task_set_name = config.task_set_name + + tasks_to_rerun = get_tasks(task_set_name, task_ids_to_rerun, None, config.language) + + # Run the specific tasks + rerun_results = run_tasks( + domain=config.domain, + tasks=tasks_to_rerun, + agent=config.agent, + user=config.user, + llm_agent=config.llm_agent, + llm_args_agent=config.llm_args_agent, + llm_user=config.llm_user, + llm_args_user=config.llm_args_user, + num_trials=config.num_trials, + max_steps=config.max_steps, + max_errors=config.max_errors, + save_to=None, # Don't save intermediate results + console_display=True, + evaluation_type=evaluation_type, + max_concurrency=config.max_concurrency, + seed=config.seed, + log_level=config.log_level, + enable_think=config.enable_think, + llm_evaluator=config.llm_evaluator, + llm_args_evaluator=config.llm_args_evaluator, + language=config.language, + ) + + # Remove old simulations for the re-run task IDs + original_simulations = [ + sim for sim in original_results.simulations + if sim.task_id not in task_ids_to_rerun + ] + + # Combine original simulations (excluding re-run tasks) with new simulations + combined_simulations = original_simulations + rerun_results.simulations + + # Update original_results with combined simulations + original_results.simulations = combined_simulations + + logger.info(f"Combined {len(original_simulations)} existing simulations with {len(rerun_results.simulations)} re-run simulations") + + logger.info(f"Total simulations to re-evaluate: {len(original_results.simulations)}") + + # Update tasks list if we re-ran any tasks (to ensure we have the latest task definitions) + final_tasks = original_results.tasks + if re_run and task_ids_to_rerun: + # Create a mapping of task_id to task for efficient lookup + existing_task_ids = {task.id for task in original_results.tasks} + new_tasks = [task for task in tasks_to_rerun if task.id not in existing_task_ids] + if new_tasks: + final_tasks = original_results.tasks + new_tasks + logger.info(f"Added {len(new_tasks)} new tasks to the task list") + + # Create new results object for re-evaluation + re_eval_results = Results( + timestamp=get_now(), + info=original_results.info, + tasks=final_tasks, + simulations=[], + ) + + # Asynchronously re-evaluate each simulation + def _re_evaluate_single(simulation, task_dict, domain_name, progress_str): + """Function to re-evaluate a single simulation""" + logger.info(f"{progress_str} Re-evaluating simulation: {simulation.task_id}") + + task = task_dict.get(simulation.task_id) + if task is None: + logger.warning(f"Task {simulation.task_id} not found, skipping simulation") + return {"status": "skipped", "simulation": simulation, "reason": "task_not_found"} + + try: + reward_info = evaluate_simulation( + simulation=simulation, + task=task, + evaluation_type=evaluation_type, + domain=domain_name, + llm_evaluator=config.llm_evaluator, + llm_args_evaluator=config.llm_args_evaluator, + language=config.language, + ) + + # Create a new simulation run with updated reward info + re_eval_simulation = SimulationRun( + id=simulation.id, + task_id=simulation.task_id, + timestamp=simulation.timestamp, + start_time=simulation.start_time, + end_time=simulation.end_time, + duration=simulation.duration, + termination_reason=simulation.termination_reason, + agent_cost=simulation.agent_cost, + user_cost=simulation.user_cost, + reward_info=reward_info, # Updated reward info + messages=simulation.messages, # Keep original messages + states=simulation.states, # Keep original states + trial=simulation.trial, + seed=simulation.seed, + ) + + logger.info(f"Re-evaluation completed for {simulation.task_id}: reward = {reward_info.reward}") + return {"status": "success", "simulation": re_eval_simulation} + + except Exception as e: + logger.error(f"Error re-evaluating simulation {simulation.task_id}: {e}") + return {"status": "failed", "simulation": simulation, "error": str(e)} + + # Create task dictionary for quick lookup + task_dict = {task.id: task for task in final_tasks} + domain_name = original_results.info.environment_info.domain_name + + # Prepare parameters for asynchronous execution + args = [] + for i, simulation in enumerate(original_results.simulations): + progress_str = f"({i + 1}/{len(original_results.simulations)})" + args.append((simulation, task_dict, domain_name, progress_str)) + + # Use thread pool for asynchronous re-evaluation execution + max_concurrency = getattr(config, 'max_concurrency', 4) # Default concurrency is 4 + with ThreadPoolExecutor(max_workers=max_concurrency) as executor: + results = list(executor.map(_re_evaluate_single, *zip(*args))) + + # Process results + successful_count = 0 + failed_count = 0 + skipped_count = 0 + + for result in results: + if result["status"] == "success": + re_eval_results.simulations.append(result["simulation"]) + successful_count += 1 + elif result["status"] == "failed": + # For failed cases, add original simulation but throw error + re_eval_results.simulations.append(result["simulation"]) + failed_count += 1 + print(f"Error of {result['simulation'].task_id} trial {result['simulation'].trial} re-evaluate: {result['error']}") + elif result["status"] == "skipped": + # For skipped cases, add original simulation + re_eval_results.simulations.append(result["simulation"]) + skipped_count += 1 + + # Output statistics + ConsoleDisplay.console.print( + f"\n✨ [bold green]Re-evaluation completed![/bold green]\n" + f"📊 [bold blue]Statistics:[/bold blue]\n" + f" ✅ Successfully re-evaluated: {successful_count}\n" + f" ❌ Failed: {failed_count}\n" + f" ⏭️ Skipped: {skipped_count}\n" + f" 📝 Total: {len(results)}" + ) + + metrics = compute_metrics(re_eval_results) + ConsoleDisplay.display_agent_metrics(metrics) + + if config.csv_output_file and re_eval_results.simulations: + try: + csv_output = config.csv_output_file + save_results_to_csv(re_eval_results, csv_output, config, metrics) + ConsoleDisplay.console.print(f"\n💾 [bold green]Results appended to CSV: {csv_output}[/bold green]") + except Exception as e: + ConsoleDisplay.console.print(f"\n[bold red]Error saving to CSV: {e}[/bold red]") + + # Save results if save_to is specified + if save_to is not None: + if isinstance(save_to, str): + save_to = Path(save_to) + + # Create parent directories if needed + if not save_to.parent.exists(): + save_to.parent.mkdir(parents=True, exist_ok=True) + + # Generate filename if not provided + if save_to.is_dir() or save_to.name == "": + original_name = simulation_path.stem + save_to = save_to / f"{original_name}_re_eval_{evaluation_type}.json" + + logger.info(f"Saving re-evaluation results to: {save_to}") + re_eval_results.save(save_to) + + return re_eval_results \ No newline at end of file diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/user/user_simulator.py b/examples/AgenticBenchmarks/VitaBench/src/vita/user/user_simulator.py new file mode 100644 index 0000000..d3bb225 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/user/user_simulator.py @@ -0,0 +1,271 @@ +from typing import Optional, Tuple + +from loguru import logger + +from vita.data_model.message import ( + Message, + MultiToolMessage, + SystemMessage, + ToolCall, + UserMessage, +) + +from vita.environment.tool import Tool +from vita.user.base import ( + OUT_OF_SCOPE, + STOP, + TRANSFER, + BaseUser, + UserState, + ValidUserInputMessage, + is_valid_user_history_message, +) +from vita.utils.llm_utils import generate +from vita.prompts import get_prompts + + +def get_global_user_sim_guidelines(language: str = None) -> str: + """ + Get the global user simulator guidelines. + + Returns: + The global user simulator guidelines. + """ + prompts = get_prompts(language) + return prompts.user_system_prompt + + +class UserSimulator(BaseUser): + """Stateless implementation of a user simulator.""" + + def __init__( + self, + tools: Optional[list[Tool]] = None, + instructions: Optional[str] = None, + persona: Optional[str] = None, + llm: Optional[str] = None, + llm_args: Optional[dict] = None, + language: str = None, + ): + super().__init__(instructions=instructions, llm=llm, llm_args=llm_args) + self.tools = tools + self.persona = persona + self.language = language + + @property + def global_simulation_guidelines(self) -> str: + """ + The simulation guidelines for the user simulator. + """ + return get_global_user_sim_guidelines(self.language) + + @property + def system_prompt(self) -> str: + """ + The system prompt for the user simulator. + """ + if self.instructions is None: + logger.warning("No instructions provided for user simulator") + system_prompt = self.global_simulation_guidelines.format( + persona=self.persona, + instructions=self.instructions, + ) + return system_prompt + + def get_init_state( + self, message_history: Optional[list[Message]] = None + ) -> UserState: + """ + Get the initial state of the user simulator. + """ + if message_history is None: + message_history = [] + assert all(is_valid_user_history_message(m) for m in message_history), ( + "Invalid user message history. User messages must be of type UserMessage, AssistantMessage, or ToolMessage to User." + ) + + user_state = UserState( + system_messages=[SystemMessage(role="system", content=self.system_prompt)], + messages=message_history, + ) + return user_state + + @classmethod + def is_stop(cls, message: UserMessage) -> bool: + """ + Check if the message is a stop message. + """ + if message.is_tool_call(): + return False + assert message.content is not None + return ( + STOP in message.content + or TRANSFER in message.content + or OUT_OF_SCOPE in message.content + ) + + def generate_next_message( + self, message: ValidUserInputMessage, state: UserState + ) -> Tuple[UserMessage, UserState]: + return self._generate_next_message(message, state) + + def _generate_next_message( + self, message: ValidUserInputMessage, state: UserState + ) -> Tuple[UserMessage, UserState]: + """Get the response from the user simulator. + + Args: + message: The assistant or tool message. + state: The user simulator's state. + + Returns: + A tuple containing the user message and the updated user state. + """ + if isinstance(message, MultiToolMessage): + state.messages.extend(message.tool_messages) + else: + state.messages.append(message) + messages = state.system_messages + state.flip_roles() + + assistant_message = generate( + model=self.llm, + messages=messages, + tools=self.tools, + **self.llm_args, + ) + + user_response = assistant_message.content + logger.debug(f"Response: {user_response}") + + user_message = UserMessage( + role="user", + content=user_response, + cost=assistant_message.cost, + usage=assistant_message.usage, + raw_data=assistant_message.raw_data, + ) + + if assistant_message.tool_calls is not None: + user_message.tool_calls = [] + for tool_call in assistant_message.tool_calls: + user_message.tool_calls.append( + ToolCall( + id=tool_call.id, + name=tool_call.name, + arguments=tool_call.arguments, + requestor="user", + ) + ) + + state.messages.append(user_message) + return user_message, state + + +class DummyUser(BaseUser): + """A dummy user to run a agent solo simulation.""" + def __init__( + self, + tools: Optional[list[Tool]] = None, + instructions: Optional[str] = None, + persona: Optional[str] = None, + llm: Optional[str] = None, + llm_args: Optional[dict] = None, + language: str = None, + ): + super().__init__(instructions=instructions, llm=llm, llm_args=llm_args) + self.tools = tools + self.persona = persona + self.language = language + + @property + def system_prompt(self) -> str: + prompts = get_prompts(self.language) + return prompts.dummy_user_system_prompt.format(persona=self.persona, instructions=self.instructions) + + def get_init_state( + self, message_history: Optional[list[Message]] = None + ) -> UserState: + """ + Get the initial state of the user simulator. + """ + if message_history is None: + message_history = [] + assert all(is_valid_user_history_message(m) for m in message_history), ( + "Invalid user message history. User messages must be of type UserMessage, AssistantMessage, or ToolMessage to User." + ) + + user_state = UserState( + system_messages=[SystemMessage(role="system", content=self.system_prompt)], + messages=message_history, + ) + return user_state + + @classmethod + def is_stop(cls, message: UserMessage) -> bool: + """ + Check if the message is a stop message. + """ + if message.is_tool_call(): + return False + assert message.content is not None + return ( + STOP in message.content + or TRANSFER in message.content + or OUT_OF_SCOPE in message.content + ) + + def generate_next_message( + self, message: ValidUserInputMessage, state: UserState + ) -> Tuple[UserMessage, UserState]: + return self._generate_next_message(message, state) + + def _generate_next_message( + self, message: ValidUserInputMessage, state: UserState + ) -> Tuple[UserMessage, UserState]: + """Get the response from the user simulator. + + Args: + message: The assistant or tool message. + state: The user simulator's state. + + Returns: + A tuple containing the user message and the updated user state. + """ + if isinstance(message, MultiToolMessage): + state.messages.extend(message.tool_messages) + else: + state.messages.append(message) + messages = state.system_messages + state.flip_roles() + + assistant_message = generate( + model=self.llm, + messages=messages, + tools=self.tools, + **self.llm_args, + ) + + user_response = assistant_message.content + logger.debug(f"Response: {user_response}") + + user_message = UserMessage( + role="user", + content=user_response, + cost=assistant_message.cost, + usage=assistant_message.usage, + raw_data=assistant_message.raw_data, + ) + + if assistant_message.tool_calls is not None: + user_message.tool_calls = [] + for tool_call in assistant_message.tool_calls: + user_message.tool_calls.append( + ToolCall( + id=tool_call.id, + name=tool_call.name, + arguments=tool_call.arguments, + requestor="user", + ) + ) + + state.messages.append(user_message) + return user_message, state \ No newline at end of file diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/utils/utils.py b/examples/AgenticBenchmarks/VitaBench/src/vita/utils/utils.py new file mode 100644 index 0000000..1b902cf --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/utils/utils.py @@ -0,0 +1,240 @@ +import re +import hashlib +import json +import subprocess +from typing import Dict, Union +from datetime import datetime, timedelta +from pathlib import Path + +from deepdiff import DeepDiff +from dotenv import load_dotenv +from loguru import logger +from thefuzz import fuzz, process +from json_repair import repair_json + +from vita.config import DEFAULT_LANGUAGE + +global_time = None + +res = load_dotenv() +if not res: + logger.warning("No .env file found") + +SOURCE_DIR = Path(__file__).parents[3] +DATA_DIR = SOURCE_DIR / "data" +DOMAIN_DIR = DATA_DIR / "vita" / "domains" + +def get_task_file_path(domain: str, language: str = None) -> Path: + """Return corresponding task file path based on language parameter""" + if language is None: + language = DEFAULT_LANGUAGE + + if language == "english": + return DOMAIN_DIR / domain / "tasks_en.json" + else: + return DOMAIN_DIR / domain / "tasks.json" + +# Task file path definitions +DELIVERY_TASK_SET_PATH = get_task_file_path("delivery") +INSTORE_TASK_SET_PATH = get_task_file_path("instore") +CROSS_TASK_SET_PATH = get_task_file_path("cross_domain") +OTA_TASK_SET_PATH = get_task_file_path("ota") + + +def get_hash(obj: Union[dict, str]) -> str: + """ + Generate a unique hash for dict. + Returns a hex string representation of the hash. + """ + if isinstance(obj, dict): + hash_string = json.dumps(obj, sort_keys=True, default=str, ensure_ascii=False) + else: + hash_string = obj + return hashlib.sha256(hash_string.encode()).hexdigest() + + +def show_dict_diff(dict1: dict, dict2: dict) -> str: + """ + Show the difference between two dictionaries. + """ + diff = DeepDiff(dict1, dict2) + return diff + + +def get_now(format: str = "%Y%m%d_%H%M%S") -> str: + """ + Returns the current date and time in the format YYYYMMDD_HHMMSS. + """ + global global_time + if global_time is not None: + return global_time + now = datetime.now() + return format_time(now, format=format) + + +def str_to_datetime(time: str) -> datetime: + return datetime.strptime(time, "%Y-%m-%d %H:%M:%S") + +def get_weekday(date: str, language: str = None) -> str: + """Get weekday in the specified language""" + if language is None: + language = DEFAULT_LANGUAGE + + if language == "english": + weekday_dict = {0: 'Monday', 1: 'Tuesday', 2: 'Wednesday', 3: 'Thursday', 4: 'Friday', 5: 'Saturday', 6: 'Sunday'} + weekday = str_to_datetime(date).weekday() + return weekday_dict[weekday] + else: + weekday_dict = {0: '一', 1: '二', 2: '三', 3: '四', 4: '五', 5: '六', 6: '日'} + weekday = str_to_datetime(date).weekday() + return f"星期{weekday_dict[weekday]}" + +def format_time(time: datetime, format: str = "%Y%m%d_%H%M%S") -> str: + """ + Format the time in the format YYYYMMDD_HHMMSS. + """ + return time.strftime(format) + + +def check_time_format(time: str, format="%Y-%m-%d %H:%M:%S") -> bool: + try: + datetime.strptime(time, format) + return True + except ValueError: + return False + + +def check_date_format(date: str) -> bool: + try: + datetime.strptime(date, "%Y-%m-%d") + return True + except ValueError: + return False + + +def get_date_between(start_date: str, end_date: str) -> list[str]: + """ + Get Date List between start_date and end_date. (include start_date and end_date) + """ + date_list = [] + start_datetime = datetime.strptime(start_date, "%Y-%m-%d") + end_datetime = datetime.strptime(end_date, "%Y-%m-%d") + while start_datetime < end_datetime: + date_list.append(start_datetime.strftime("%Y-%m-%d")) + start_datetime += timedelta(days=1) + return date_list + + +def get_commit_hash() -> str: + """ + Get the commit hash of the current directory. + """ + try: + commit_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], text=True) + .strip() + .split("\n")[0] + ) + except Exception as e: + logger.error(f"Failed to get git hash: {e}") + commit_hash = "unknown" + return commit_hash + + +def edit_distance_score(s1: str, s2: str): + """Calculate the edit distance between two strings.""" + dp = [[0] * (len(s2) + 1) for _ in range(len(s1) + 1)] + for i in range(len(s1) + 1): + dp[i][0] = i + for j in range(len(s2) + 1): + dp[0][j] = j + for i in range(1, len(s1) + 1): + for j in range(1, len(s2) + 1): + if s1[i - 1] == s2[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + else: + dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1 + return 1 - dp[len(s1)][len(s2)] / max(len(s1), len(s2)) + + +def rerank(keywords: str, docs: Dict[str, str], with_score: bool = False): + # Ensure there are no duplicate values in docs + robust_docs = {} + val_set = set() + for key, val in docs.items(): + while val in val_set: + # Add a dummy suffix to the value + val += "-" + + val_set.add(val) + robust_docs[key] = val + + candidates = [doc for doc in robust_docs.values()] + doc_dict_reverse = {val: key for key, val in robust_docs.items()} + + docs_sorted = process.extract(keywords, candidates, limit=None, scorer=fuzz.partial_ratio) + if with_score: + id_doc_sorted = [(doc_dict_reverse[doc], doc, score) for doc, score in docs_sorted] + else: + id_doc_sorted = [(doc_dict_reverse[doc], doc) for doc, _ in docs_sorted] + + return id_doc_sorted + + +def fuzzy_match(x: str, y: str) -> bool: + if fuzz.partial_ratio(x, y) >= 40: + return True + else: + return False + + +def fuzzy_ratio_match(x: str, y: str) -> bool: + if fuzz.ratio(x, y) >= 20: + return True + else: + return False + + +def json_check(json_str: str) -> bool: + try: + json.loads(json_str) + return True + except (ValueError, json.JSONDecodeError) as e: + print(f"Format error: {e}") + return False + + +def extract_json_fields(json_str): + """Simple function to extract JSON fields""" + rubrics_pattern = r'"rubrics":\s*"(.*?)"(?=,|\s*})' + reasoning_pattern = r'"reasoning":\s*"(.*?)"(?=,|\s*})' + meet_expectation_pattern = r'"meetExpectation":\s*(true|false)' + + rubrics_list = re.findall(rubrics_pattern, json_str, re.DOTALL) + reasoning_list = re.findall(reasoning_pattern, json_str, re.DOTALL) + meet_expectation_list = re.findall(meet_expectation_pattern, json_str) + + results = [] + max_length = max(len(rubrics_list), len(reasoning_list), len(meet_expectation_list)) + + for i in range(max_length): + obj_result = {} + if i < len(rubrics_list): + obj_result['rubrics'] = rubrics_list[i] + if i < len(reasoning_list): + obj_result['reasoning'] = reasoning_list[i] + if i < len(meet_expectation_list): + obj_result['meetExpectation'] = meet_expectation_list[i] == 'true' + + results.append(obj_result) + + return results + + +def evaluator_extracter(content: str) -> list[dict]: + """ + Extract the result from the content. + """ + good_json_string = repair_json(content) + result_data = json.loads(good_json_string) + return result_data From 9e19c8d89a2df2fdf4f84d33296770f6f4f9dde2 Mon Sep 17 00:00:00 2001 From: ashmitkx <66110457+ashmitkx@users.noreply.github.com> Date: Wed, 24 Jun 2026 14:04:31 +0000 Subject: [PATCH 2/2] add vitabench diff --- .../AgenticBenchmarks/VitaBench/README.md | 114 ++++++ .../VitaBench/src/vita/agent/llm_agent.py | 20 +- .../VitaBench/src/vita/cli.py | 42 ++- .../src/vita/data_model/simulation.py | 46 +++ .../vita/domains/ota/completeness/__init__.py | 35 ++ .../vita/domains/ota/completeness/checker.py | 291 +++++++++++++++ .../ota/completeness/constraint_extractor.py | 167 +++++++++ .../vita/domains/ota/completeness/schema.py | 152 ++++++++ .../ota/soundness_judge_harness/__init__.py | 218 +++++++++++ .../constraint_extractor.py | 81 ++++ .../ota/soundness_judge_harness/judge.py | 220 +++++++++++ .../soundness_judge_harness/memory_store.py | 127 +++++++ .../ota/soundness_judge_harness/schema.py | 116 ++++++ .../ota/soundness_judge_llm/__init__.py | 13 + .../domains/ota/soundness_judge_llm/judge.py | 127 +++++++ .../VitaBench/src/vita/domains/ota/tools.py | 31 +- .../src/vita/domains/ota/tools_schema.py | 67 +++- .../src/vita/domains/ota/verifier/__init__.py | 205 ++++++++++ .../src/vita/domains/ota/verifier/utils.py | 88 +++++ .../src/vita/evaluator/evaluator_traj.py | 12 +- .../src/vita/orchestrator/orchestrator.py | 122 +++++- .../src/vita/prompts/agent_system_prompt.yaml | 4 + .../completeness_extraction_template.yaml | 81 ++++ .../prompts/date_resolution_template.yaml | 51 +++ ...arness_constraint_extraction_template.yaml | 113 ++++++ .../harness_memory_writer_template.yaml | 33 ++ .../harness_soundness_judge_template.yaml | 51 +++ .../prompts/solo_agent_system_prompt.yaml | 4 + .../prompts/soundness_judge_template.yaml | 72 ++++ .../VitaBench/src/vita/run.py | 121 +++++- .../VitaBench/src/vita/scripts/README.md | 94 +++++ .../vita/scripts/preextract_completeness.py | 183 +++++++++ .../scripts/preextract_constraints_harness.py | 200 ++++++++++ .../vita/scripts/pregenerate_solo_messages.py | 153 ++++++++ .../src/vita/scripts/preresolve_dates.py | 349 ++++++++++++++++++ .../VitaBench/src/vita/user/user_simulator.py | 67 ++++ .../VitaBench/src/vita/utils/utils.py | 73 +++- 37 files changed, 3879 insertions(+), 64 deletions(-) create mode 100644 examples/AgenticBenchmarks/VitaBench/README.md create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/__init__.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/checker.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/constraint_extractor.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/schema.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/__init__.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/constraint_extractor.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/judge.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/memory_store.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/schema.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_llm/__init__.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_llm/judge.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/verifier/__init__.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/verifier/utils.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/prompts/completeness_extraction_template.yaml create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/prompts/date_resolution_template.yaml create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_constraint_extraction_template.yaml create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_memory_writer_template.yaml create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_soundness_judge_template.yaml create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/prompts/soundness_judge_template.yaml create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/scripts/README.md create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preextract_completeness.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preextract_constraints_harness.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/scripts/pregenerate_solo_messages.py create mode 100644 examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preresolve_dates.py diff --git a/examples/AgenticBenchmarks/VitaBench/README.md b/examples/AgenticBenchmarks/VitaBench/README.md new file mode 100644 index 0000000..12b7dc9 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/README.md @@ -0,0 +1,114 @@ +# Running the OTA verifier on VitaBench + +--- + +**Note:** The code provided in this folder is built on top of the original code for VitaBench, found at [https://github.com/meituan-longcat/vitabench](https://github.com/meituan-longcat/vitabench). In each file, we have mentioned the changes we have made, and the code we have used verbatim, relative to the same file in the original VitaBench repo. + +## 1. Clone the upstream VitaBench repo + +```bash +git clone https://github.com/meituan-longcat/vitabench.git +cd vitabench +``` + +This README and the files alongside it are an overlay on top of the upstream +`main` branch. The overlay keeps the upstream directory layout, so every file +lives at the same path it would occupy inside a VitaBench checkout. + +## 2. Apply the overlay + +Because the overlay mirrors the upstream layout, copy its `src/` tree straight +over your clone — files at matching paths are replaced, new files are added: + +```bash +SRC=/path/to/this/overlay # the directory containing this README +DST=/path/to/vitabench # your upstream clone + +cp -r "$SRC/src" "$DST/" # merge the overlay sources into the clone +``` + +### Modified files + +| Path | What changed | +|---|---| +| `src/vita/cli.py` | Adds the `--soundness-mode`, `--completeness-mode`, `--solo-user-mode` and `--solo-user-file` run flags. | +| `src/vita/data_model/simulation.py` | Adds the matching `RunConfig` fields (+ validation) and a `soundness_log` field on `SimulationRun`. | +| `src/vita/run.py` | Threads the new flags through the run pipeline, builds the OTA verifier, and resolves solo user messages. | +| `src/vita/orchestrator/orchestrator.py` | Runs the verifier inline: blocking soundness check before each tool call, and a completeness check on stop. | +| `src/vita/agent/llm_agent.py` | Solo agent honours `language`; relaxes the tool-call-only guard so the orchestrator can nudge instead of crashing. | +| `src/vita/user/user_simulator.py` | `DummyUser` can replay a pregenerated opening message instead of calling the LLM each run. | +| `src/vita/domains/ota/tools.py` | Adds an optional `override` flag to every OTA WRITE tool so the agent can bypass a soundness block when confident. | +| `src/vita/domains/ota/tools_schema.py` | Documents the new `override` argument (Chinese + English). | +| `src/vita/evaluator/evaluator_traj.py` | Flattens nested-list LLM rubric output before scoring. | +| `src/vita/utils/utils.py` | Hardens `evaluator_extracter` JSON extraction (think-block stripping, fenced/balanced-block fallback). | +| `src/vita/prompts/agent_system_prompt.yaml` | Adds an "always respond in English" instruction. | +| `src/vita/prompts/solo_agent_system_prompt.yaml` | Adds an "always respond in English" instruction. | + +### New files + +| Path | Purpose | +|---|---| +| `src/vita/domains/ota/verifier/` | `OTAVerifier` + `create_verifier()` factory that wires the soundness and completeness checks together. | +| `src/vita/domains/ota/soundness_judge_llm/` | LLM-judge soundness checker (`--soundness-mode llm`). | +| `src/vita/domains/ota/soundness_judge_harness/` | NL-constraint "harness" soundness checker with running memory (`--soundness-mode harness`). | +| `src/vita/domains/ota/completeness/` | Completeness checker that compares the final orders against extracted constraints at stop. | +| `src/vita/prompts/*.yaml` | New prompt templates: soundness/harness judges, constraint & completeness extraction, memory writer, and date resolution. | +| `src/vita/scripts/` | Offline preprocessing scripts and their guide — see [`src/vita/scripts/README.md`](src/vita/scripts/README.md). | + +## 3. Python environment + +Follow the upstream VitaBench README. + +## 4. Offline preprocessing (optional) + +Some verifier modes consume artifacts produced by the scripts in +`src/vita/scripts/` (resolved dates, extracted constraints, pregenerated solo +user messages). The dependency order and exact commands are documented in +[`src/vita/scripts/README.md`](src/vita/scripts/README.md). You only need these +if you run `--soundness-mode harness`, `--completeness-mode on`, or +`--solo-user-mode file`. + +## 5. Environment variables + +```bash +# Max times the agent is sent back after a failed completeness check (default 1) +export VITA_MAX_COMPLETENESS_RETRIES=1 +``` + +## 6. Run + +Reference command (OTA domain, solo agent, dummy user, harness soundness + +completeness checks on): + +```bash +vita run \ + --domain ota \ + --agent llm_solo_agent \ + --user dummy_user \ + --agent-llm \ + --evaluator-llm \ + --language english \ + --soundness-mode harness \ + --completeness-mode on \ + --num-tasks 100 +``` + +Flags (the four overlay flags are added by this overlay; the rest are upstream): + +| Flag | Meaning | +|---|---| +| `--domain ota` | Run the OTA domain. The verifier only activates for `ota`. | +| `--agent llm_solo_agent` | Solo-mode agent: no user-simulator turn; it works the ticket autonomously via tool calls. | +| `--user dummy_user` | No-op user that only issues the opening message. | +| `--agent-llm ` | Model (from `models.yaml`) the agent runs on. | +| `--evaluator-llm ` | Model used by the rubric evaluator. | +| `--language english` | Prompt/task language (`english` or `chinese`). | +| `--num-tasks 100` | Number of tasks to run. | +| `--soundness-mode {llm,harness,off}` | Soundness checker before each write tool call. `llm` = LLM judge, `harness` = NL-constraint judge with memory, `off` = disabled. Default `off`. | +| `--completeness-mode {on,off}` | When `on`, run a completeness check at stop and send the agent back (up to `VITA_MAX_COMPLETENESS_RETRIES`) if requirements are unmet. Default `off`. | +| `--solo-user-mode {live,file}` | Solo opening message: `live` generates it via LLM each run (introduces variance); `file` loads a deterministic pregenerated message. Default `live`. | +| `--solo-user-file ` | JSON mapping `task_id -> message`, required when `--solo-user-mode=file`. Produced by `src/vita/scripts/pregenerate_solo_messages.py`. | + +Results are written to `data/simulations/`. See the upstream README for the full +list of base flags (`--num-trials`, `--max-steps`, `--task-ids`, `--csv-output`, +…). diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/agent/llm_agent.py b/examples/AgenticBenchmarks/VitaBench/src/vita/agent/llm_agent.py index 8a8bd0a..44f725a 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/agent/llm_agent.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/agent/llm_agent.py @@ -1,3 +1,14 @@ +"""VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/agent/llm_agent.py. +Everything is verbatim from the original except for the following changes: + +1. ``LLMSoloAgent`` now stores ``self.language`` and builds its system prompt + with ``get_prompts(self.language)`` (was ``get_prompts()``), so solo runs + honour the requested language. +2. Commented out the ``raise ValueError("LLMSoloAgent only supports tool calls + before ###STOP###.")`` guard in ``generate_next_message`` — the orchestrator + now nudges the agent back instead of hard-failing on a stray text turn. +""" from copy import deepcopy from typing import List, Optional @@ -147,10 +158,11 @@ def __init__( self.llm_args = deepcopy(llm_args) if llm_args is not None else {} self.time = time + " " + get_weekday(time, language) self.enable_think = enable_think + self.language = language @property def system_prompt(self) -> str: - prompts = get_prompts() + prompts = get_prompts(self.language) if self.time is not None: return prompts.solo_agent_system_prompt.format( time=self.time @@ -210,8 +222,8 @@ def generate_next_message( enable_think=self.enable_think, **self.llm_args, ) - if not assistant_message.is_tool_call() and not self.is_stop(assistant_message): - raise ValueError("LLMSoloAgent only supports tool calls before ###STOP###.") + # if not assistant_message.is_tool_call() and not self.is_stop(assistant_message): + # raise ValueError("LLMSoloAgent only supports tool calls before ###STOP###.") state.messages.append(assistant_message) return assistant_message, state @@ -222,4 +234,4 @@ def set_seed(self, seed: int): cur_seed = self.llm_args.get("seed", None) if cur_seed is not None: logger.warning(f"Seed is already set to {cur_seed}, resetting it to {seed}") - self.llm_args["seed"] = seed \ No newline at end of file + self.llm_args["seed"] = seed diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/cli.py b/examples/AgenticBenchmarks/VitaBench/src/vita/cli.py index a613101..572b318 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/cli.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/cli.py @@ -1,3 +1,11 @@ +"""VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/cli.py. +Everything is verbatim from the original except for the following changes: + +1. Added four ``run`` CLI arguments: ``--soundness-mode``, ``--completeness-mode``, + ``--solo-user-mode`` and ``--solo-user-file``. +2. Forwarded those four arguments into ``run_domain`` from ``main()``. +""" import argparse from typing import get_args @@ -175,7 +183,33 @@ def add_run_args(parser): action="store_true", help="Re-run tasks specified by --task-ids. If used with --re-evaluate-file, will re-run specified tasks and then re-evaluate all tasks together.", ) - + parser.add_argument( + "--soundness-mode", + type=str, + choices=["llm", "harness", "off"], + default="off", + help="Soundness check mode: 'llm' (LLM judge), 'harness' (NL constraint harness), 'off' (disabled). Default is 'off'.", + ) + parser.add_argument( + "--completeness-mode", + type=str, + choices=["on", "off"], + default="off", + help="Completeness check mode: 'on' (check final order completeness at stop), 'off' (disabled). Default is 'off'.", + ) + parser.add_argument( + "--solo-user-mode", + type=str, + choices=["live", "file"], + default="live", + help="Solo agent user message mode: 'live' (generate via LLM each run, introduces variance) or 'file' (load from --solo-user-file, errors if a task is missing). Default is 'live'.", + ) + parser.add_argument( + "--solo-user-file", + type=str, + default=None, + help="Path to a JSON file mapping task_id -> pregenerated user message. Required when --solo-user-mode=file.", + ) def main(): @@ -212,7 +246,11 @@ def main(): csv_output_file=getattr(args, 'csv_output', None), enable_think=args.enable_think, language=args.language, - re_run=getattr(args, 're_run', False) + re_run=getattr(args, 're_run', False), + soundness_mode=args.soundness_mode, + completeness_mode=args.completeness_mode, + solo_user_mode=args.solo_user_mode, + solo_user_file=args.solo_user_file, ) ) ) diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/data_model/simulation.py b/examples/AgenticBenchmarks/VitaBench/src/vita/data_model/simulation.py index b0a5587..4417338 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/data_model/simulation.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/data_model/simulation.py @@ -1,3 +1,13 @@ +"""VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/data_model/simulation.py. +Everything is verbatim from the original except for the following changes: + +1. Added four fields to ``RunConfig``: ``soundness_mode``, ``completeness_mode``, + ``solo_user_mode`` and ``solo_user_file``. +2. Added validation in ``RunConfig.validate()`` for ``solo_user_mode`` / + ``solo_user_file``. +3. Added the ``soundness_log`` field to ``SimulationRun``. +""" import json from copy import deepcopy from enum import Enum @@ -220,6 +230,34 @@ class RunConfig(BaseModel): default=False, ), ] + soundness_mode: Annotated[ + str, + Field( + description="Soundness check mode: 'llm', 'harness', or 'off'", + default="off", + ), + ] + completeness_mode: Annotated[ + str, + Field( + description="Completeness check mode: 'on' or 'off'", + default="off", + ), + ] + solo_user_mode: Annotated[ + str, + Field( + description="Solo agent user message mode: 'live' (generate via LLM each run) or 'file' (load from solo_user_file, error if missing)", + default="live", + ), + ] + solo_user_file: Annotated[ + Optional[str], + Field( + description="Path to a JSON file mapping task_id -> pregenerated user message. Required when solo_user_mode='file'.", + default=None, + ), + ] def validate(self) -> None: """ @@ -231,6 +269,11 @@ def validate(self) -> None: raise ValueError("--re-run can only be used with --re-evaluate-file") if not self.task_ids: raise ValueError("--re-run requires --task-ids to specify which tasks to re-run") + # Validate solo user mode + if self.solo_user_mode not in ("live", "file"): + raise ValueError("--solo-user-mode must be 'live' or 'file'") + if self.solo_user_mode == "file" and not self.solo_user_file: + raise ValueError("--solo-user-file must be provided when --solo-user-mode=file") class NLRubricCheck(BaseModel): @@ -354,6 +397,9 @@ class SimulationRun(BaseModel): seed: Optional[int] = Field( description="Seed used for the simulation.", default=None ) + soundness_log: Optional[list] = Field( + description="Log of all LLM soundness judge calls (tool name, args, reason).", default=None + ) diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/__init__.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/__init__.py new file mode 100644 index 0000000..ae59313 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/__init__.py @@ -0,0 +1,35 @@ +""" +OTA Completeness checking. + +Two parts: + - Extraction (offline): derive minimal booking-count constraints from a task. + - Checking (runtime): verify the final order state satisfies those constraints. + +Usage: + from vita.domains.ota.completeness import check_completeness, CompletenessConstraints + from vita.domains.ota.completeness import extract_completeness_constraints +""" + +from vita.domains.ota.completeness.schema import ( + AttractionCompleteness, + CancelCompleteness, + CompletenessConstraints, + FlightCompleteness, + HotelCompleteness, + ModifyCompleteness, + TrainCompleteness, +) +from vita.domains.ota.completeness.checker import check_completeness +from vita.domains.ota.completeness.constraint_extractor import extract_completeness_constraints + +__all__ = [ + "check_completeness", + "extract_completeness_constraints", + "CompletenessConstraints", + "HotelCompleteness", + "FlightCompleteness", + "TrainCompleteness", + "AttractionCompleteness", + "CancelCompleteness", + "ModifyCompleteness", +] diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/checker.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/checker.py new file mode 100644 index 0000000..5316ca8 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/checker.py @@ -0,0 +1,291 @@ +""" +Order-based completeness checker for OTA tasks. + +Checks final order state against completeness constraints to verify +that all required bookings exist. Does NOT check attribute correctness +(room type, seat class, etc.) — that's the soundness judge's job. + +Operates on Order objects (Pydantic models) from the live DB, pre-split +into old_states / new_states by the orchestrator's get_states(). +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timedelta +from typing import Any + +from vita.domains.ota.completeness.schema import ( + AttractionCompleteness, + CancelCompleteness, + CompletenessConstraints, + FlightCompleteness, + HotelCompleteness, + ModifyCompleteness, + TrainCompleteness, +) + +logger = logging.getLogger(__name__) + + +def _parse_date(date_str: str) -> datetime | None: + """Try to parse a date string in YYYY-MM-DD format.""" + try: + return datetime.strptime(date_str, "%Y-%m-%d") + except (ValueError, TypeError): + return None + + +def _get(obj: Any, key: str, default: Any = "") -> Any: + """Attribute or dict access — works with Order objects and plain dicts.""" + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _get_products(order: Any) -> list: + """Get the products list from an order (object or dict).""" + prods = _get(order, "products", []) + return prods if prods else [] + + +def check_completeness( + new_orders: list, + old_orders: list, + constraints: CompletenessConstraints, + environment: dict | None = None, +) -> list[str]: + """ + Check that all completeness constraints are satisfied by the final orders. + + Args: + new_orders: Orders created/modified during the simulation (from get_states new_states). + old_orders: Pre-existing orders (from get_states old_states). + constraints: Completeness constraints (conditional bucket is excluded). + environment: Task environment dict for looking up flight/train cities by store_id. + + Returns: + List of missing-booking descriptions. Empty = all complete. + """ + environment = environment or {} + missing: list[str] = [] + + # Active (non-cancelled) new orders + active_new = [o for o in new_orders if _get(o, "status") != "cancelled"] + + all_orders = list(old_orders) + list(new_orders) + + for con in constraints.hotel: + missing.extend(_check_hotel(con, active_new)) + for con in constraints.flight: + missing.extend(_check_flight(con, active_new, environment)) + for con in constraints.train: + missing.extend(_check_train(con, active_new, environment)) + for con in constraints.attraction: + missing.extend(_check_attraction(con, active_new)) + for con in constraints.cancel: + missing.extend(_check_cancel(con, all_orders)) + for con in constraints.modify: + missing.extend(_check_modify(con, new_orders, old_orders)) + + return missing + + +_ORDER_TYPE_TO_CATALOG = {"flight": "flights", "train": "trains"} + + +def _resolve_city(order: Any, environment: dict, field: str) -> str: + """Look up a city field from the environment via order's store_id.""" + catalog_key = _ORDER_TYPE_TO_CATALOG.get(_get(order, "order_type", ""), "") + if not catalog_key: + return "" + catalog = environment.get(catalog_key, {}) + store_id = _get(order, "store_id", "") + entry = catalog.get(store_id) + if entry is None: + return "" + return _get(entry, field, "") + + +# ────────────────────────────────────────────── +# Hotel completeness +# ────────────────────────────────────────────── + +def _check_hotel(con: HotelCompleteness, active_orders: list) -> list[str]: + """ + Check hotel completeness: right date range, enough room-nights. + City is not checked (hotel addresses are unstructured). + """ + hotel_orders = [o for o in active_orders if _get(o, "order_type") == "hotel"] + num_nights = con.num_nights or 1 + required_room_nights = con.num_rooms * num_nights + + expected_dates: set[str] | None = None + if con.check_in_date: + checkin = _parse_date(con.check_in_date) + if checkin: + expected_dates = { + (checkin + timedelta(days=i)).strftime("%Y-%m-%d") + for i in range(num_nights) + } + + matched_count = 0 + for order in hotel_orders: + for product in _get_products(order): + product_date = _get(product, "date", "") + if expected_dates and product_date not in expected_dates: + continue + matched_count += _get(product, "quantity", 1) + + if matched_count >= required_room_nights: + return [] + + prefix = f"[{con.id}]" + detail = ( + f"check-in={con.check_in_date or 'any'}, " + f"{con.num_rooms} room(s) x {num_nights} night(s)" + ) + if not hotel_orders: + return [f"{prefix} No hotel booked ({detail})"] + return [ + f"{prefix} Hotel: " + f"{matched_count}/{required_room_nights} room-night orders ({detail})" + ] + + +# ────────────────────────────────────────────── +# Flight completeness +# ────────────────────────────────────────────── + +def _check_flight(con: FlightCompleteness, active_orders: list, environment: dict) -> list[str]: + """Check flight completeness: right route, right date, enough total quantity.""" + flight_orders = [o for o in active_orders if _get(o, "order_type") == "flight"] + + matched_qty = 0 + for order in flight_orders: + if con.departure_city: + dep = _resolve_city(order, environment, "departure_city") + if dep.lower() != con.departure_city.lower(): + continue + if con.arrival_city: + arr = _resolve_city(order, environment, "arrival_city") + if arr.lower() != con.arrival_city.lower(): + continue + for product in _get_products(order): + if con.date and _get(product, "date", "") != con.date: + continue + matched_qty += _get(product, "quantity", 0) + + if matched_qty >= con.quantity: + return [] + + prefix = f"[{con.id}]" + route = f"{con.departure_city or '*'} -> {con.arrival_city or '*'}" + detail = f"route={route}, date={con.date or 'any'}, qty={con.quantity}" + if not flight_orders: + return [f"{prefix} No flight booked ({detail})"] + return [f"{prefix} Flight {route} on {con.date}: {matched_qty}/{con.quantity} tickets ({detail})"] + + +# ────────────────────────────────────────────── +# Train completeness +# ────────────────────────────────────────────── + +def _check_train(con: TrainCompleteness, active_orders: list, environment: dict) -> list[str]: + """Check train completeness: right route, right date, enough total quantity.""" + train_orders = [o for o in active_orders if _get(o, "order_type") == "train"] + + matched_qty = 0 + for order in train_orders: + if con.departure_city: + dep = _resolve_city(order, environment, "departure_city") + if dep.lower() != con.departure_city.lower(): + continue + if con.arrival_city: + arr = _resolve_city(order, environment, "arrival_city") + if arr.lower() != con.arrival_city.lower(): + continue + for product in _get_products(order): + if con.date and _get(product, "date", "") != con.date: + continue + matched_qty += _get(product, "quantity", 0) + + if matched_qty >= con.quantity: + return [] + + prefix = f"[{con.id}]" + route = f"{con.departure_city or '*'} -> {con.arrival_city or '*'}" + detail = f"route={route}, date={con.date or 'any'}, qty={con.quantity}" + if not train_orders: + return [f"{prefix} No train booked ({detail})"] + return [f"{prefix} Train {route} on {con.date}: {matched_qty}/{con.quantity} tickets ({detail})"] + + +# ────────────────────────────────────────────── +# Attraction completeness +# ────────────────────────────────────────────── + +def _check_attraction(con: AttractionCompleteness, active_orders: list) -> list[str]: + """Check attraction completeness: right date, enough total quantity.""" + attr_orders = [o for o in active_orders if _get(o, "order_type") == "attraction"] + + matched_qty = 0 + for order in attr_orders: + for product in _get_products(order): + if con.date and _get(product, "date", "") != con.date: + continue + matched_qty += _get(product, "quantity", 0) + + if matched_qty >= con.quantity: + return [] + + prefix = f"[{con.id}]" + detail = f"date={con.date or 'any'}, qty={con.quantity}" + if not attr_orders: + return [f"{prefix} No attraction booked ({detail})"] + return [f"{prefix} Attraction on {con.date}: insufficient quantity ({detail})"] + + +# ────────────────────────────────────────────── +# Cancel completeness +# ────────────────────────────────────────────── + +def _check_cancel(con: CancelCompleteness, all_orders: list) -> list[str]: + """Check that the specified pre-existing order was cancelled.""" + for order in all_orders: + if _get(order, "order_id") == con.order_id: + status = _get(order, "status") + if status == "cancelled": + return [] + return [ + f"[{con.id}] Order {con.order_id} ({con.entity_type}) should be cancelled " + f"but has status '{status}'" + ] + + return [] + + +# ────────────────────────────────────────────── +# Modify completeness +# ────────────────────────────────────────────── + +def _check_modify(con: ModifyCompleteness, new_orders: list, old_orders: list) -> list[str]: + """ + Check that the specified pre-existing order was modified. + If it moved to new_orders (update_time changed), it was modified. + """ + + # search in new_orders first - if it's there, it's modified, no need to check old_orders + for order in new_orders: + if _get(order, "order_id") == con.order_id: + return [] + + # if it's not in new_orders, check if it exists in old_orders - if it does, then it was not modified + for order in old_orders: + if _get(order, "order_id") == con.order_id: + return [ + f"[{con.id}] Order {con.order_id} ({con.entity_type}) should be modified " + f"but was not updated" + ] + + return [] diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/constraint_extractor.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/constraint_extractor.py new file mode 100644 index 0000000..a52b0ac --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/constraint_extractor.py @@ -0,0 +1,167 @@ +""" +LLM-based completeness constraint extractor for OTA tasks. + +Extracts minimal constraints (counts, cities, routes, dates) needed to +verify that all required bookings were made. Attribute-level details +(room type, seat class, etc.) are left to the soundness judge. +""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from vita.config import DEFAULT_LLM_EVALUATOR, models +from vita.data_model.message import SystemMessage, UserMessage +from vita.data_model.tasks import Task +from vita.domains.ota.completeness.schema import CompletenessConstraints +from vita.domains.ota.verifier.utils import _extract_json +from vita.prompts import get_prompts +from vita.utils.llm_utils import generate + +logger = logging.getLogger(__name__) + + +def _format_available_cities(environment: dict) -> str: + """Extract unique flight/train city names from the environment.""" + cities: set[str] = set() + for flight in environment.get("flights", {}).values(): + f = flight if isinstance(flight, dict) else vars(flight) + for key in ("departure_city", "arrival_city"): + if v := f.get(key): + cities.add(v) + for train in environment.get("trains", {}).values(): + t = train if isinstance(train, dict) else vars(train) + for key in ("departure_city", "arrival_city"): + if v := t.get(key): + cities.add(v) + if not cities: + return "(none)" + return ", ".join(sorted(cities)) + + +def _format_existing_orders(environment: dict) -> str: + """Format pre-existing orders for the LLM prompt.""" + orders = environment.get("orders", {}) + if not orders: + return "(no existing orders)" + + lines = [] + for order_id, order in orders.items(): + # Handle both dict and object forms + if isinstance(order, dict): + otype = order.get("order_type", "unknown") + status = order.get("status", "unknown") + products = order.get("products", []) + else: + otype = getattr(order, "order_type", "unknown") + status = getattr(order, "status", "unknown") + products = getattr(order, "products", []) + + product_summary = [] + for p in products: + if isinstance(p, dict): + date = p.get("date", "N/A") + qty = p.get("quantity", 1) + else: + date = getattr(p, "date", "N/A") + qty = getattr(p, "quantity", 1) + product_summary.append(f"date={date}, qty={qty}") + + lines.append( + f"- order_id={order_id}, type={otype}, status={status}, " + f"products=[{'; '.join(product_summary)}]" + ) + return "\n".join(lines) + + +def extract_completeness_constraints( + task: Task, + llm_model: Optional[str] = None, + llm_args: Optional[dict] = None, + language: str = "english", +) -> CompletenessConstraints: + """ + Extract completeness constraints from a task's instructions via a single LLM call. + + Args: + task: The Task object (instructions + environment). + llm_model: LLM model name for the extraction call. + llm_args: Model-specific arguments (temperature, base_url, headers, etc.). + language: Prompt language ("english" or "chinese"). + + Returns: + CompletenessConstraints with all booking count requirements. + """ + if llm_model is None: + llm_model = DEFAULT_LLM_EVALUATOR + if llm_args is None: + llm_args = dict(models.get(llm_model, models.get("default", {}))) + + # Ensure enough output tokens for thinking + JSON response + llm_args["max_tokens"] = max(llm_args.get("max_tokens", 0), 16384) + + env = task.environment or {} + system_time = env.get("time", "unknown") + + user_profile = {} + if task.user_scenario and task.user_scenario.user_profile: + up = task.user_scenario.user_profile + user_profile = up if isinstance(up, dict) else {} + + user_historical_behaviors = env.get("user_historical_behaviors", {}) + existing_orders = _format_existing_orders(env) + available_cities = _format_available_cities(env) + + # Build the JSON schema for the LLM to follow + json_schema = json.dumps( + CompletenessConstraints.model_json_schema(), + indent=2, + ensure_ascii=False, + ) + + # Load the prompt template + prompts = get_prompts(language) + system_prompt = prompts.completeness_extraction_template.format( + system_time=system_time, + user_profile=json.dumps(user_profile, indent=2, ensure_ascii=False), + user_historical_behaviors=json.dumps( + user_historical_behaviors, indent=2, ensure_ascii=False + ), + existing_orders=existing_orders, + available_cities=available_cities, + json_schema=json_schema, + ) + + user_prompt = ( + f"Task ID: {task.id}\n\n" + f"Instructions:\n{task.instructions}" + ) + + messages = [ + SystemMessage(role="system", content=system_prompt), + UserMessage(role="user", content=user_prompt), + ] + + # Single LLM call + logger.info("Extracting completeness constraints for task %s with %s", task.id, llm_model) + response = generate(model=llm_model, messages=messages, enable_think=True, **llm_args) + + # Parse and validate the response + raw_content = response.content or "" + try: + parsed = _extract_json(raw_content) + if parsed is None: + raise ValueError("No JSON object found in LLM response") + if isinstance(parsed, list): + parsed = parsed[0] if parsed else {} + if isinstance(parsed, dict) and "task_id" not in parsed: + parsed["task_id"] = task.id + constraints = CompletenessConstraints.model_validate(parsed) + except Exception as e: + raise ValueError( + "LLM extraction failed for task %s: %s. Raw: %s" % (task.id, e, raw_content) + ) + + return constraints diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/schema.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/schema.py new file mode 100644 index 0000000..2a02edc --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/completeness/schema.py @@ -0,0 +1,152 @@ +""" +Pydantic models for OTA completeness constraints. + +These are intentionally minimal — only fields needed to verify that +the right *number* of bookings exist. Attribute-level correctness +(room type, seat type, ticket type, etc.) is handled by the LLM +soundness judge. +""" + +from __future__ import annotations + +from typing import List, Literal, Optional, Union + +from pydantic import BaseModel, Field + + +# ────────────────────────────────────────────── +# Booking completeness constraints +# ────────────────────────────────────────────── + +class HotelCompleteness(BaseModel): + """Hotel booking completeness requirement.""" + + id: str = Field(description="Unique constraint ID, e.g. 'hotel_1'") + check_in_date: Optional[str] = Field( + default=None, + description="Check-in date in YYYY-MM-DD format", + ) + num_rooms: int = Field(default=1, description="Number of rooms to book") + num_nights: Optional[int] = Field(default=None, description="Number of nights to stay") + description: str = "" + + +class FlightCompleteness(BaseModel): + """Flight booking completeness requirement.""" + + id: str = Field(description="Unique constraint ID, e.g. 'flight_1'") + departure_city: Optional[str] = None + arrival_city: Optional[str] = None + date: Optional[str] = Field(default=None, description="Departure date YYYY-MM-DD") + quantity: int = Field(default=1, description="Number of tickets") + description: str = "" + + +class TrainCompleteness(BaseModel): + """Train booking completeness requirement.""" + + id: str = Field(description="Unique constraint ID, e.g. 'train_1'") + departure_city: Optional[str] = None + arrival_city: Optional[str] = None + date: Optional[str] = Field(default=None, description="Departure date YYYY-MM-DD") + quantity: int = Field(default=1, description="Number of tickets") + description: str = "" + + +class AttractionCompleteness(BaseModel): + """Attraction ticket completeness requirement.""" + + id: str = Field(description="Unique constraint ID, e.g. 'attraction_1'") + date: Optional[str] = Field(default=None, description="Visit date YYYY-MM-DD") + quantity: int = Field(default=1, description="Number of tickets") + description: str = "" + + +# ────────────────────────────────────────────── +# Cancel / modify completeness constraints +# ────────────────────────────────────────────── + +class CancelCompleteness(BaseModel): + """Requirement that a pre-existing order be cancelled.""" + + id: str + entity_type: Literal["hotel", "flight", "train", "attraction"] + order_id: str = Field( + description="Order ID of the pre-existing order that should be cancelled", + ) + description: str = "" + + +class ModifyCompleteness(BaseModel): + """Requirement that a pre-existing order be modified.""" + + id: str + entity_type: Literal["flight", "train"] + order_id: str = Field( + description="Order ID of the pre-existing order that should be modified", + ) + description: str = "" + + +# ────────────────────────────────────────────── +# Type aliases +# ────────────────────────────────────────────── + +AnyBookingCompleteness = Union[ + HotelCompleteness, + FlightCompleteness, + TrainCompleteness, + AttractionCompleteness, +] + +AnyCompleteness = Union[ + HotelCompleteness, + FlightCompleteness, + TrainCompleteness, + AttractionCompleteness, + CancelCompleteness, + ModifyCompleteness, +] + + +# ────────────────────────────────────────────── +# Conditional (excluded from completeness checking) +# ────────────────────────────────────────────── + +# Bookings that depend on ANY runtime condition (weather, price, +# availability, distance, etc.) are placed here and excluded from +# completeness checking. The soundness judge handles whether the +# agent took the correct branch. + + +# ────────────────────────────────────────────── +# Top-level extraction result +# ────────────────────────────────────────────── + +class CompletenessConstraints(BaseModel): + """Completeness constraints extracted from task instructions. + + Only captures *what must exist* at the end of the conversation + (counts, cities, routes, dates). Attribute correctness (room type, + seat type, etc.) is left to the soundness judge. + + Bookings that depend on any runtime condition should go in + ``conditional`` so they are excluded from completeness checking. + """ + + task_id: str + + hotel: List[HotelCompleteness] = Field(default_factory=list) + flight: List[FlightCompleteness] = Field(default_factory=list) + train: List[TrainCompleteness] = Field(default_factory=list) + attraction: List[AttractionCompleteness] = Field(default_factory=list) + + cancel: List[CancelCompleteness] = Field(default_factory=list) + modify: List[ModifyCompleteness] = Field(default_factory=list) + + conditional: List[AnyBookingCompleteness] = Field( + default_factory=list, + description="Bookings that depend on any runtime condition " + "(weather, price, availability, distance, etc.). " + "These are EXCLUDED from completeness checking.", + ) diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/__init__.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/__init__.py new file mode 100644 index 0000000..d625194 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/__init__.py @@ -0,0 +1,218 @@ +""" +Soundness Judge Harness — NL-based constraint checking during simulation. + +Three-phase architecture: + 1. Constraint Extraction (offline): Extract NL constraints from task instructions + 2. Memory Store (runtime): SLM distills read tool call results into relevant facts + 3. Judgment (runtime): On each write tool call, evaluate constraints against memory + +Usage: + from vita.domains.ota.soundness_judge_harness import create_harness + harness = create_harness(task, llm_model="claude-sonnet-4.6") + + # During simulation — called by orchestrator: + # On read tool calls: + harness.observe_tool_response(tool_name, tool_args, tool_response) + # On write tool calls: + feedback = harness.check_soundness(tool_name, tool_args) +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Optional + +from vita.config import models as model_configs +from vita.data_model.tasks import Task +from vita.domains.ota.soundness_judge_harness.judge import ConstraintJudge +from vita.domains.ota.soundness_judge_harness.memory_store import MemoryWriter +from vita.domains.ota.soundness_judge_harness.schema import ( + ConstraintVerdict, + ExtractedConstraintSet, + JudgmentResult, +) + +logger = logging.getLogger(__name__) + + +class SoundnessJudgeHarness: + """ + Orchestrates the three-phase NL soundness checking pipeline. + + Exposes the same check_soundness() interface expected by the Orchestrator, + plus observe_tool_response() for memory updates on read calls. + """ + + WRITE_PREFIXES = ("create_", "cancel_", "modify_") + + def __init__( + self, + constraints: ExtractedConstraintSet, + memory_writer: MemoryWriter, + judge: ConstraintJudge, + ): + self.constraints = constraints + self.memory_writer = memory_writer + self.judge = judge + self.soundness_call_log: list[dict] = [] + + def observe_tool_response( + self, + tool_call_name: str, + tool_call_args: dict, + tool_response: str, + ) -> None: + """ + Called after every tool call execution. + Updates memory for read calls; no-op for write calls. + """ + self.memory_writer.process_tool_response( + tool_call_name, tool_call_args, tool_response + ) + + def check_soundness( + self, + tool_call_name: str, + tool_call_args: dict, + trajectory: Optional[list] = None, # unused, kept for interface compat + ) -> Optional[str]: + """ + Evaluate a write tool call against constraints using accumulated memory. + + Returns blocking feedback string if violations detected, None otherwise. + """ + if not any(tool_call_name.startswith(p) for p in self.WRITE_PREFIXES): + return None + if tool_call_args.get("override") is True: + return None + + result = self.judge.judge( + tool_call_name, + tool_call_args, + self.memory_writer.get_memory(), + ) + + # Log every judgment, including current memory snapshot + self.soundness_call_log.append({ + "tool_call_name": result.tool_call_name, + "tool_call_args": result.tool_call_args, + "memory": self.memory_writer.get_memory().model_dump(), + "judgments": [ + {"constraint_id": j.constraint_id, "verdict": j.verdict.value, "reasoning": j.reasoning} + for j in result.judgments + ], + "has_violation": result.has_violation, + }) + + if not result.has_violation: + return None + + # Format blocking feedback + violations = result.violated_constraints + lines = [ + "[SOUNDNESS CHECK] Order blocked — constraint violations detected:", + ] + for v in violations: + lines.append(f" - [{v.constraint_id}] {v.reasoning}") + lines.append( + "Please review your tool call against the user instructions. " + "Fix the issue or pass override=true if you believe this is a mistake." + ) + return "\n".join(lines) + + +def create_harness( + task: Task, + llm_model: Optional[str] = None, + memory_model: Optional[str] = None, + language: str = "english", + constraints_file: Optional[str] = None, +) -> SoundnessJudgeHarness: + """ + Build a SoundnessJudgeHarness for a task. + + Args: + task: The task to create the harness for. + llm_model: Model for judgment. + memory_model: Model for memory writing (can be smaller/cheaper). + Defaults to llm_model if not specified. + language: Prompt language. + constraints_file: Path to pre-extracted constraints JSON. + Required — no live extraction fallback. + """ + if llm_model is None: + from vita.config import DEFAULT_LLM_EVALUATOR + llm_model = DEFAULT_LLM_EVALUATOR + + if memory_model is None: + memory_model = llm_model + + # Phase 1: Load pre-extracted constraints (no fallback) + constraints = _load_constraints(task, constraints_file) + + if constraints is None: + raise ValueError( + f"Task {task.id}: No constraints found. " + f"Please provide a valid constraints_file or run the pre-extraction script." + ) + + # Phase 2: Create memory writer (seeded with user profile) + # Memory writer: no model config passed — keeps thinking disabled, cheap calls + user_profile = "" + if task.user_scenario and task.user_scenario.user_profile: + profile = task.user_scenario.user_profile + if isinstance(profile, dict): + user_profile = json.dumps(profile, ensure_ascii=False, indent=2) + else: + user_profile = str(profile) + + # memory_llm_args = dict(model_configs.get(memory_model, model_configs.get("default", {}))) + memory_writer = MemoryWriter( + constraints=constraints, + llm_model=memory_model, + # llm_args=memory_llm_args, + user_profile=user_profile or None, + language=language, + ) + + # Phase 3: Create judge (pass full model config for thinking etc.) + judge_llm_args = dict(model_configs.get(llm_model, model_configs.get("default", {}))) + judge = ConstraintJudge( + constraints=constraints, + llm_model=llm_model, + llm_args=judge_llm_args, + language=language, + ) + + logger.info("Task %s: soundness judge harness created (%s / memory: %s)", task.id, llm_model, memory_model) + return SoundnessJudgeHarness(constraints, memory_writer, judge) + + +def _load_constraints( + task: Task, + constraints_file: Optional[str], +) -> ExtractedConstraintSet: + """Load pre-extracted constraints from file. No live extraction fallback.""" + if constraints_file is None: + return None + + path = Path(constraints_file) + if not path.exists(): + raise FileNotFoundError( + f"Task {task.id}: constraints file not found at {constraints_file}. " + f"Run the pre-extraction script first." + ) + + with open(path) as f: + data = json.load(f) + + task_data = data.get(task.id) + if task_data is None: + raise FileNotFoundError( + f"Task {task.id} not found in constraints file {constraints_file}. " + f"Run the pre-extraction script for this task." + ) + task_data["task_id"] = task.id + return ExtractedConstraintSet.model_validate(task_data) diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/constraint_extractor.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/constraint_extractor.py new file mode 100644 index 0000000..1ded5ef --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/constraint_extractor.py @@ -0,0 +1,81 @@ +""" +Constraint extractor for the NL soundness judge harness. + +Preserves natural language descriptions and categorizes constraints by type +(date, price, city, etc.) for LLM-based judgment. +""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from vita.config import DEFAULT_LLM_EVALUATOR, models +from vita.data_model.message import SystemMessage, UserMessage +from vita.data_model.tasks import Task +from vita.domains.ota.soundness_judge_harness.schema import ExtractedConstraintSet +from vita.domains.ota.verifier.utils import _extract_json +from vita.prompts import get_prompts +from vita.utils.llm_utils import generate + +logger = logging.getLogger(__name__) + + +def extract_constraints( + task: Task, + user_message: str, + llm_model: Optional[str] = None, + llm_args: Optional[dict] = None, + language: str = "english", +) -> ExtractedConstraintSet: + """ + Extract NL constraints from a user's travel request. + + Args: + task: The Task object (used for metadata: id, environment, user_profile). + user_message: The user simulation message to extract constraints from. + + Returns an ExtractedConstraintSet with categorized, natural-language constraints. + """ + if llm_model is None: + llm_model = DEFAULT_LLM_EVALUATOR + + user_scenario = task.user_scenario + user_profile = "" + if user_scenario and user_scenario.user_profile: + profile = user_scenario.user_profile + if isinstance(profile, dict): + user_profile = json.dumps(profile, ensure_ascii=False, indent=2) + else: + user_profile = str(profile) + + system_time = "" + if task.environment and "system_time" in task.environment: + system_time = task.environment["system_time"] + + system_content = get_prompts(language).harness_constraint_extraction_template.format( + system_time=system_time or "(not specified)", + user_profile=user_profile or "(none)", + ) + + # Use user_message from user sim file + messages = [ + SystemMessage(role="system", content=system_content), + UserMessage(role="user", content=user_message), + ] + + kwargs = dict(llm_args or {}) + kwargs.setdefault("temperature", 0) + + response = generate(llm_model, messages, enable_think=True, **kwargs) + raw = response.content if hasattr(response, "content") else str(response) + + parsed = _extract_json(raw) + if parsed is None: + raise ValueError(f"Failed to parse constraint extraction output for task {task.id}") + + # Ensure task_id is set + parsed["task_id"] = task.id + + return ExtractedConstraintSet.model_validate(parsed) diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/judge.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/judge.py new file mode 100644 index 0000000..8d88eb9 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/judge.py @@ -0,0 +1,220 @@ +""" +Judgment module for the NL soundness judge harness. + +Given a write tool call, evaluates each relevant constraint individually +against the accumulated memory. One LLM call per constraint, results +ANDed in code (any violation → overall violation). +""" + +from __future__ import annotations + +import json +import logging +import re +from typing import Optional + +from vita.data_model.message import SystemMessage, UserMessage +from vita.domains.ota.soundness_judge_harness.schema import ( + Constraint, + ConstraintJudgment, + ConstraintVerdict, + EntityType, + ExtractedConstraintSet, + JudgmentResult, + TaskMemory, +) +from vita.prompts import get_prompts +from vita.utils.llm_utils import generate + +logger = logging.getLogger(__name__) + +# Map tool call name prefixes to entity types +_TOOL_TO_ENTITY = { + "hotel": EntityType.HOTEL, + "flight": EntityType.FLIGHT, + "train": EntityType.TRAIN, + "attraction": EntityType.ATTRACTION, +} + + +def _entity_type_from_tool(tool_call_name: str) -> Optional[EntityType]: + """Infer entity type from tool call name.""" + for key, etype in _TOOL_TO_ENTITY.items(): + if key in tool_call_name: + return etype + return None + + +def _parse_single_judgment(text: str) -> Optional[dict]: + """Parse a single JSON verdict from LLM response.""" + text = text.split("")[-1].strip() + + # Try fenced blocks + fence_matches = re.findall(r"```(?:json)?\s*(.*?)```", text, re.DOTALL) + for candidate in reversed(fence_matches): + try: + result = json.loads(candidate) + if isinstance(result, dict): + return result + except json.JSONDecodeError: + continue + + # Try raw JSON object + brace_start = text.find("{") + if brace_start == -1: + return None + + depth = 0 + in_string = False + escape = False + for i in range(brace_start, len(text)): + ch = text[i] + if escape: + escape = False + continue + if ch == "\\" and in_string: + escape = True + continue + if ch == '"' and not escape: + in_string = not in_string + continue + if in_string: + continue + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + try: + return json.loads(text[brace_start : i + 1]) + except json.JSONDecodeError: + return None + return None + + +class ConstraintJudge: + """ + Evaluates write tool calls against extracted constraints using memory context. + + Makes one LLM call per relevant constraint, then ANDs results: + any single VIOLATED → overall violation. + """ + + def __init__( + self, + constraints: ExtractedConstraintSet, + llm_model: str, + llm_args: Optional[dict] = None, + language: str = "english", + ): + self.constraints = constraints + self.llm_model = llm_model + self.llm_args = llm_args or {} + self.language = language + + def _get_relevant_constraints(self, tool_call_name: str) -> list[Constraint]: + """Filter constraints to those relevant to the tool call's entity type.""" + entity_type = _entity_type_from_tool(tool_call_name) + relevant = [] + for c in self.constraints.constraints: + # Include if: constraint has no entity_type (cross-cutting), + # or matches the tool call's entity type + if c.entity_type is None or c.entity_type == entity_type: + relevant.append(c) + return relevant + + def judge( + self, + tool_call_name: str, + tool_call_args: dict, + memory: TaskMemory, + ) -> JudgmentResult: + """ + Judge a write tool call against relevant constraints. + + One LLM call per constraint, results ANDed in code. + """ + relevant = self._get_relevant_constraints(tool_call_name) + + if not relevant: + return JudgmentResult( + tool_call_name=tool_call_name, + tool_call_args=tool_call_args, + judgments=[], + ) + + memory_text = memory.render() + current_call = f"{tool_call_name}({json.dumps(tool_call_args, ensure_ascii=False)})" + + judgments: list[ConstraintJudgment] = [] + for constraint in relevant: + judgment = self._judge_single(constraint, current_call, memory_text) + judgments.append(judgment) + + return JudgmentResult( + tool_call_name=tool_call_name, + tool_call_args=tool_call_args, + judgments=judgments, + ) + + def _judge_single( + self, + constraint: Constraint, + current_call: str, + memory_text: str, + ) -> ConstraintJudgment: + """Evaluate a single constraint against the tool call.""" + cat = constraint.category.value + etype = f" [{constraint.entity_type.value}]" if constraint.entity_type else "" + constraint_text = f"[{constraint.id}] ({cat}{etype}) {constraint.description}" + + user_content = ( + f"## Constraint\n{constraint_text}\n\n" + f"## Memory (observed facts from prior tool calls)\n{memory_text}\n\n" + f"## Write Tool Call to Evaluate\n{current_call}" + ) + + messages = [ + SystemMessage(role="system", content=get_prompts(self.language).harness_soundness_judge_template), + UserMessage(role="user", content=user_content), + ] + + kwargs = dict(self.llm_args) + kwargs.setdefault("temperature", 0) + + try: + response = generate(self.llm_model, messages, enable_think=True, **kwargs) + raw = response.content if hasattr(response, "content") else str(response) + return self._parse_response(constraint.id, raw) + except Exception as e: + logger.warning("Judge call failed for constraint %s: %s", constraint.id, e) + return ConstraintJudgment( + constraint_id=constraint.id, + verdict=ConstraintVerdict.UNDETERMINED, + reasoning=f"Judge call failed: {e}", + ) + + def _parse_response(self, constraint_id: str, raw: str) -> ConstraintJudgment: + """Parse single-constraint LLM response.""" + parsed = _parse_single_judgment(raw) + if parsed is None: + logger.warning("Failed to parse judge response for %s", constraint_id) + return ConstraintJudgment( + constraint_id=constraint_id, + verdict=ConstraintVerdict.UNDETERMINED, + reasoning="Parse failure — could not interpret judge response", + ) + + verdict_str = parsed.get("verdict", "undetermined").lower() + reasoning = parsed.get("reasoning", "") + + try: + verdict = ConstraintVerdict(verdict_str) + except ValueError: + verdict = ConstraintVerdict.UNDETERMINED + + return ConstraintJudgment( + constraint_id=constraint_id, + verdict=verdict, + reasoning=reasoning, + ) diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/memory_store.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/memory_store.py new file mode 100644 index 0000000..436836a --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/memory_store.py @@ -0,0 +1,127 @@ +""" +Runtime memory store for the NL soundness judge harness. + +After each read-type tool call, an SLM distills the response into +facts relevant to the extracted constraints. This accumulated memory +is used during judgment of write tool calls. +""" + +from __future__ import annotations + +import json +import logging +from typing import Optional + +from vita.config import models +from vita.data_model.message import SystemMessage, UserMessage +from vita.domains.ota.soundness_judge_harness.schema import ( + Constraint, + ExtractedConstraintSet, + TaskMemory, +) +from vita.prompts import get_prompts +from vita.utils.llm_utils import generate + +logger = logging.getLogger(__name__) + +WRITE_PREFIXES = ("create_", "cancel_", "modify_", "pay_") + + +class MemoryWriter: + """ + Manages the per-task memory store during simulation. + + On each non-write tool call, asks an SLM to distill relevant facts + from the tool response into the memory. + """ + + def __init__( + self, + constraints: ExtractedConstraintSet, + llm_model: str, + llm_args: Optional[dict] = None, + max_response_len: int = 3000, + user_profile: Optional[str] = None, + language: str = "english", + ): + self.constraints = constraints + self.llm_model = llm_model + self.llm_args = llm_args or {} + self.max_response_len = max_response_len + self.memory = TaskMemory() + self.language = language + + # Seed memory with user profile + if user_profile: + self.memory.append(source_tool="user_profile", summary=user_profile) + + # Pre-render constraints text for the prompt + self._constraints_text = self._render_constraints() + + def _render_constraints(self) -> str: + lines = [] + for c in self.constraints.constraints: + cat = c.category.value + etype = f" [{c.entity_type.value}]" if c.entity_type else "" + lines.append(f"- [{c.id}] ({cat}{etype}) {c.description}") + return "\n".join(lines) + + def is_read_call(self, tool_call_name: str) -> bool: + """Returns True if this tool call is a read (non-write) operation.""" + return not any(tool_call_name.startswith(p) for p in WRITE_PREFIXES) + + def process_tool_response( + self, + tool_call_name: str, + tool_call_args: dict, + tool_response: str, + ) -> None: + """ + Process a read tool call response and update memory if relevant. + + Should be called for every non-write tool call after execution. + """ + if not self.is_read_call(tool_call_name): + return + + # Truncate long responses + response_text = tool_response + if len(response_text) > self.max_response_len: + response_text = response_text[: self.max_response_len] + "... (truncated)" + + system_content = get_prompts(self.language).harness_memory_writer_template.format( + constraints_text=self._constraints_text, + ) + + user_content = ( + f"## Tool Call\n" + f"{tool_call_name}({json.dumps(tool_call_args, ensure_ascii=False)})\n\n" + f"## Response\n{response_text}" + ) + + messages = [ + SystemMessage(role="system", content=system_content), + UserMessage(role="user", content=user_content), + ] + + kwargs = dict(self.llm_args) + kwargs.setdefault("temperature", 0) + + try: + response = generate(self.llm_model, messages, enable_think=False, **kwargs) + summary = response.content if hasattr(response, "content") else str(response) + summary = summary.strip() + + if summary and summary != "NOTHING_RELEVANT": + self.memory.append(source_tool=tool_call_name, summary=summary) + logger.debug("Memory updated from %s: %s", tool_call_name, summary[:80]) + except Exception as e: + logger.warning("Memory writer failed for %s: %s", tool_call_name, e) + + def get_memory(self) -> TaskMemory: + """Return the current memory state.""" + return self.memory + + def render_memory(self) -> str: + """Render memory as text for the judge.""" + return self.memory.render() diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/schema.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/schema.py new file mode 100644 index 0000000..b665eaf --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_harness/schema.py @@ -0,0 +1,116 @@ +""" +Schema for the NL-based soundness judge harness. + +Constraints are extracted preserving natural language descriptions, +categorized by type for structured reasoning. +""" + +from __future__ import annotations + +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel, Field + + +class ConstraintCategory(str, Enum): + DATE = "date" + PRICE = "price" + CITY = "city" + QUANTITY = "quantity" + DURATION = "duration" + ENTITY_ATTRIBUTE = "entity_attribute" # room type, seat class, ticket type, etc. + OTHER = "other" + + +class EntityType(str, Enum): + HOTEL = "hotel" + FLIGHT = "flight" + TRAIN = "train" + ATTRACTION = "attraction" + + +class Constraint(BaseModel): + """A single constraint extracted from user instructions.""" + + id: str = Field(description="Unique ID, e.g. 'c1', 'c2'") + category: ConstraintCategory + entity_type: Optional[EntityType] = Field( + default=None, + description="Which booking entity this applies to, if specific", + ) + description: str = Field( + description="Natural language description of the constraint, " + "preserving the user's intent as closely as possible", + ) + + +class ExtractedConstraintSet(BaseModel): + """Full set of constraints extracted from a task's instructions.""" + + task_id: str + constraints: List[Constraint] + + +# ────────────────────────────────────────────── +# Memory store schema +# ────────────────────────────────────────────── + + +class MemoryEntry(BaseModel): + """A single fact recorded from a read tool call.""" + + source_tool: str = Field(description="Tool call that produced this info, e.g. 'search_hotels'") + summary: str = Field(description="Key information relevant to constraint evaluation") + + +class TaskMemory(BaseModel): + """Runtime memory accumulator for a single task simulation.""" + + entries: List[MemoryEntry] = Field(default_factory=list) + + def append(self, source_tool: str, summary: str): + self.entries.append(MemoryEntry(source_tool=source_tool, summary=summary)) + + def render(self) -> str: + if not self.entries: + return "(no observations recorded yet)" + lines = [] + for i, e in enumerate(self.entries, 1): + lines.append(f"[{i}] ({e.source_tool}) {e.summary}") + return "\n".join(lines) + + +# ────────────────────────────────────────────── +# Judgment schema +# ────────────────────────────────────────────── + + +class ConstraintVerdict(str, Enum): + VIOLATED = "violated" + CONSISTENT = "consistent" + UNDETERMINED = "undetermined" + + +class ConstraintJudgment(BaseModel): + """Judgment for a single constraint against a write tool call.""" + + constraint_id: str + verdict: ConstraintVerdict + reasoning: str = Field(description="Brief explanation for the verdict") + + +class JudgmentResult(BaseModel): + """Aggregated judgment across all constraints for one write tool call.""" + + tool_call_name: str + tool_call_args: dict + judgments: List[ConstraintJudgment] + + @property + def has_violation(self) -> bool: + return any(j.verdict == ConstraintVerdict.VIOLATED for j in self.judgments) + + @property + def violated_constraints(self) -> List[ConstraintJudgment]: + return [j for j in self.judgments if j.verdict == ConstraintVerdict.VIOLATED] diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_llm/__init__.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_llm/__init__.py new file mode 100644 index 0000000..a0c0180 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_llm/__init__.py @@ -0,0 +1,13 @@ +""" +LLM-based soundness judge for OTA. + +Uses an LLM to evaluate whether a write tool call (create/pay/cancel/modify) is +consistent with the user's instructions, given the tool call history so far. + +Usage: + from vita.domains.ota.soundness_judge_llm import SoundnessJudge, SoundnessJudgeConfig +""" + +from vita.domains.ota.soundness_judge_llm.judge import SoundnessJudge, SoundnessJudgeConfig + +__all__ = ["SoundnessJudge", "SoundnessJudgeConfig"] diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_llm/judge.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_llm/judge.py new file mode 100644 index 0000000..db8b229 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/soundness_judge_llm/judge.py @@ -0,0 +1,127 @@ +""" +LLM-based soundness judge for OTA verifier. + +Uses an LLM to evaluate whether a write tool call (create/pay/cancel/modify) is +consistent with the user's instructions, given the tool call history so far. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from typing import Optional + +from vita.data_model.message import SystemMessage, UserMessage +from vita.prompts import get_prompts +from vita.utils.llm_utils import generate +from vita.domains.ota.verifier.utils import _extract_tool_history, _extract_json + +logger = logging.getLogger(__name__) + + +@dataclass +class SoundnessJudgeConfig: + """Configuration for the LLM-based soundness judge.""" + + llm_model: str + """Model name to use for judging (e.g. 'claude-sonnet-4.6').""" + + llm_args: dict = field(default_factory=dict) + """Extra kwargs forwarded to the generate() call (temperature, max_tokens, etc.).""" + + language: str = "english" + """Language for the prompt template ('english' or 'chinese').""" + + max_response_len: int = 2000 + """Max characters per tool response shown to the judge (longer responses are truncated).""" + + +def _format_tool_trace(tool_trace: list[dict], max_response_len: int = 2000) -> str: + """Format tool trace into a readable string for the LLM.""" + if not tool_trace: + return "(no prior tool calls)" + + lines = [] + for i, tc in enumerate(tool_trace, 1): + lines.append(f"[{i}] {tc['name']}({json.dumps(tc['arguments'], ensure_ascii=False)})") + resp = tc["response"] + if len(resp) > max_response_len: + resp = resp[:max_response_len] + "... (truncated)" + lines.append(f" → {resp}") + return "\n".join(lines) + + +class SoundnessJudge: + """ + LLM-based judge that evaluates whether a write tool call + is consistent with the user's original instructions. + """ + + def __init__( + self, + user_instruction: str, + config: SoundnessJudgeConfig, + ): + self.user_instruction = user_instruction + self.config = config + + prompts = get_prompts(config.language) + self.system_prompt = prompts.soundness_judge_template + + def judge( + self, + tool_call_name: str, + tool_call_args: dict, + trajectory: list, + ) -> tuple[str, Optional[str]]: + """ + Judge whether a write tool call should be allowed or blocked. + + Returns a (verdict, reason) tuple where verdict is 'ALLOW' or 'BLOCK', + and reason is None when not applicable. + """ + # Reuse the shared tool history extractor + tool_trace = _extract_tool_history(trajectory) + trace_str = _format_tool_trace(tool_trace, self.config.max_response_len) + + current_call = f"{tool_call_name}({json.dumps(tool_call_args, ensure_ascii=False)})" + + user_content = ( + f"## User Instruction\n{self.user_instruction}\n\n" + f"## Tool Call History\n{trace_str}\n\n" + f"## Current Tool Call (to judge)\n{current_call}" + ) + + messages = [ + SystemMessage(role="system", content=self.system_prompt), + UserMessage(role="user", content=user_content), + ] + + try: + response = generate( + model=self.config.llm_model, + messages=messages, + **self.config.llm_args, + ) + except Exception as e: + logger.warning("Soundness judge failed, allowing call: %s", e) + return "ALLOW", None + + raw = response.content or "" + return _parse_verdict(raw) + + +def _parse_verdict(raw: str) -> tuple[str, Optional[str]]: + """Parse the LLM's JSON verdict. Defaults to ALLOW on parse failure.""" + + parsed = _extract_json(raw) + if not isinstance(parsed, dict): + logger.warning("Failed to parse LLM judge response, defaulting to ALLOW: %s", raw[:200]) + return "ALLOW", None + + verdict = parsed.get("verdict", "ALLOW").upper() + reason = parsed.get("reason") or None + if verdict not in ("ALLOW", "BLOCK"): + verdict = "ALLOW" + return verdict, reason diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools.py index 07c29ea..822e72a 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools.py @@ -1,4 +1,13 @@ -"""Toolkit for the OTA domain.""" +"""Toolkit for the OTA domain. + +VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/domains/ota/tools.py. +Everything is verbatim from the original except for the following change: + +1. Added an optional ``override: bool = False`` parameter to every WRITE tool + (create/modify/cancel hotel, attraction, flight and train orders) so the + agent can bypass a soundness block when it is confident its action is correct. +""" from typing import List, Optional, Union @@ -335,7 +344,7 @@ def train_ticket_search(self, departure: str, destination: str, date: str) -> st return trains_repr @is_tool(tool_type=ToolType.WRITE) - def create_hotel_order(self, hotel_id: str, room_id: str, user_id: str) -> str: + def create_hotel_order(self, hotel_id: str, room_id: str, user_id: str, override: bool = False) -> str: assert hotel_id, "Hotel ID cannot be empty" assert room_id, "Room ID cannot be empty" assert user_id, "User ID cannot be empty" @@ -386,7 +395,7 @@ def create_hotel_order(self, hotel_id: str, room_id: str, user_id: str) -> str: return f"Failed to create order: {response}" @is_tool(tool_type=ToolType.WRITE) - def create_attraction_order(self, attraction_id: str, ticket_id: str, user_id: str, date: str, quantity: int) -> str: + def create_attraction_order(self, attraction_id: str, ticket_id: str, user_id: str, date: str, quantity: int, override: bool = False) -> str: assert attraction_id, "Attraction ID cannot be empty" assert ticket_id, "Ticket ID cannot be empty" assert user_id, "User ID cannot be empty" @@ -443,7 +452,7 @@ def create_attraction_order(self, attraction_id: str, ticket_id: str, user_id: s return f"Failed to create order: {response}" @is_tool(tool_type=ToolType.WRITE) - def create_flight_order(self, flight_id: str, seat_id: str, user_id: str, date: str, quantity: int) -> str: + def create_flight_order(self, flight_id: str, seat_id: str, user_id: str, date: str, quantity: int, override: bool = False) -> str: assert flight_id, "Flight ID cannot be empty" assert seat_id, "Seat ID cannot be empty" assert user_id, "User ID cannot be empty" @@ -500,7 +509,7 @@ def create_flight_order(self, flight_id: str, seat_id: str, user_id: str, date: return f"Failed to create order: {response}" @is_tool(tool_type=ToolType.WRITE) - def create_train_order(self, train_id: str, seat_id: str, user_id: str, date: str, quantity: int) -> str: + def create_train_order(self, train_id: str, seat_id: str, user_id: str, date: str, quantity: int, override: bool = False) -> str: assert train_id, "Train ID cannot be empty" assert seat_id, "Seat ID cannot be empty" assert user_id, "User ID cannot be empty" @@ -869,7 +878,7 @@ def get_train_order_detail(self, order_id: str) -> str: return repr(order) @is_tool(tool_type=ToolType.WRITE) - def modify_train_order(self, order_id: str, user_id: str, new_date: str) -> str: + def modify_train_order(self, order_id: str, user_id: str, new_date: str, override: bool = False) -> str: assert order_id, "Order ID cannot be empty" assert user_id, "User ID cannot be empty" assert new_date, "New departure date cannot be empty" @@ -951,7 +960,7 @@ def modify_train_order(self, order_id: str, user_id: str, new_date: str) -> str: return f"Modification failed: {response}" @is_tool(tool_type=ToolType.WRITE) - def modify_flight_order(self, order_id: str, user_id: str, new_date: str) -> str: + def modify_flight_order(self, order_id: str, user_id: str, new_date: str, override: bool = False) -> str: assert order_id, "Order ID cannot be empty" assert user_id, "User ID cannot be empty" assert new_date, "New departure date cannot be empty" @@ -1033,7 +1042,7 @@ def modify_flight_order(self, order_id: str, user_id: str, new_date: str) -> str return f"Modification failed: {response}" @is_tool(tool_type=ToolType.WRITE) - def cancel_hotel_order(self, order_id: str, user_id: str) -> str: + def cancel_hotel_order(self, order_id: str, user_id: str, override: bool = False) -> str: assert order_id, "Order ID cannot be empty" assert user_id, "User ID cannot be empty" assert self._check_user(user_id), "User ID does not match" @@ -1065,7 +1074,7 @@ def cancel_hotel_order(self, order_id: str, user_id: str) -> str: return f"Cancellation failed: {response}" @is_tool(tool_type=ToolType.WRITE) - def cancel_attraction_order(self, order_id: str, user_id: str) -> str: + def cancel_attraction_order(self, order_id: str, user_id: str, override: bool = False) -> str: assert order_id, "Order ID cannot be empty" assert user_id, "User ID cannot be empty" assert self._check_user(user_id), "User ID does not match" @@ -1097,7 +1106,7 @@ def cancel_attraction_order(self, order_id: str, user_id: str) -> str: return f"Cancellation failed: {response}" @is_tool(tool_type=ToolType.WRITE) - def cancel_flight_order(self, order_id: str, user_id: str) -> str: + def cancel_flight_order(self, order_id: str, user_id: str, override: bool = False) -> str: assert order_id, "Order ID cannot be empty" assert user_id, "User ID cannot be empty" assert self._check_user(user_id), "User ID does not match" @@ -1129,7 +1138,7 @@ def cancel_flight_order(self, order_id: str, user_id: str) -> str: return f"Cancellation failed: {response}" @is_tool(tool_type=ToolType.WRITE) - def cancel_train_order(self, order_id: str, user_id: str) -> str: + def cancel_train_order(self, order_id: str, user_id: str, override: bool = False) -> str: assert order_id, "Order ID cannot be empty" assert user_id, "User ID cannot be empty" assert self._check_user(user_id), "User ID does not match" diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools_schema.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools_schema.py index aaf8b3f..f54e0f3 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools_schema.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/tools_schema.py @@ -2,6 +2,13 @@ This file contains the descriptions and mappings for all tools decorated with @is_tool in the OTATools class. + +VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/domains/ota/tools_schema.py. +Everything is verbatim from the original except for the following change: + +1. Added the ``override`` argument description (Chinese and English) to the + schema of every WRITE tool that gained the ``override`` parameter in tools.py. """ from typing import Dict, Any @@ -109,7 +116,8 @@ "args": { "hotel_id": "酒店ID", "room_id": "房间ID", - "user_id": "用户ID" + "user_id": "用户ID", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "创建订单操作的反馈输出", "tool_type": "WRITE" @@ -124,7 +132,8 @@ "ticket_id": "门票ID", "user_id": "用户ID", "date": "参观日期,格式为 %Y-%m-%d", - "quantity": "数量" + "quantity": "数量", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "创建订单操作的反馈输出", "tool_type": "WRITE" @@ -139,7 +148,8 @@ "seat_id": "座位ID", "user_id": "用户ID", "date": "出发日期", - "quantity": "数量" + "quantity": "数量", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "创建订单操作的反馈输出", "tool_type": "WRITE" @@ -154,7 +164,8 @@ "seat_id": "座位ID", "user_id": "用户ID", "date": "出发日期", - "quantity": "数量" + "quantity": "数量", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "创建订单操作的反馈输出", "tool_type": "WRITE" @@ -307,7 +318,8 @@ "args": { "order_id": "订单ID", "user_id": "用户ID", - "new_date": "新的出发日期,格式为 %Y-%m-%d" + "new_date": "新的出发日期,格式为 %Y-%m-%d", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "(修改后的订单内容, 差价,正为需补差价,负为退差价)", "tool_type": "WRITE" @@ -320,7 +332,8 @@ "args": { "order_id": "订单ID", "user_id": "用户ID", - "new_date": "新的出发日期,格式为 %Y-%m-%d" + "new_date": "新的出发日期,格式为 %Y-%m-%d", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "(修改后的订单内容, 差价,正为需补差价,负为退差价)", "tool_type": "WRITE" @@ -332,7 +345,8 @@ "postconditions": "取消订单并更新订单状态,若需退差价,告知用户即可", "args": { "order_id": "订单ID", - "user_id": "用户ID" + "user_id": "用户ID", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "取消订单的退款金额", "tool_type": "WRITE" @@ -344,7 +358,8 @@ "postconditions": "如果退差价,告知用户即可", "args": { "order_id": "订单ID", - "user_id": "用户ID" + "user_id": "用户ID", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "取消订单的退款金额", "tool_type": "WRITE" @@ -356,7 +371,8 @@ "postconditions": "如果退差价,告知用户即可", "args": { "order_id": "订单ID", - "user_id": "用户ID" + "user_id": "用户ID", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "取消订单的退款金额", "tool_type": "WRITE" @@ -368,7 +384,8 @@ "postconditions": "如果退差价,告知用户即可", "args": { "order_id": "订单ID", - "user_id": "用户ID" + "user_id": "用户ID", + "override": "默认false,保持false除非你收到验证反馈且确信你的操作是正确的" }, "returns": "取消订单的退款金额", "tool_type": "WRITE" @@ -477,7 +494,8 @@ "args": { "hotel_id": "Hotel ID", "product_id": "Room ID", - "user_id": "User ID" + "user_id": "User ID", + "override": "Default false. Only set to true if you received soundness feedback and are confident your action is correct" }, "returns": "Feedback output of creating order operation", "tool_type": "WRITE" @@ -492,7 +510,8 @@ "ticket_id": "Ticket ID", "user_id": "User ID", "date": "Visit date, format: %Y-%m-%d", - "quantity": "Quantity" + "quantity": "Quantity", + "override": "Default false. Only set to true if you received soundness feedback and are confident your action is correct" }, "returns": "Feedback output of creating order operation", "tool_type": "WRITE" @@ -507,7 +526,8 @@ "seat_id": "Seat ID", "user_id": "User ID", "date": "Departure date, format: %Y-%m-%d", - "quantity": "Quantity" + "quantity": "Quantity", + "override": "Default false. Only set to true if you received soundness feedback and are confident your action is correct" }, "returns": "Feedback output of creating order operation", "tool_type": "WRITE" @@ -522,7 +542,8 @@ "seat_id": "Seat ID", "user_id": "User ID", "date": "Departure date, format: %Y-%m-%d", - "quantity": "Quantity" + "quantity": "Quantity", + "override": "Default false. Only set to true if you received soundness feedback and are confident your action is correct" }, "returns": "Feedback output of creating order operation", "tool_type": "WRITE" @@ -675,7 +696,8 @@ "args": { "order_id": "Order ID", "user_id": "User ID", - "new_date": "New departure date, format: %Y-%m-%d" + "new_date": "New departure date, format: %Y-%m-%d", + "override": "Default false. Only set to true if you received validation feedback and are confident your action is correct" }, "returns": "(Modified order content, price difference, positive means need to compensate, negative means refund)", "tool_type": "WRITE" @@ -688,7 +710,8 @@ "args": { "order_id": "Order ID", "user_id": "User ID", - "new_date": "New departure date, format: %Y-%m-%d" + "new_date": "New departure date, format: %Y-%m-%d", + "override": "Default false. Only set to true if you received validation feedback and are confident your action is correct" }, "returns": "(Modified order content, price difference, positive means need to compensate, negative means refund)", "tool_type": "WRITE" @@ -700,7 +723,8 @@ "postconditions": "Cancel order and update order status, if refund is needed, inform user", "args": { "order_id": "Order ID", - "user_id": "User ID" + "user_id": "User ID", + "override": "Default false. Only set to true if you received validation feedback and are confident your action is correct" }, "returns": "Order cancellation refund amount", "tool_type": "WRITE" @@ -712,7 +736,8 @@ "postconditions": "If refund is needed, inform user", "args": { "order_id": "Order ID", - "user_id": "User ID" + "user_id": "User ID", + "override": "Default false. Only set to true if you received validation feedback and are confident your action is correct" }, "returns": "Order cancellation refund amount", "tool_type": "WRITE" @@ -724,7 +749,8 @@ "postconditions": "If refund is needed, inform user", "args": { "order_id": "Order ID", - "user_id": "User ID" + "user_id": "User ID", + "override": "Default false. Only set to true if you received validation feedback and are confident your action is correct" }, "returns": "Order cancellation refund amount", "tool_type": "WRITE" @@ -736,7 +762,8 @@ "postconditions": "If refund is needed, inform user", "args": { "order_id": "Order ID", - "user_id": "User ID" + "user_id": "User ID", + "override": "Default false. Only set to true if you received validation feedback and are confident your action is correct" }, "returns": "Order cancellation refund amount", "tool_type": "WRITE" diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/verifier/__init__.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/verifier/__init__.py new file mode 100644 index 0000000..2ee65c9 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/verifier/__init__.py @@ -0,0 +1,205 @@ +""" +OTA Verifier — runs during simulation. + +Two independent subsystems: + - Soundness (LLM judge): validates each write tool call BEFORE execution. + - Completeness (rule-based): checks final order state when the agent stops. + +Usage: + from vita.domains.ota.verifier import create_verifier + verifier = create_verifier(task, language="english") +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Optional + +from vita.data_model.tasks import Task +from vita.config import models as model_configs +from vita.domains.ota.completeness import check_completeness, CompletenessConstraints +from vita.domains.ota.soundness_judge_llm import SoundnessJudge, SoundnessJudgeConfig +from vita.domains.ota.soundness_judge_harness import SoundnessJudgeHarness + +logger = logging.getLogger(__name__) + + +class OTAVerifier: + """ + Live verifier with two subsystems: + - Soundness: LLM judge blocks bad write tool calls. + - Completeness: rule-based check that all required bookings exist at stop time. + """ + + WRITE_PREFIXES = ("create_", "cancel_", "modify_") + + def __init__( + self, + completeness_constraints: Optional[CompletenessConstraints], + environment: dict, + soundness_mode: str = "off", + completeness_mode: str = "on", + soundness_judge: Optional[SoundnessJudge] = None, + soundness_harness: Optional[SoundnessJudgeHarness] = None, + ): + self.soundness_mode = soundness_mode + self.completeness_mode = completeness_mode + self.soundness_judge = soundness_judge + self.soundness_harness = soundness_harness + self.environment = environment + self.soundness_call_log: list[dict] = [] + self.completeness_constraints = completeness_constraints + + def check_soundness(self, tool_call_name: str, tool_call_args: dict, trajectory: Optional[list] = None) -> Optional[str]: + """ + Validate a write tool call BEFORE execution. + Dispatches to the appropriate handler based on soundness_mode. + """ + if not any(tool_call_name.startswith(p) for p in self.WRITE_PREFIXES): + return None + if tool_call_args.get("override") is True: + return None + + if self.soundness_mode == "off": + return None + elif self.soundness_mode == "llm": + return self._check_soundness_llm(tool_call_name, tool_call_args, trajectory) + elif self.soundness_mode == "harness": + return self._check_soundness_harness(tool_call_name, tool_call_args, trajectory) + else: + logger.warning("Unknown soundness_mode '%s', allowing call", self.soundness_mode) + return None + + def _check_soundness_llm(self, tool_call_name: str, tool_call_args: dict, trajectory: Optional[list]) -> Optional[str]: + """LLM judge evaluates the tool call against user instructions + tool history.""" + verdict, reason = self.soundness_judge.judge(tool_call_name, tool_call_args, trajectory or []) + self.soundness_call_log.append({ + "tool_call_name": tool_call_name, + "tool_call_args": tool_call_args, + "verdict": verdict, + "reason": reason, + }) + if verdict != "BLOCK": + return None + + return ( + f"[SOUNDNESS CHECK] Order blocked by LLM judge:\n" + f" - {reason}\n" + f"Please review your tool call against the user instructions and the conversation history.\n" + f"Fix the issue or pass override=true if you believe this is a mistake." + ) + + def _check_soundness_harness(self, tool_call_name: str, tool_call_args: dict, trajectory: Optional[list]) -> Optional[str]: + """Delegate soundness check to the NL constraint harness.""" + result = self.soundness_harness.check_soundness(tool_call_name, tool_call_args, trajectory) + # Merge harness call log into our own + self.soundness_call_log = self.soundness_harness.soundness_call_log + return result + + def observe_tool_response(self, tool_call_name: str, tool_call_args: dict, tool_response: str) -> None: + """Forward tool responses to the harness memory store (no-op if not in harness mode).""" + if self.soundness_harness: + self.soundness_harness.observe_tool_response(tool_call_name, tool_call_args, tool_response) + + def check_on_stop(self, trajectory: list, remaining: int = 0, new_orders: list = (), old_orders: list = ()) -> Optional[str]: + """ + Called when the agent decides to stop. Runs a full completeness check. + Returns feedback string if bookings are incomplete, None if all good. + """ + if self.completeness_mode == "off": + return None + missing = check_completeness( + new_orders=new_orders, + old_orders=old_orders, + constraints=self.completeness_constraints, + environment=self.environment, + ) + if not missing: + return None + + lines = ["\n\n[COMPLETENESS CHECK]\nSome bookings appear to still be incomplete:"] + lines.extend(f" - {m}" for m in missing) + lines.append( + "Please review and complete the missing bookings before ending the conversation. " + f"This message will appear {remaining} more time(s) before your stop is accepted unconditionally." + ) + return "\n".join(lines) + + +def create_verifier( + task: Task, + llm_model: Optional[str] = None, + language: str = "english", + constraints_file: Optional[str] = None, + soundness_mode: str = "off", + completeness_mode: str = "on", + solo_user_message: Optional[str] = None, +) -> OTAVerifier: + """ + Build an OTAVerifier for a task. + + Args: + soundness_mode: "llm" for LLM judge, "harness" for NL constraint harness, "off" to disable. + completeness_mode: "on" to check final order completeness at stop, "off" to disable. + """ + environment = task.environment or {} + completeness_constraints = _load_constraints(task.id, llm_model, constraints_file) + if completeness_constraints is None and completeness_mode != "off": + raise FileNotFoundError( + f"Task {task.id}: no pre-computed completeness constraints found. " + f"Run the pre-extraction script first." + ) + + soundness_judge = None + soundness_harness = None + if soundness_mode == "llm": + try: + judge_config = SoundnessJudgeConfig( + llm_model=llm_model, + llm_args=dict(model_configs.get(llm_model, model_configs.get("default", {}))), + language=language, + ) + soundness_judge = SoundnessJudge( + user_instruction=task.instructions, + config=judge_config, + ) + except Exception as e: + raise RuntimeError( + f"Task {task.id}: failed to create soundness judge ({llm_model}): {e}" + ) from e + logger.info("Task %s: soundness judge enabled (%s)", task.id, llm_model) + elif soundness_mode == "harness": + from vita.domains.ota.soundness_judge_harness import create_harness + soundness_harness = create_harness(task, llm_model=llm_model, language=language) + logger.info("Task %s: soundness harness enabled (%s)", task.id, llm_model) + + return OTAVerifier(completeness_constraints, environment, soundness_mode=soundness_mode, completeness_mode=completeness_mode, soundness_judge=soundness_judge, soundness_harness=soundness_harness) + + +def _load_constraints( + task_id: str, + llm_model: Optional[str], + constraints_file: Optional[str], +) -> Optional[CompletenessConstraints]: + """Load pre-computed completeness constraints from disk.""" + if constraints_file is None: + return None + + constraints_path = Path(constraints_file) + if not constraints_path.exists(): + return None + + try: + with open(constraints_path) as f: + all_constraints = json.load(f) + if task_id in all_constraints: + raw = all_constraints[task_id] + if "_error" not in raw: + c = CompletenessConstraints.model_validate(raw) + logger.info("Task %s: loaded completeness constraints", task_id) + return c + except Exception as e: + logger.warning("Failed to load completeness constraints: %s", e) + return None diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/verifier/utils.py b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/verifier/utils.py new file mode 100644 index 0000000..275b673 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/domains/ota/verifier/utils.py @@ -0,0 +1,88 @@ +"""Shared utility functions for the OTA verifier.""" + +import json +import re + + +def _extract_json(text: str) -> dict | list | None: + """Extract JSON object/array from text that may have thinking prefix. + + When the LLM emits multiple JSON blocks (e.g. it reconsiders and outputs a + corrected version), we take the *last* valid JSON object. + """ + # Strip ... blocks + text = text.split("")[-1].strip() + + # Try all ```json ... ``` fences, take the last one that parses + fence_matches = re.findall(r"```(?:json)?\s*(.*?)```", text, re.DOTALL) + for candidate in reversed(fence_matches): + try: + return json.loads(candidate) + except json.JSONDecodeError: + continue + + # Fallback: find balanced top-level { } blocks, take the last valid one + candidates: list[str] = [] + depth = 0 + start = -1 + in_string = False + escape = False + for i, ch in enumerate(text): + if escape: + escape = False + continue + if ch == '\\' and in_string: + escape = True + continue + if ch == '"' and not escape: + in_string = not in_string + continue + if in_string: + continue + if ch == '{': + if depth == 0: + start = i + depth += 1 + elif ch == '}': + depth -= 1 + if depth == 0 and start != -1: + candidates.append(text[start:i + 1]) + start = -1 + + for candidate in reversed(candidates): + try: + return json.loads(candidate) + except json.JSONDecodeError: + continue + + return None + + +def _extract_tool_history(trajectory: list) -> list[dict]: + """ + Extract tool calls paired with their responses from a trajectory. + """ + tool_calls_by_id: dict[str, dict] = {} + + for msg in trajectory: + if hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + tool_calls_by_id[tc.id] = { + "name": tc.name, + "arguments": tc.arguments, + "tool_call_id": tc.id, + "response": "", + } + + if hasattr(msg, "role") and getattr(msg, "role", None) == "tool": + msg_id = getattr(msg, "id", "") + if msg_id in tool_calls_by_id: + tool_calls_by_id[msg_id]["response"] = getattr(msg, "content", "") or "" + + if hasattr(msg, "tool_messages"): + for tm in msg.tool_messages: + tm_id = getattr(tm, "id", "") + if tm_id in tool_calls_by_id: + tool_calls_by_id[tm_id]["response"] = getattr(tm, "content", "") or "" + + return list(tool_calls_by_id.values()) diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/evaluator/evaluator_traj.py b/examples/AgenticBenchmarks/VitaBench/src/vita/evaluator/evaluator_traj.py index 52c1792..5794bee 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/evaluator/evaluator_traj.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/evaluator/evaluator_traj.py @@ -1,3 +1,10 @@ +"""VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/evaluator/evaluator_traj.py. +Everything is verbatim from the original except for the following change: + +1. In ``TrajectoryEvaluator``, flatten ``result_data`` when the LLM response is + parsed as a nested list before iterating over rubric results. +""" import json import copy from typing import List @@ -238,6 +245,9 @@ def _evaluate_window( result_data = evaluator_extracter(assistant_message.content) if result_data: + # # Flatten in case the LLM response was parsed as a nested list + if isinstance(result_data, list) and result_data and isinstance(result_data[0], list): + result_data = [item for sublist in result_data for item in sublist] for result in result_data: rubric_idx = result.get("rubric_idx") if rubric_idx and rubric_idx in updated_states: @@ -734,4 +744,4 @@ def _evaluate_trajectory_full_traj_wo_rubric( "meetExpectation": False } - return updated_states, trajectory_evaluation_info \ No newline at end of file + return updated_states, trajectory_evaluation_info diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/orchestrator/orchestrator.py b/examples/AgenticBenchmarks/VitaBench/src/vita/orchestrator/orchestrator.py index 37676ae..f517e0a 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/orchestrator/orchestrator.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/orchestrator/orchestrator.py @@ -1,3 +1,20 @@ +"""VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/orchestrator/orchestrator.py. +Everything is verbatim from the original except for the following changes: + +1. Added a ``verifier`` constructor argument plus completeness-retry state + (``VITA_MAX_COMPLETENESS_RETRIES``) and ``verifier_feedback_history``. +2. Attached the verifier's ``soundness_log`` to the finalized ``SimulationRun``. +3. On agent stop, run the verifier's completeness check and, if it returns + feedback, inject a synthetic ``completeness_check`` tool_call/tool_result pair + to send the agent back instead of stopping. +4. In solo mode, nudge the agent back with a user message when it emits + non-tool-call text. +5. Before executing each agent tool call, run the verifier's soundness check; + on feedback, return it as an error ToolMessage and skip execution (blocking), + and feed successful tool responses to the harness via ``observe_tool_response``. +""" +import os import time import uuid from copy import deepcopy @@ -12,6 +29,7 @@ AssistantMessage, Message, MultiToolMessage, + ToolCall, ToolMessage, UserMessage, ) @@ -61,6 +79,7 @@ def __init__( seed: Optional[int] = None, solo_mode: bool = False, language: str = None, + verifier: Optional[Any] = None, ): self.domain = domain self.agent = agent @@ -82,6 +101,10 @@ def __init__( self.from_role: Optional[Role] = None self.to_role: Optional[Role] = None self.message: Optional[Message] = None + self.verifier = verifier + self.verifier_feedback_history: list = [] + self._completeness_retries = 0 + self._max_completeness_retries = int(os.environ.get("VITA_MAX_COMPLETENESS_RETRIES", 1)) def initialize(self): """ @@ -248,7 +271,8 @@ def run(self) -> SimulationRun: agent_cost=agent_cost, messages=messages, seed=self.seed, - states=self.get_states(self.environment.tools.db, self.environment.tools.db.time) + states=self.get_states(self.environment.tools.db, self.environment.tools.db.time), + soundness_log=self.verifier.soundness_call_log if self.verifier and hasattr(self.verifier, "soundness_call_log") else None, ) return simulation_run @@ -314,23 +338,113 @@ def step(self): return agent_msg.validate() + completeness_feedback = None if self.agent.is_stop(agent_msg): - self.done = True - self.termination_reason = TerminationReason.AGENT_STOP + # Run completeness check before allowing stop + remaining = self._max_completeness_retries - self._completeness_retries + if remaining > 0 and self.verifier: + states = self.get_states(self.environment.tools.db, self.environment.tools.db.time) + completeness_feedback = self.verifier.check_on_stop(self.trajectory, remaining - 1, new_orders=states["new_states"], old_orders=states["old_states"]) + if completeness_feedback: + self._completeness_retries += 1 + else: + self.done = True + self.termination_reason = TerminationReason.AGENT_STOP + else: + self.done = True + self.termination_reason = TerminationReason.AGENT_STOP self.trajectory.append(agent_msg) self.message = agent_msg self.from_role = Role.AGENT - if agent_msg.is_tool_call(): + if completeness_feedback: + # Inject a matching fake tool_call into the agent message + # so its valid tool_use → tool_result pair + check_id = f"completeness_check_{self._completeness_retries}" + if agent_msg.tool_calls is None: + agent_msg.tool_calls = [] + agent_msg.tool_calls.append(ToolCall( + id=check_id, + name="completeness_check", + arguments={}, + requestor="assistant", + )) + feedback_tool_msg = ToolMessage( + id=check_id, + name="completeness_check", + content=completeness_feedback.strip(), + requestor="assistant", + role="tool", + error=True, + ) + self.trajectory.append(feedback_tool_msg) + self.message = feedback_tool_msg + self.from_role = Role.ENV + self.to_role = Role.AGENT + elif agent_msg.is_tool_call(): self.to_role = Role.ENV + elif self.solo_mode and not self.agent.is_stop(agent_msg): + # Agent tried to talk to a user that isn't there — nudge it back + logger.warning(f"Solo mode agent produced non-tool-call text, nudging back: {agent_msg}") + nudge_msg = UserMessage( + role="user", + content=( + "There is no user to communicate with. " + "You are operating solo — refer to the task instructions in previous messages. " + "Continue making tool calls to complete the task, or call ###STOP### if done." + ), + ) + self.trajectory.append(nudge_msg) + self.message = nudge_msg + self.from_role = Role.USER + self.to_role = Role.AGENT else: self.to_role = Role.USER elif self.from_role in [Role.AGENT, Role.USER] and self.to_role == Role.ENV: if not self.message.is_tool_call(): raise ValueError("Agent or User should send tool call to environment") + # Exclude the current agent message from the base trajectory so the judge + # does not see all sibling tool calls as already-executed prior history. + trajectory_before_current = self.trajectory[:-1] tool_msgs = [] for tool_call in self.message.tool_calls: + soundness_feedback = None + if self.verifier: + # Build context = prior trajectory + a mock assistant message containing + # only the already-executed sibling tool_calls (so their IDs are registered + # by _extract_tool_history) + their real responses. + if tool_msgs: + mock_assistant = deepcopy(self.message) + mock_assistant.tool_calls = self.message.tool_calls[:len(tool_msgs)] + trajectory_for_judge = trajectory_before_current + [mock_assistant] + tool_msgs + else: + trajectory_for_judge = trajectory_before_current + soundness_feedback = self.verifier.check_soundness( + tool_call.name, tool_call.arguments, trajectory_for_judge + ) + if soundness_feedback: # Blocking behavior: return feedback as error, skip execution + tool_msgs.append(ToolMessage( + id=tool_call.id, + name=tool_call.name, + content=soundness_feedback, + requestor=tool_call.requestor, + role="tool", + error=True, + )) + continue tool_msg = self.environment.get_response(tool_call) + # OLD: nudge behavior — prepend feedback to tool response + # if soundness_feedback: + # tool_msg.content = ( + # soundness_feedback.strip() + # + "\n\n" + # + (tool_msg.content or "") + # ) tool_msgs.append(tool_msg) + # Feed tool response to harness memory (if the verifier supports it) + if self.verifier and hasattr(self.verifier, "observe_tool_response"): + self.verifier.observe_tool_response( + tool_call.name, tool_call.arguments or {}, tool_msg.content or "" + ) # Increment error count if tool call failed if tool_msg.error: self.num_errors += 1 diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/agent_system_prompt.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/agent_system_prompt.yaml index 3b20305..c221302 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/agent_system_prompt.yaml +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/agent_system_prompt.yaml @@ -26,3 +26,7 @@ english: |- - Only use information from the above context, prohibit constructing information without basis and replying to users - Focus on completing user needs, prohibit divergent guidance to users to propose new needs - After completing the user's task requirements, ask if there are any other needs. If the user indicates no, generate '###STOP###' mark to end the conversation + + # Language + - Always respond in English. + - Even tool calls should be in English. The database is in English. diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/completeness_extraction_template.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/completeness_extraction_template.yaml new file mode 100644 index 0000000..f235a0d --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/completeness_extraction_template.yaml @@ -0,0 +1,81 @@ +name: completeness_extraction_template +chinese: |- + 你是一个在线旅行社(OTA)任务的完整性约束提取系统。 + + 给定用户的旅行规划指令,提取所有预订需求的**数量和位置**信息,用于验证智能体是否完成了所有必要的预订。 + + 有些预订取决于运行时条件(天气、价格、库存、距离等,如“天气好就去A,否则去B”或“如果票价超过800元就选B”)。这些预订必须放入conditional列表中,绝对不能放在hotel/flight/train/attraction列表中。只有无条件的预订才放入主列表。 + + ## 重要规则 + 1. 使用模拟时间将所有相对日期解析为绝对YYYY-MM-DD格式。模拟时间: {system_time}。"这个月"指模拟时间的当月,"下个月"指下一个月,"下周六"指从模拟时间起的下一个周六。 + 2. 计算总票数时,不区分票种(成人票、儿童票等),只需统计总数量。 + 3. 只提取数量/位置/日期信息,不需要提取具体的酒店名称、房型、座位等级、票种等属性。 + 4. 对于取消/修改操作,使用提供的现有订单信息找到对应的order_id。 + + ## 字段说明 + - HotelCompleteness: check_in_date(入住日期), num_rooms(房间数), num_nights(入住天数) + - FlightCompleteness: departure_city(出发城市), arrival_city(到达城市), date(出发日期), quantity(票数) + - TrainCompleteness: departure_city(出发城市), arrival_city(到达城市), date(出发日期), quantity(票数) + - AttractionCompleteness: date(日期), quantity(票数) + - CancelCompleteness: entity_type(实体类型), order_id(要取消的订单ID) + - ModifyCompleteness: entity_type(实体类型), order_id(要修改的订单ID) + + ## 可用的航班/火车城市 + {available_cities} + 航班和火车约束中的departure_city和arrival_city必须使用上述列表中的确切城市名称。 + + ## 现有订单 + {existing_orders} + + ## 用户资料 + {user_profile} + + ## 用户历史行为 + {user_historical_behaviors} + + ## JSON Schema + 请严格按照以下schema输出: + {json_schema} + + 只输出JSON,不要输出其他内容。 + +english: |- + You are a completeness constraint extraction system for an Online Travel Agency (OTA) task evaluator. + + Given a user's travel planning instructions, extract the **quantity and location** information for all booking requirements. This is used to verify that the agent completed all necessary bookings — not whether each booking has the right attributes. + + Some bookings depend on runtime conditions (weather, price, availability, distance, etc. — e.g., "if the weather is good go to A, otherwise go to B" or "if tickets exceed 800 yuan, do X instead"). These MUST go in the conditional list — never in the hotel/flight/train/attraction lists. Only unconditional bookings belong in the main lists. + + ## Important Rules + 1. Always respond in English. All field values, keywords, descriptions, and notes must be written in English. The entire database is in English. + 2. Resolve ALL relative dates to absolute YYYY-MM-DD format using the simulated time: {system_time}. "This month" means the month of the simulated time. "Next month" means the following month. "Next Saturday" means the upcoming Saturday from the simulated time. + 3. Count total tickets regardless of ticket type (adult, child, etc.) — just the total quantity. + 4. Only extract quantity/location/date information. Do NOT extract specific hotel names, room types, seat classes, ticket types, or other attributes. + 5. For cancel/modify operations, use the provided existing orders to find the matching order_id. + + ## Field Descriptions + - HotelCompleteness: check_in_date, num_rooms, num_nights + - FlightCompleteness: departure_city, arrival_city, date, quantity (number of tickets) + - TrainCompleteness: departure_city, arrival_city, date, quantity (number of tickets) + - AttractionCompleteness: date, quantity (number of tickets) + - CancelCompleteness: entity_type, order_id (ID of the pre-existing order to cancel) + - ModifyCompleteness: entity_type, order_id (ID of the pre-existing order to modify) + + ## Available Flight/Train Cities + {available_cities} + The departure_city and arrival_city in flight and train constraints MUST use the exact city names from the list above. + + ## Existing Orders + {existing_orders} + + ## User Profile + {user_profile} + + ## User Historical Behaviors + {user_historical_behaviors} + + ## JSON Schema + Output strictly according to this schema: + {json_schema} + + Output ONLY the JSON, no other text. diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/date_resolution_template.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/date_resolution_template.yaml new file mode 100644 index 0000000..a9f2dd8 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/date_resolution_template.yaml @@ -0,0 +1,51 @@ +name: date_resolution_template +chinese: |- + 识别文本中的所有**相对日期表达**,解析为绝对日期,并输出替换后的文本。 + + 当前模拟时间: {system_time}({weekday}) + + ## 规则 + 1. 找出所有相对日期表达(如"下周六"、"下个月1号"、"后天"、"七夕节后一天"等)。 + 2. 根据模拟时间解析为绝对日期。 + 3. 中国传统节日按农历转公历(七夕、中秋、春节等)。 + 4. 范围表达(如"这周末"、"下周"、"这周")解析为起止日期,格式为 `YYYY-MM-DD to YYYY-MM-DD (星期X-星期X)`。月份表达(如"这个月"、"下个月")仅解析为 `YYYY-MM`。 + 5. 替换后的指令中,将原始相对日期表达替换为上述格式。除日期外不要修改任何内容。 + 6. 在输出前,逐条验证每个日期的星期是否正确(例如 2026-06-20 实际是星期六,不是星期日)。如有不一致,修正日期或星期。 + + ## 输出格式 + ```json + {{ + "resolved_dates": [["原始短语", "YYYY-MM-DD", "星期X"], ...], + "resolved_instructions": "替换后的完整文本" + }} + ``` + + ## 待解析文本 + {instructions} +english: |- + You are a date resolution assistant. Your task is to identify all **relative date/time expressions** in some text, resolve them to absolute dates, and output the instructions with replacements applied. + + Current simulated time: {system_time} ({weekday}) + + ## Rules + 1. Find all relative date expressions in the instructions. + 2. Relative expressions include but are not limited to: "next Saturday", "next month on the 1st", "this weekend", "the day after tomorrow", "two weeks from now on Saturday", "the day after Qixi Festival", etc. + 3. Resolve the relative dates against the simulated time. Convert Chinese lunar festivals (Qixi, Mid-Autumn, Spring Festival, etc.) to Gregorian dates. + 4. Range expressions (e.g. "this weekend", "next week", "this week") must be resolved to start and end dates: `YYYY-MM-DD to YYYY-MM-DD (DayOfWeek-DayOfWeek)`. Month-level expressions (e.g. "this month", "next month") resolve to just `YYYY-MM`. + 5. In the resolved instructions, replace the original relative expressions with the formats above. Do not modify anything else — keep the original wording, tone, and punctuation exactly the same. + 6. Before outputting, verify each date's day-of-week is correct (e.g. 2026-06-20 is actually a Saturday, not a Sunday). If there is a mismatch, fix the date or the day-of-week. + + ## Output Format + Output a JSON object with two fields: + ```json + {{ + "resolved_dates": [ + ["original phrase", "YYYY-MM-DD", "DayOfWeek"], + ... + ], + "resolved_instructions": "Full instruction text with dates replaced" + }} + ``` + + ## Text to Resolve + {instructions} diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_constraint_extraction_template.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_constraint_extraction_template.yaml new file mode 100644 index 0000000..3e7bb3e --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_constraint_extraction_template.yaml @@ -0,0 +1,113 @@ +name: harness_constraint_extraction_template +chinese: |- + 你是一个在线旅行社(OTA)任务评估器的约束提取系统。 + + 给定用户的旅行规划指令,提取预订智能体必须满足的所有约束。每条约束都应保留用户自然语言中的意图。 + + ## 什么是约束? + 约束是用户指令中单一、原子化的需求。每条约束都应能针对单个预订操作独立验证。将复合需求拆解为最小的可验证单元。例如:“预订2间房,住3晚” → 一条数量约束(2间房)+ 一条时长约束(3晚)。例如:“上海每晚500元以下最便宜的酒店” → 一条城市约束(上海)+ 一条价格约束(最便宜,每晚500元以下)。 + + ## 规则 + 1. 保留用户的原意——不要过度解读或添加假设。 + 2. 相对日期应在描述中解析为绝对日期(例如“2026-05-09(下周六)入住”)。 + 3. entity_type 必须是以下之一:“hotel”、“flight”、“train”、“attraction”,若约束跨实体类型适用则为 null。 + 4. 每条约束描述必须**自包含**。不要使用“那天”、“同一个”、“上面那家酒店”等指代——请重复相关上下文。评判者将独立评估每条约束,无法看到其他约束。错误示例:“乘坐那天最早的火车” 正确示例:“乘坐下周三(2026-05-07)从天津到杭州的最早一班火车” + + ## 约束类别 + - **date**:任何日期需求(入住日期、出发日期、“下周六”等) + - **price**:预算上限、价格偏好(“每晚500元以下”、“最便宜的选项”) + - **city**:地点需求(目的地城市、出发城市、“市中心附近”) + - **quantity**:房间数、票数、座位数、乘客数 + - **duration**:入住时长、住宿晚数 + - **entity_attribute**:关于实体的具体需求(房型、座位等级、星级、酒店品牌、设施) + - **other**:任何不属于上述类别的可验证需求(例如“预订与上次旅行相同的酒店”、“靠窗座位”、“必须含早餐”、“无烟房”) + + ## 条件性需求 + 如果用户的指令包含条件逻辑(例如“下雨就订X,否则订Y”),将条件直接嵌入约束的描述中。例如:“如果苏州下雨就在杭州订酒店,否则在苏州订” → 两条城市约束:“在杭州订酒店(仅当苏州天气下雨时)”和“在苏州订酒店(仅当苏州天气不下雨时)”。 + + ## 系统时间 + {system_time} + + ## 用户画像 + {user_profile} + + ## 输出格式 + 返回一个JSON对象: + ```json + {{ + "task_id": "", + "constraints": [ + {{ + "id": "c1", + "category": "date", + "entity_type": "hotel", + "description": "Check in on 2026-05-09 (next Saturday)" + }}, + {{ + "id": "c2", + "category": "price", + "entity_type": "hotel", + "description": "Hotel should cost no more than 500 per night" + }}, + ... + ] + }} + ``` + + 仅输出JSON,不要输出任何其他文本。 +english: |- + You are a constraint extractor for an Online Travel Agency (OTA) task evaluator. + + Given a user's travel planning instructions, extract ALL constraints that a booking agent must satisfy. Each constraint should preserve the user's natural language intent. + + ## What is a constraint? + A constraint is a single, atomic requirement from the user's instructions. Each constraint should be independently verifiable against a single booking action. Decompose compound requirements into their smallest verifiable parts. Example: "Book 2 rooms for 3 nights" → one quantity constraint (2 rooms) + one duration constraint (3 nights). Example: "Cheapest hotel in Shanghai under 500/night" → one city constraint (Shanghai) + one price constraint (cheapest, under 500/night). + + ## Rules + 1. Preserve the user's language — do NOT over-interpret or add assumptions. + 2. Relative dates should be resolved to absolute dates in the description (e.g. "Check in on 2026-05-09 (next Saturday)"). + 3. entity_type must be one of: "hotel", "flight", "train", "attraction", or null if the constraint applies across entity types. + 4. Each constraint description must be **self-contained**. Do not use references like "that day", "the same one", "the above hotel" — repeat the relevant context. A judge will evaluate each constraint independently with NO visibility into other constraints. Bad: "Take the earliest train on that day" Good: "Take the earliest train on next Wednesday (2026-05-07) from Tianjin to Hangzhou" + + ## Constraint Categories + - **date**: Any date requirement (check-in dates, departure dates, "next Saturday", etc.) + - **price**: Budget limits, price preferences ("under 500/night", "cheapest option") + - **city**: Location requirements (destination city, departure city, "near downtown") + - **quantity**: Number of rooms, tickets, seats, passengers + - **duration**: Length of stay, number of nights + - **entity_attribute**: Specific requirements about the entity (room type, seat class, star rating, hotel brand, amenities) + - **other**: Any verifiable requirement that doesn't fit the above categories (e.g. "book the same hotel as last trip", "window seat", "must include breakfast", "non-smoking room") + + ## Conditional Requirements + If the user's instructions include conditional logic (e.g. "if rainy, book X; otherwise book Y"), embed the condition directly in the constraint's description. Example: "Book a hotel in Hangzhou IF the weather in Suzhou is rainy, otherwise book in Suzhou" → two city constraints: "Hotel in Hangzhou (only if Suzhou weather is rainy)" and "Hotel in Suzhou (only if Suzhou weather is NOT rainy)". + + ## System Time + {system_time} + + ## User Profile + {user_profile} + + ## Output Format + Return a JSON object: + ```json + {{ + "task_id": "", + "constraints": [ + {{ + "id": "c1", + "category": "date", + "entity_type": "hotel", + "description": "Check in on 2026-05-09 (next Saturday)" + }}, + {{ + "id": "c2", + "category": "price", + "entity_type": "hotel", + "description": "Hotel should cost no more than 500 per night" + }}, + ... + ] + }} + ``` + + Output ONLY the JSON, no other text. diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_memory_writer_template.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_memory_writer_template.yaml new file mode 100644 index 0000000..5280ae1 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_memory_writer_template.yaml @@ -0,0 +1,33 @@ +name: harness_memory_writer_template +chinese: |- + 你是一个旅行预订智能体监控器的简洁记录员。 + + 你将看到来自旅行预订系统的一次工具调用及其响应。你的任务是只提取与评估以下约束相关的事实: + + ## 需要监控的约束 + {constraints_text} + + ## 说明 + - 用简短的总结(1-3句话)概括此工具响应中与上述任一约束相关的事实。 + - 重点关注:看到的价格、城市/地点、日期、可用情况、数量、具体属性(房型、座位等级、评级)。 + - 如果工具响应中没有与约束相关的任何信息,只回复:NOTHING_RELEVANT + - 不要推测或推断。只记录工具响应中明确陈述的内容。 + - 保持简洁。这是一份运行日志,不是完整的逐字记录。 + + 只回复你的总结或 NOTHING_RELEVANT。不要使用JSON,不要使用任何格式。 +english: |- + You are a concise note-taker for a travel booking agent monitor. + + You will see a tool call and its response from a travel booking system. Your job is to extract ONLY the facts that are relevant to evaluating the following constraints: + + ## Constraints to Monitor + {constraints_text} + + ## Instructions + - Write a brief summary (1-3 sentences) of facts from this tool response that relate to ANY of the above constraints. + - Focus on: prices seen, cities/locations, dates, availability, quantities, specific attributes (room types, seat classes, ratings). + - If the tool response contains NO information relevant to the constraints, respond with exactly: NOTHING_RELEVANT + - Do NOT speculate or infer. Only record what the tool response explicitly states. + - Be concise. This is a running log, not a full transcript. + + Respond with ONLY your summary or NOTHING_RELEVANT. No JSON, no formatting. diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_soundness_judge_template.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_soundness_judge_template.yaml new file mode 100644 index 0000000..49d4209 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/harness_soundness_judge_template.yaml @@ -0,0 +1,51 @@ +name: harness_soundness_judge_template +chinese: |- + 你是一个在线旅行社(OTA)预订监控器的约束评估器。 + + 你将获得: + 1. 智能体必须满足的单条约束(从用户指令中提取) + 2. 从先前工具调用中观察到的事实的累积记忆 + 3. 智能体即将执行的写操作(create/cancel/modify)工具调用 + + 根据记忆中已知的信息,判断该写操作工具调用对这条约束是违反、一致,还是尚无法确定。 + + ## 判定结果 + - **violated**:根据现有证据,该工具调用明显违背了这条约束。 + - **consistent**:该工具调用与这条约束相符。 + - **undetermined**:记忆中信息不足以判断,或该约束不适用于这次具体的调用。 + + ## 规则 + - 不要求智能体已完成所有预订。部分预订是可以接受的。 + - 容忍轻微的名称差异和近似匹配。 + - 如果约束包含条件(例如“仅当天气下雨时”),在记忆中查找有关该条件的证据。若无证据,标记为 undetermined。 + + ## 输出格式 + 仅用JSON回复: + ```json + {"reasoning": "简要说明(1-2句话)", "verdict": "consistent/violated/undetermined"} + ``` +english: |- + You are a constraint evaluator for an Online Travel Agency booking monitor. + + You will be given: + 1. A single constraint the agent must satisfy (extracted from user instructions) + 2. An accumulated memory of facts observed from prior tool calls + 3. A write tool call (create/cancel/modify) the agent is about to execute + + Determine whether the write tool call — in light of what's known in memory — violates, is consistent with, or cannot yet be determined against this constraint. + + ## Verdicts + - **violated**: The tool call clearly contradicts this constraint based on available evidence. + - **consistent**: The tool call is compatible with this constraint. + - **undetermined**: Not enough information in memory to judge, or the constraint is not applicable to this specific call. + + ## Rules + - Do NOT require the agent to have completed all bookings. A partial booking is fine. + - Tolerate minor name variations and approximate matches. + - If the constraint contains a condition (e.g. "only if weather is rainy"), check memory for evidence about that condition. If no evidence, mark undetermined. + + ## Output Format + Reply with JSON only: + ```json + {"reasoning": "Brief explanation (1-2 sentences)", "verdict": "consistent/violated/undetermined"} + ``` diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/solo_agent_system_prompt.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/solo_agent_system_prompt.yaml index 7a630b2..52a1fb3 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/solo_agent_system_prompt.yaml +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/solo_agent_system_prompt.yaml @@ -32,3 +32,7 @@ english: |- - No interaction with users during execution - By default, for logic that requires user confirmation, it is considered that the user has already confirmed - After completing all user requirements, generate '###STOP###' mark to end the conversation + + # Language + - Always respond in English. + - Even tool calls should be in English. The database is in English. diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/soundness_judge_template.yaml b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/soundness_judge_template.yaml new file mode 100644 index 0000000..d5b9e9c --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/prompts/soundness_judge_template.yaml @@ -0,0 +1,72 @@ +name: soundness_judge_template +chinese: |- + 你是一个在线旅行社(OTA)任务的实时验证器。你的职责是判断一个即将执行的写操作(create_*、pay_*、cancel_*、modify_*工具调用)的参数是否与用户的原始指令一致。 + + 你需要关注的是**这一次工具调用本身**的参数是否正确,而不是智能体是否已经完成了所有预订。 + + ## 你的输入 + 1. **用户指令**:用户对旅行社智能体的原始请求 + 2. **工具调用历史**:智能体到目前为止执行的所有工具调用及其返回结果 + 3. **当前工具调用**:即将执行的写操作及其参数 + + ## 判断步骤 + + ### 第一步:验证智能体的整体方案 + 在评估当前工具调用之前,先检查智能体的整体方案是否正确: + - 仔细重新阅读用户的原始指令。明确用户的实际需求——正确的目的地、正确的行程安排、正确的方案选择。 + - 如果用户的指令要求在多个选项中做出选择(如"选更便宜的行程"、"去天气更好的城市"),根据工具调用历史中的搜索结果、价格比较、天气查询等来判断哪个选项才是正确的。如果智能体在执行错误的选项,应BLOCK。 + - 如果用户的指令包含条件逻辑(如"天气好就去A,否则去B"),在工具调用历史中找到实际的weather()或其他相关查询结果,确定应该走哪个分支。如果智能体在执行错误的分支,应BLOCK。 + + ### 第二步:验证日期 + 如果用户使用相对日期表达("下周六"、"这个周五"、"后天"、"下个月1号"),根据用户指令中的系统时间独立计算正确日期,并验证其与工具调用参数中的日期是否一致。不要信任智能体的日期推算——请自行验证。 + + ### 第三步:验证本次工具调用的参数 + 检查本次工具调用的具体参数是否符合用户意图: + - **容忍模糊语言和名称变体**:"大约"、"左右"、"差不多"不是硬性边界。接受同一地点的不同叫法(如"峨眉山"/"Emeishan")。如果搜索结果中某选项最接近用户需求,即使不完全匹配也应接受。 + - **尊重状态变更**:如果智能体取消了旧预订并正在重新预订,应根据改签意图评估,而非原始预订详情。 + - **不要检查完整性**:智能体可能会分多次调用来完成预订(如先订成人票,再订老年票)。只要这一次调用的参数正确,就应该ALLOW,即使它只是整个预订的一部分。 + + ## 输出格式 + 用JSON回复: + ```json + {{ + "reason": "简短解释(1-2句话)", + "verdict": "ALLOW" | "BLOCK" + }} + ``` + +english: |- + You are a real-time verifier for an Online Travel Agency (OTA) task. Your job is to judge whether an upcoming write action (create_*, pay_*, cancel_*, or modify_*) is consistent with the user's original instructions. + + Your focus is on whether **this particular tool call's parameters** are correct, not on whether the agent has completed all required bookings. + + ## Your Inputs + 1. **User Instruction**: The user's original request to the travel agent + 2. **Tool Call History**: All tool calls the agent has executed so far and their responses (searches, queries, cancellations, etc.) + 3. **Current Tool Call**: The write operation about to be executed, with its arguments + + ## Judgment Steps + + ### Step 1: Verify the agent's overall plan + Before evaluating the current tool call, check whether the agent's overall plan is correct: + - Re-read the user's original instructions carefully. Identify the user's actual requirements — the correct destination, the correct itinerary, the correct option among alternatives. + - If the user's instructions require choosing among options (e.g., "pick the cheaper itinerary", "go to whichever city has better weather"), determine which option is correct based on the tool call history (search results, price comparisons, weather queries). If the agent is executing the wrong option, BLOCK. + - If the user's instructions contain conditional logic (e.g., "if the weather is good go to A, otherwise go to B"), find the actual weather() or other relevant query results in the tool call history and determine which branch is active. If the agent is following the wrong branch, BLOCK. + + ### Step 2: Verify dates + If the user specifies dates using relative expressions ("next Saturday", "this Friday", "the day after tomorrow", "next month on the 1st"), independently compute the correct date from the system time shown in the user instruction, and verify it matches the date in the tool call's arguments. Do not trust the agent's date arithmetic — verify it yourself. + + ### Step 3: Verify this tool call's parameters + Check that the specific parameters of this tool call match the user's intent: + - **Tolerate vague language and name variants**: "around", "approximately", "preferred" are not hard boundaries. Accept equivalent names for the same place (e.g., "Mount Emei" / "Emeishan"). Accept the closest available option from search results even if not a perfect literal match. + - **Respect state changes**: If the agent cancelled old bookings and is rebooking, evaluate against the rescheduling intent, not the original booking details. + - **Do not check completeness**: The agent may split a booking across multiple calls (e.g., adult tickets first, then senior tickets). ALLOW if this call's parameters are correct, even if it only covers part of the overall request. + + ## Output Format + Reply with JSON only: + ```json + {{ + "reason": "Brief explanation (1-2 sentences)", + "verdict": "ALLOW" | "BLOCK" + }} + ``` diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/run.py b/examples/AgenticBenchmarks/VitaBench/src/vita/run.py index 36f25e9..05e8172 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/run.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/run.py @@ -1,4 +1,22 @@ +"""VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/run.py. +Everything is verbatim from the original except for the following changes: + +1. Use full tracebacks in retry/error logs. +2. In solo mode, build the user via ``DummyUser.build(...)`` (resolving + pregenerated messages) instead of the registry constructor. +3. Added the ``_build_csv_path`` helper (timestamp + model name when the CSV + output is a directory). +4. Threaded ``soundness_mode``, ``completeness_mode``, ``solo_user_mode`` and + ``solo_user_file`` / ``solo_user_messages`` through ``run_domain`` → + ``run_tasks`` → ``run_task`` → ``_run_task_internal``. +5. Load the pregenerated solo user messages from ``solo_user_file`` once per run. +6. Build the OTA verifier via ``create_verifier`` and pass it to the + ``Orchestrator``. +7. Append soundness/completeness suffixes to the run name in ``make_run_name``. +""" import json +import traceback import multiprocessing import random import threading @@ -26,13 +44,24 @@ from vita.metrics.agent_metrics import compute_metrics from vita.orchestrator.orchestrator import Orchestrator from vita.registry import RegistryInfo, registry -from vita.user.user_simulator import get_global_user_sim_guidelines +from vita.user.user_simulator import get_global_user_sim_guidelines, DummyUser from vita.utils.display import ConsoleDisplay from vita.utils.pydantic_utils import get_pydantic_hash from vita.utils.utils import DATA_DIR, get_commit_hash, get_now, show_dict_diff, global_time from vita.utils.csv_utils import save_results_to_csv +def _build_csv_path(csv_output: str, model_name: str) -> str: + """Build CSV path with timestamp and model name if csv_output is a directory.""" + path = Path(csv_output) + if path.suffix != ".csv": + # Treat as directory: generate filename with timestamp and model name + safe_model = model_name.replace("/", "_") + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + path = path / f"{timestamp}_{safe_model}.csv" + return str(path) + + def get_options() -> RegistryInfo: """ Returns options for the simulator. @@ -101,8 +130,10 @@ def make_run_name(config: RunConfig) -> str: # Add think mode indicator to the filename if enable_think is True think_suffix = "_think" if config.enable_think else "" - - return f"{get_now()}_{config.domain}_{agent_name}_{user_name}{think_suffix}" + soundness_suffix = f"_sc-{config.soundness_mode}" if config.soundness_mode != "off" else "" + completeness_suffix = "_cc-on" if config.completeness_mode == "on" else "" + + return f"{get_now()}_{config.domain}_{agent_name}_{user_name}{think_suffix}{soundness_suffix}{completeness_suffix}" def run_domain(config: RunConfig) -> Results: @@ -155,6 +186,10 @@ def run_domain(config: RunConfig) -> Results: llm_evaluator=config.llm_evaluator, llm_args_evaluator=config.llm_args_evaluator, language=config.language, + soundness_mode=config.soundness_mode, + completeness_mode=config.completeness_mode, + solo_user_mode=config.solo_user_mode, + solo_user_file=config.solo_user_file, ) metrics = compute_metrics(simulation_results) @@ -162,7 +197,7 @@ def run_domain(config: RunConfig) -> Results: if config.csv_output_file and simulation_results.simulations: try: - csv_output = config.csv_output_file + csv_output = _build_csv_path(config.csv_output_file, config.llm_agent) save_results_to_csv(simulation_results, csv_output, config, metrics) ConsoleDisplay.console.print(f"\n💾 [bold green]Results appended to CSV: {csv_output}[/bold green]") except Exception as e: @@ -193,6 +228,11 @@ def run_tasks( llm_evaluator: Optional[str] = None, llm_args_evaluator: Optional[dict] = None, language: str = None, + + soundness_mode: str = "off", + completeness_mode: str = "on", + solo_user_mode: str = "live", + solo_user_file: Optional[str] = None, ) -> Results: """ Runs tasks for a given domain. @@ -234,6 +274,15 @@ def run_tasks( if max_errors <= 0: raise ValueError("Max errors must be greater than 0") + # Load the pregenerated solo user messages once for the entire run + solo_user_messages: Optional[dict] = None + if solo_user_mode == "file": + if not solo_user_file: + raise ValueError("solo_user_mode='file' requires solo_user_file to be set") + with open(solo_user_file, "r", encoding="utf-8") as _fp: + solo_user_messages = json.load(_fp) + logger.info(f"Loaded {len(solo_user_messages)} pregenerated solo user messages from {solo_user_file}") + random.seed(seed) seeds = [random.randint(0, 1000000) for _ in range(num_trials)] @@ -378,6 +427,10 @@ def _run(task: Task, trial: int, seed: int, progress_str: str) -> dict: llm_evaluator=llm_evaluator, llm_args_evaluator=llm_args_evaluator, language=language, + soundness_mode=soundness_mode, + completeness_mode=completeness_mode, + solo_user_mode=solo_user_mode, + solo_user_messages=solo_user_messages, ) simulation.trial = trial if console_display: @@ -456,6 +509,10 @@ def run_task( llm_evaluator: Optional[str] = None, llm_args_evaluator: Optional[dict] = None, language: str = None, + soundness_mode: str = "off", + completeness_mode: str = "off", + solo_user_mode: str = "live", + solo_user_messages: Optional[dict] = None, ) -> SimulationRun: """ Runs tasks for a given domain. @@ -502,16 +559,20 @@ def run_task( enable_think=enable_think, llm_evaluator=llm_evaluator, llm_args_evaluator=llm_args_evaluator, - language=language + language=language, + soundness_mode=soundness_mode, + completeness_mode=completeness_mode, + solo_user_mode=solo_user_mode, + solo_user_messages=solo_user_messages, ) except Exception as e: if attempt < max_retries: - logger.warning(f"Task {task.id} failed on attempt {attempt + 1}/{max_retries + 1}: {e}. Retrying...") + logger.warning(f"Task {task.id} failed on attempt {attempt + 1}/{max_retries + 1}: {traceback.format_exc()}. Retrying...") # Clear global state, prepare for retry _clear_global_state() continue else: - logger.error(f"Task {task.id} failed after {max_retries + 1} attempts. Last error: {e}") + logger.error(f"Task {task.id} failed after {max_retries + 1} attempts. Last error:\n{traceback.format_exc()}") raise e @@ -532,6 +593,10 @@ def _run_task_internal( llm_evaluator: Optional[str] = None, llm_args_evaluator: Optional[dict] = None, language: str = None, + soundness_mode: str = "off", + completeness_mode: str = "off", + solo_user_mode: str = "live", + solo_user_messages: Optional[dict] = None, ) -> SimulationRun: """ Internal implementation of run_task without retry logic. @@ -581,15 +646,32 @@ def _run_task_internal( f"Unknown agent type: {AgentConstructor}. Should be LLMAgent or LLMSoloAgent" ) - UserConstructor = registry.get_user_constructor(user) + if solo_mode: + user = DummyUser.build( + task_id=task.id, + instructions=str(task.instructions), + persona=str(task.user_scenario.user_profile), + llm=llm_user, + llm_args=llm_args_user, + language=language, + solo_user_mode=solo_user_mode, + solo_user_messages=solo_user_messages, + ) + else: + UserConstructor = registry.get_user_constructor(user) + user = UserConstructor( + persona=str(task.user_scenario.user_profile), + instructions=str(task.instructions), + llm=llm_user, + llm_args=llm_args_user, + language=language, + ) - user = UserConstructor( - persona=str(task.user_scenario.user_profile), - instructions=str(task.instructions), - llm=llm_user, - llm_args=llm_args_user, - language=language, - ) + # Build the OTA verifier for supported domains + verifier = None + if (soundness_mode != "off" or completeness_mode != "off") and domain == "ota": + from vita.domains.ota.verifier import create_verifier + verifier = create_verifier(task, llm_model=llm_agent, language=language or "english", soundness_mode=soundness_mode, completeness_mode=completeness_mode, solo_user_message=user.solo_user_message) orchestrator = Orchestrator( domain=domain, @@ -601,7 +683,8 @@ def _run_task_internal( max_errors=max_errors, seed=seed, solo_mode=solo_mode, - language=language + language=language, + verifier=verifier, ) simulation = orchestrator.run() @@ -763,6 +846,8 @@ def re_evaluate_simulation(config: RunConfig) -> Results: llm_evaluator=config.llm_evaluator, llm_args_evaluator=config.llm_args_evaluator, language=config.language, + soundness_mode=config.soundness_mode, + completeness_mode=config.completeness_mode, ) # Remove old simulations for the re-run task IDs @@ -894,7 +979,7 @@ def _re_evaluate_single(simulation, task_dict, domain_name, progress_str): if config.csv_output_file and re_eval_results.simulations: try: - csv_output = config.csv_output_file + csv_output = _build_csv_path(config.csv_output_file, config.llm_agent) save_results_to_csv(re_eval_results, csv_output, config, metrics) ConsoleDisplay.console.print(f"\n💾 [bold green]Results appended to CSV: {csv_output}[/bold green]") except Exception as e: @@ -917,4 +1002,4 @@ def _re_evaluate_single(simulation, task_dict, domain_name, progress_str): logger.info(f"Saving re-evaluation results to: {save_to}") re_eval_results.save(save_to) - return re_eval_results \ No newline at end of file + return re_eval_results diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/README.md b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/README.md new file mode 100644 index 0000000..766dbb0 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/README.md @@ -0,0 +1,94 @@ +# `vita/scripts` + +Offline helper scripts. Run them all from the **repo root**. Each command below +shows only the required args; every flag is documented underneath and is optional +unless marked **required**. + +--- + +### `preresolve_dates.py` + +Rewrites relative date phrases ("next Saturday", "the 1st of next month") in OTA +task instructions into absolute dates with an LLM, writing +`data/vita/domains/ota/resolved_instructions_.json`. + +```bash +python src/vita/scripts/preresolve_dates.py --model +``` + +- `--tasks-file` — input tasks JSON. Default `data/vita/domains/ota/tasks_en.json`. +- `--output` / `-o` — output path. Default derived from model: `resolved_instructions_.json`. +- `--model` — **required**; LLM used for resolution. +- `--language` — `english` or `chinese`. Default `english`. +- `--task-ids` — only resolve these task IDs (space-separated). Default: all tasks. +- `--num-tasks` — only resolve the first N tasks. Default: all tasks. +- `--max-concurrency` — parallel workers. Default `16`. +- `--resume` — flag; skip tasks already present in the output file. + +--- + +### `preextract_completeness.py` + +Extracts per-task completeness constraints (booking counts, cities, routes, dates) +that the completeness checker loads at runtime, writing +`data/vita/domains/ota/completeness_.json`. + +```bash +python src/vita/scripts/preextract_completeness.py --model +``` + +- `--tasks-file` — input tasks JSON. Default `data/vita/domains/ota/tasks_en.json`. +- `--output` / `-o` — output path. Default `data/vita/domains/ota/completeness_.json`. +- `--model` — **required**; LLM used for extraction. +- `--language` — `english` or `chinese`. Default `english`. +- `--task-ids` — only extract these task IDs (space-separated). Default: all tasks. +- `--num-tasks` — only extract the first N tasks. Default: all tasks. +- `--max-concurrency` — parallel workers. Default `1`. +- `--resume` — flag; skip tasks already present in the output file. + +--- + +### `preextract_constraints_harness.py` + +Extracts the natural-language constraints used by the soundness judge harness from +the resolved user-sim messages, writing +`data/vita/domains/ota/harness_constraints_.json`. + +```bash +python src/vita/scripts/preextract_constraints_harness.py --model +``` + +- `--tasks-file` — input tasks JSON. Default `data/vita/domains/ota/tasks_en.json`. +- `--user-sim-file` — user-sim messages keyed by task_id (use a date-resolved one). Default `data/user_sim_claude-opus-4.6_ota_english_resolved.json`. +- `--output` / `-o` — output path. Default `data/vita/domains/ota/harness_constraints_.json`. +- `--model` — **required**; LLM used for extraction. +- `--language` — `english` or `chinese`. Default `english`. +- `--task-ids` — only extract these task IDs (space-separated). Default: all tasks. +- `--num-tasks` — only extract the first N tasks. Default: all tasks. +- `--max-concurrency` — parallel workers. Default `1`. +- `--resume` — flag; skip tasks already present in the output file. + +--- + +### `pregenerate_solo_messages.py` + +Pre-generates the opening user message for solo-agent runs, writing +`data/user_sim___.json` (pass it to a run via +`--solo-user-mode=file --solo-user-file=`). + +Depends on `preresolve_dates.py`: pass that script's output via +`--resolved-instructions` so the generated messages are built from absolute-date +instructions instead of the original relative ("next Saturday") phrasing. Run +`preresolve_dates.py` first, then point this script at its JSON. + +```bash +python -m vita.scripts.pregenerate_solo_messages --domain ota --llm +``` + +- `--domain` — **required**; domain name (`ota`, `delivery`, `instore`). +- `--llm` — **required**; LLM model used for generation. +- `--language` — `english` or `chinese`. Default `english`. +- `--resolved-instructions` — path to `preresolve_dates.py` output; uses absolute-date instructions. Default: none (uses original instructions). +- `--task-ids` — only generate for these task IDs (space-separated). Default: all tasks. +- `--max-concurrency` — parallel workers. Default `1`. +- `--force` — flag; regenerate entries already present in the output file. diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preextract_completeness.py b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preextract_completeness.py new file mode 100644 index 0000000..f62e02f --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preextract_completeness.py @@ -0,0 +1,183 @@ +""" +Pre-extract completeness constraints for all OTA tasks and save to a JSON file. + +Run this ONCE before test time so the verifier can load pre-computed +completeness constraints instead of calling the LLM at runtime. + +Usage: + python src/vita/scripts/preextract_completeness.py --model + python src/vita/scripts/preextract_completeness.py --model --max-concurrency 16 + python src/vita/scripts/preextract_completeness.py --model --task-ids D0811005 D0811006 + python src/vita/scripts/preextract_completeness.py --model --num-tasks 10 +""" + +from __future__ import annotations + +import argparse +import json +import sys +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +from vita.data_model.tasks import Task +from vita.domains.ota.completeness import extract_completeness_constraints + + +def load_tasks(path: str) -> list[dict]: + with open(path) as f: + return json.load(f) + + +def main(): + parser = argparse.ArgumentParser( + description="Pre-extract OTA completeness constraints for all tasks" + ) + parser.add_argument( + "--tasks-file", + default="data/vita/domains/ota/tasks_en.json", + help="Path to tasks JSON file", + ) + parser.add_argument( + "--output", + "-o", + default=None, + help="Output path (default: data/vita/domains/ota/completeness_{model}.json)", + ) + parser.add_argument( + "--task-ids", + nargs="+", + default=None, + help="Only extract specific task IDs", + ) + parser.add_argument( + "--num-tasks", + type=int, + default=None, + help="Only extract first N tasks", + ) + parser.add_argument( + "--model", + required=True, + help="LLM model name", + ) + parser.add_argument( + "--language", + default="english", + choices=["english", "chinese"], + help="Prompt language (default: english)", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from existing output file, skipping already-extracted tasks", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=1, + help="Max parallel extraction tasks (default: 1)", + ) + args = parser.parse_args() + + if args.output is None: + args.output = f"data/vita/domains/ota/completeness_{args.model}.json" + + raw_tasks = load_tasks(args.tasks_file) + print(f"Loaded {len(raw_tasks)} tasks from {args.tasks_file}") + + # Filter + if args.task_ids: + raw_tasks = [t for t in raw_tasks if t.get("id") in args.task_ids] + if not raw_tasks: + print(f"No tasks found matching IDs: {args.task_ids}") + sys.exit(1) + elif args.num_tasks is not None: + raw_tasks = raw_tasks[: args.num_tasks] + + # Load existing results if resuming + existing: dict[str, dict] = {} + if args.resume: + try: + with open(args.output) as f: + existing = json.load(f) + print(f"Resuming: {len(existing)} tasks already extracted") + except FileNotFoundError: + pass + + results: dict[str, dict] = dict(existing) + successes = 0 + failures = 0 + lock = threading.Lock() + + # Filter out already-extracted tasks + pending: list[tuple[int, dict]] = [] + for i, raw in enumerate(raw_tasks): + task_id = raw.get("id", f"unknown_{i}") + if task_id in existing: + print(f"[{i+1}/{len(raw_tasks)}] {task_id} — skipped (already extracted)") + successes += 1 + else: + pending.append((i, raw)) + + total = len(raw_tasks) + + def _extract_one(idx: int, raw: dict) -> None: + nonlocal successes, failures + task_id = raw.get("id", f"unknown_{idx}") + print(f"[{idx+1}/{total}] {task_id} — extracting...", flush=True) + t0 = time.time() + + raw.setdefault("environment", {}) + raw.setdefault("user_scenario", {"user_profile": {}}) + task = Task(**raw) + + try: + constraints = extract_completeness_constraints( + task, + llm_model=args.model, + language=args.language, + ) + result = constraints.model_dump(exclude_none=True) + elapsed = time.time() - t0 + n_items = ( + len(constraints.hotel) + + len(constraints.flight) + + len(constraints.train) + + len(constraints.attraction) + + len(constraints.cancel) + + len(constraints.modify) + + len(constraints.conditional) + ) + print(f"[{idx+1}/{total}] {task_id} — OK ({n_items} constraints, {elapsed:.1f}s)") + with lock: + results[task_id] = result + successes += 1 + except Exception as e: + elapsed = time.time() - t0 + print(f"[{idx+1}/{total}] {task_id} — FAILED ({elapsed:.1f}s): {e}") + with lock: + results[task_id] = {"task_id": task_id, "_error": str(e)} + failures += 1 + + # Save after each task so progress isn't lost + with lock: + with open(args.output, "w") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + max_workers = max(1, args.max_concurrency) + if max_workers == 1: + for idx, raw in pending: + _extract_one(idx, raw) + else: + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = {pool.submit(_extract_one, idx, raw): idx for idx, raw in pending} + for fut in as_completed(futures): + fut.result() + + print(f"\nDone: {successes} succeeded, {failures} failed") + print(f"Saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preextract_constraints_harness.py b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preextract_constraints_harness.py new file mode 100644 index 0000000..a5b6d87 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preextract_constraints_harness.py @@ -0,0 +1,200 @@ +""" +Pre-extract NL constraints for the soundness judge harness. + +Run this ONCE before test time so the harness can load pre-computed +constraints instead of extracting at runtime. + +Usage: + python src/vita/scripts/preextract_constraints_harness.py --model + python src/vita/scripts/preextract_constraints_harness.py --model --tasks-file data/vita/domains/ota/tasks_en.json + python src/vita/scripts/preextract_constraints_harness.py --model --output harness_constraints.json + python src/vita/scripts/preextract_constraints_harness.py --model --task-ids D0811005 D0811006 + python src/vita/scripts/preextract_constraints_harness.py --model --num-tasks 10 +""" + +from __future__ import annotations + +import argparse +import json +import sys +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +from vita.config import models as model_configs +from vita.data_model.tasks import Task +from vita.domains.ota.soundness_judge_harness.constraint_extractor import extract_constraints + + +def load_tasks(path: str) -> list[dict]: + with open(path) as f: + return json.load(f) + + +def main(): + parser = argparse.ArgumentParser( + description="Pre-extract NL constraints for the soundness judge harness" + ) + parser.add_argument( + "--tasks-file", + default="data/vita/domains/ota/tasks_en.json", + help="Path to tasks JSON file", + ) + parser.add_argument( + "--user-sim-file", + default="data/user_sim_claude-opus-4.6_ota_english_resolved.json", + help="Path to user simulation messages JSON (keyed by task_id)", + ) + parser.add_argument( + "--output", + "-o", + default=None, + help="Output path (default: data/vita/domains/ota/harness_constraints_{model}.json)", + ) + parser.add_argument( + "--task-ids", + nargs="+", + default=None, + help="Only extract specific task IDs", + ) + parser.add_argument( + "--num-tasks", + type=int, + default=None, + help="Only extract first N tasks", + ) + parser.add_argument( + "--model", + required=True, + help="LLM model name", + ) + parser.add_argument( + "--language", + default="english", + choices=["english", "chinese"], + help="Prompt language (default: english)", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from existing output file, skipping already-extracted tasks", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=1, + help="Max parallel extraction tasks (default: 1)", + ) + args = parser.parse_args() + + if args.output is None: + model_slug = args.model.rsplit("/", 1)[-1] + args.output = f"data/vita/domains/ota/harness_constraints_{model_slug}.json" + + raw_tasks = load_tasks(args.tasks_file) + print(f"Loaded {len(raw_tasks)} tasks from {args.tasks_file}") + + # Load user simulation messages + with open(args.user_sim_file) as f: + user_sim_messages: dict[str, str] = json.load(f) + print(f"Loaded {len(user_sim_messages)} user sim messages from {args.user_sim_file}") + + # Filter + if args.task_ids: + raw_tasks = [t for t in raw_tasks if t.get("id") in args.task_ids] + if not raw_tasks: + print(f"No tasks found matching IDs: {args.task_ids}") + sys.exit(1) + elif args.num_tasks is not None: + raw_tasks = raw_tasks[: args.num_tasks] + + # Load existing results if resuming + existing: dict[str, dict] = {} + if args.resume: + try: + with open(args.output) as f: + existing = json.load(f) + print(f"Resuming: {len(existing)} tasks already extracted") + except FileNotFoundError: + pass + + results: dict[str, dict] = dict(existing) + successes = 0 + failures = 0 + lock = threading.Lock() + + # Filter out already-extracted tasks + pending: list[tuple[int, dict]] = [] + for i, raw in enumerate(raw_tasks): + task_id = raw.get("id", f"unknown_{i}") + if task_id in existing: + print(f"[{i+1}/{len(raw_tasks)}] {task_id} — skipped (already extracted)") + successes += 1 + else: + pending.append((i, raw)) + + total = len(raw_tasks) + + def _extract_one(idx: int, raw: dict) -> None: + nonlocal successes, failures + task_id = raw.get("id", f"unknown_{idx}") + print(f"[{idx+1}/{total}] {task_id} — extracting...", flush=True) + t0 = time.time() + + raw.setdefault("environment", {}) + raw.setdefault("user_scenario", {"user_profile": {}}) + task = Task(**raw) + + # Get user sim message for this task + user_msg = user_sim_messages.get(task_id) + if user_msg is None: + elapsed = time.time() - t0 + print(f"[{idx+1}/{total}] {task_id} — SKIPPED (no user sim message)") + with lock: + failures += 1 + return + + try: + llm_args = dict(model_configs.get(args.model, model_configs.get("default", {}))) + constraint_set = extract_constraints( + task, + llm_model=args.model, + llm_args=llm_args, + language=args.language, + user_message=user_msg, + ) + result = constraint_set.model_dump(exclude_none=True) + elapsed = time.time() - t0 + n = len(constraint_set.constraints) + print(f"[{idx+1}/{total}] {task_id} — OK ({n} constraints, {elapsed:.1f}s)") + with lock: + results[task_id] = result + successes += 1 + except Exception as e: + elapsed = time.time() - t0 + print(f"[{idx+1}/{total}] {task_id} — FAILED ({elapsed:.1f}s): {e}") + with lock: + results[task_id] = {"task_id": task_id, "_error": str(e)} + failures += 1 + + # Save after each task so progress isn't lost + with lock: + with open(args.output, "w") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + max_workers = max(1, args.max_concurrency) + if max_workers == 1: + for idx, raw in pending: + _extract_one(idx, raw) + else: + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = {pool.submit(_extract_one, idx, raw): idx for idx, raw in pending} + for fut in as_completed(futures): + fut.result() + + print(f"\nDone: {successes} succeeded, {failures} failed") + print(f"Saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/pregenerate_solo_messages.py b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/pregenerate_solo_messages.py new file mode 100644 index 0000000..e31d234 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/pregenerate_solo_messages.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Pregenerate the first user message for solo-agent tasks and store them in a JSON file. + +The output file maps task_id -> pregenerated user message string. Pass it to the +benchmark via --solo-user-mode=file --solo-user-file=. + +The output path is built automatically as: + data/user_sim___.json + +Usage: + python -m vita.scripts.pregenerate_solo_messages --domain ota --llm claude-sonnet-4-5 + + # Chinese tasks + python -m vita.scripts.pregenerate_solo_messages --domain ota --language chinese --llm claude-sonnet-4-5 + + # Force-regenerate entries already present in the output file + python -m vita.scripts.pregenerate_solo_messages --domain ota --llm claude-sonnet-4-5 --force +""" +import argparse +import json +import threading +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +from loguru import logger + +from vita.data_model.tasks import Task +from vita.user.user_simulator import DummyUser +from vita.utils.utils import get_task_file_path, DATA_DIR +from vita.orchestrator.orchestrator import get_default_first_agent_message +from vita.config import models + + +def build_output_path(llm: str, domain: str, language: str, resolved: bool = False) -> Path: + safe_model = llm.replace("/", "_") + suffix = "_resolved" if resolved else "" + return DATA_DIR / f"user_sim_{safe_model}_{domain}_{language}{suffix}.json" + + +def pregenerate( + domain: str, + language: str, + llm: str, + llm_args: dict, + output: Path, + task_ids: list[str] = None, + max_concurrency: int = 1, + force: bool = False, + resolved_instructions_file: str = None, +): + task_path = get_task_file_path(domain, language) + with open(task_path, "r", encoding="utf-8") as fp: + raw_tasks = json.load(fp) + + tasks = [Task.model_validate(t) for t in raw_tasks] + + # Load resolved instructions if provided + resolved = {} + if resolved_instructions_file: + with open(resolved_instructions_file, "r", encoding="utf-8") as fp: + resolved = json.load(fp) + logger.info(f"Loaded {len(resolved)} resolved instructions from {resolved_instructions_file}") + + tasks = [Task.model_validate(t) for t in raw_tasks] + + if task_ids is not None: + tasks = [t for t in tasks if t.id in task_ids] + if len(tasks) != len(task_ids): + missing = set(task_ids) - {t.id for t in tasks} + raise ValueError(f"Task IDs not found: {missing}") + + # Load existing output file so we can skip already-generated entries + messages: dict = {} + if output.exists(): + with open(output, "r", encoding="utf-8") as fp: + messages = json.load(fp) + + needs_update = [t for t in tasks if force or t.id not in messages] + logger.info(f"{len(needs_update)} / {len(tasks)} tasks need solo_user_message generation") + + if not needs_update: + logger.info("Nothing to do.") + return + + # The orchestrator sends this greeting before the first user turn + greeting = get_default_first_agent_message(language) + lock = threading.Lock() + + def _generate(task: Task): + if resolved: + entry = resolved.get(task.id) + if entry is None: + raise ValueError(f"Task {task.id} not found in resolved instructions file") + ri = entry.get("resolved_instructions") + if not ri: + raise ValueError(f"Task {task.id} has no resolved_instructions in resolved file") + instructions = ri + else: + instructions = str(task.instructions) + dummy = DummyUser( + instructions=instructions, + persona=str(task.user_scenario.user_profile), + llm=llm, + llm_args=llm_args, + language=language, + ) + state = dummy.get_init_state() + user_msg, _ = dummy.generate_next_message(greeting, state) + logger.info(f" [{task.id}] -> {user_msg.content[:120]!r}...") + with lock: + messages[task.id] = user_msg.content + + logger.info(f"Generating with concurrency={max_concurrency} ...") + with ThreadPoolExecutor(max_workers=max_concurrency) as executor: + list(executor.map(_generate, needs_update)) + + output.parent.mkdir(parents=True, exist_ok=True) + with open(output, "w", encoding="utf-8") as fp: + json.dump(messages, fp, indent=2, ensure_ascii=False) + + logger.info(f"Done. {len(needs_update)} messages written to {output}") + + +def main(): + parser = argparse.ArgumentParser(description="Pregenerate solo user messages for tasks.") + parser.add_argument("--domain", required=True, help="Domain name (e.g. ota, delivery, instore)") + parser.add_argument("--language", default="english", choices=["english", "chinese"]) + parser.add_argument("--llm", required=True, help="LLM model name to use for generation") + parser.add_argument("--task-ids", nargs="+", default=None, help="Only generate for these task IDs.") + parser.add_argument("--max-concurrency", type=int, default=1, help="Number of tasks to generate in parallel. Default is 1.") + parser.add_argument("--force", action="store_true", help="Regenerate even if already present in the output file") + parser.add_argument("--resolved-instructions", default=None, help="Path to resolved instructions JSON (from preresolve_dates.py)") + args = parser.parse_args() + + output = build_output_path(args.llm, args.domain, args.language, resolved=bool(args.resolved_instructions)) + logger.info(f"Output path: {output}") + + pregenerate( + domain=args.domain, + language=args.language, + llm=args.llm, + llm_args=models.get(args.llm, {}), + output=output, + task_ids=args.task_ids, + max_concurrency=args.max_concurrency, + force=args.force, + resolved_instructions_file=args.resolved_instructions, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preresolve_dates.py b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preresolve_dates.py new file mode 100644 index 0000000..5141045 --- /dev/null +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/scripts/preresolve_dates.py @@ -0,0 +1,349 @@ +""" +Pre-resolve relative dates in OTA task instructions. + +Replaces relative date expressions ("next Saturday", "next month on the 1st", +"the day after Qixi Festival", etc.) with absolute dates using an LLM, given +the simulated environment time. + +Produces a JSON mapping {task_id: resolved_instructions} that can be loaded +at runtime to replace the original instructions. + +Usage: + python src/vita/scripts/preresolve_dates.py --model + python src/vita/scripts/preresolve_dates.py --model --tasks-file data/vita/domains/ota/tasks_en.json + python src/vita/scripts/preresolve_dates.py --model --language english + python src/vita/scripts/preresolve_dates.py --model --task-ids D0812002 70812007 + python src/vita/scripts/preresolve_dates.py --model --resume +""" + +from __future__ import annotations + +import argparse +import json +import re +import sys +import threading +import time +from datetime import datetime +from concurrent.futures import ThreadPoolExecutor, as_completed + +from vita.config import DEFAULT_LLM_EVALUATOR, models +from vita.data_model.message import SystemMessage, UserMessage +from vita.domains.ota.verifier.utils import _extract_json +from vita.prompts import get_prompts +from vita.utils.llm_utils import generate +from vita.utils.utils import get_weekday + + +WEEKDAY_NAMES = { + 0: "Monday", 1: "Tuesday", 2: "Wednesday", + 3: "Thursday", 4: "Friday", 5: "Saturday", 6: "Sunday", +} + + +def _verify_day(date_str: str, claimed_day: str) -> bool: + """Return True if claimed day-of-week matches for a single date, a range, or a month.""" + # Month-only: YYYY-MM + if re.fullmatch(r"\d{4}-\d{2}", date_str): + return True # nothing to verify for month-level + # Range: YYYY-MM-DD to YYYY-MM-DD + range_match = re.fullmatch(r"(\d{4}-\d{2}-\d{2})\s+to\s+(\d{4}-\d{2}-\d{2})", date_str) + if range_match: + days = [d.strip() for d in claimed_day.split("-")] + if len(days) != 2: + return False + return _verify_single(range_match.group(1), days[0]) and _verify_single(range_match.group(2), days[1]) + # Multi-date with "and": YYYY-MM-DD and YYYY-MM-DD [and ...] + and_match = re.fullmatch(r"(\d{4}-\d{2}-\d{2})(?:\s+and\s+\d{4}-\d{2}-\d{2})+", date_str) + if and_match: + dates = re.findall(r"\d{4}-\d{2}-\d{2}", date_str) + days = [d.strip() for d in claimed_day.split(" and ")] + if len(dates) != len(days): + return False + return all(_verify_single(d, w) for d, w in zip(dates, days)) + # Single date + return _verify_single(date_str, claimed_day) + + +def _verify_single(date_str: str, claimed_day: str) -> bool: + try: + dt = datetime.strptime(date_str, "%Y-%m-%d") + except ValueError: + return False + return WEEKDAY_NAMES[dt.weekday()].lower() == claimed_day.lower().strip() + + +def load_tasks(path: str) -> list[dict]: + with open(path) as f: + return json.load(f) + + +def _build_correction_message(dates: list[dict]) -> str | None: + """Build a correction prompt for any dates that failed verification. + Returns None if all dates are verified.""" + errors = [] + for d in dates: + if not d["verified"]: + date_str = d["date"] + claimed = d["day"] + # Compute actual day(s) + range_match = re.fullmatch(r"(\d{4}-\d{2}-\d{2})\s+to\s+(\d{4}-\d{2}-\d{2})", date_str) + if range_match: + days_claimed = [x.strip() for x in claimed.split("-")] + parts = [] + for ds, dc in zip([range_match.group(1), range_match.group(2)], days_claimed): + try: + actual = WEEKDAY_NAMES[datetime.strptime(ds, "%Y-%m-%d").weekday()] + except ValueError: + actual = "?" + if actual.lower() != dc.lower(): + parts.append(f"{ds} is actually {actual}, not {dc}") + if parts: + errors.append(f'- "{d["phrase"]}": {"; ".join(parts)}') + elif re.fullmatch(r"\d{4}-\d{2}", date_str): + continue # month-level, nothing to verify + else: + try: + actual = WEEKDAY_NAMES[datetime.strptime(date_str, "%Y-%m-%d").weekday()] + except ValueError: + actual = "?" + errors.append(f'- "{d["phrase"]}": {date_str} is actually {actual}, not {claimed}') + if not errors: + return None + return ( + "Some of your resolved dates have incorrect days of the week:\n" + + "\n".join(errors) + + "\n\nPlease fix the errors and output the corrected JSON in the same format." + ) + + +def resolve_dates( + task_id: str, + instructions: str, + system_time: str, + language: str, + llm_model: str, + llm_args: dict, + max_retries: int = 2, +) -> dict: + """Call the LLM to extract relative date phrases, resolve them, and + return the replaced instructions. + + Returns dict with keys "resolved_dates" and "resolved_instructions". + """ + weekday = get_weekday(system_time, language) + + prompts = get_prompts(language) + system_prompt = prompts.date_resolution_template.format( + system_time=system_time, + weekday=weekday, + instructions=instructions, + ) + + messages = [ + SystemMessage(role="system", content=system_prompt), + UserMessage(role="user", content=instructions), + ] + + for attempt in range(1 + max_retries): + response = generate(model=llm_model, messages=messages, enable_think=True, **llm_args) + + raw_content = (response.content or "").strip() + if not raw_content: + raise ValueError(f"Empty response for task {task_id}") + + parsed = _extract_json(raw_content) + if parsed is None: + raise ValueError(f"No JSON found in response for {task_id}: {raw_content[:200]}") + + if not isinstance(parsed, dict): + raise ValueError(f"Expected dict, got {type(parsed).__name__} for {task_id}") + + tuples = parsed.get("resolved_dates", []) + resolved_instructions = parsed.get("resolved_instructions", "") + + if not resolved_instructions: + raise ValueError(f"Missing resolved_instructions for {task_id}") + + dates = [] + for item in tuples: + if isinstance(item, (list, tuple)) and len(item) == 3: + phrase, date_str, day = item + dates.append({ + "phrase": phrase, + "date": date_str, + "day": day, + "verified": _verify_day(date_str, day), + }) + + # Check if correction is needed + correction = _build_correction_message(dates) + if correction is None or attempt == max_retries: + break + + # Feed back errors for correction + messages.append({"role": "assistant", "content": raw_content}) + messages.append(UserMessage(role="user", content=correction)) + + return {"system_time": system_time, "resolved_dates": dates, "resolved_instructions": resolved_instructions} + + +def main(): + parser = argparse.ArgumentParser( + description="Pre-resolve relative dates in OTA task instructions" + ) + parser.add_argument( + "--tasks-file", + default="data/vita/domains/ota/tasks_en.json", + help="Path to tasks JSON file", + ) + parser.add_argument( + "--output", + "-o", + default=None, + help="Output path (default: derived from model name)", + ) + parser.add_argument( + "--task-ids", + nargs="+", + default=None, + help="Only resolve specific task IDs", + ) + parser.add_argument( + "--num-tasks", + type=int, + default=None, + help="Only resolve first N tasks", + ) + parser.add_argument( + "--model", + required=True, + help="LLM model name", + ) + parser.add_argument( + "--language", + default="english", + choices=["english", "chinese"], + help="Prompt language (default: english)", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from existing output file, skipping already-resolved tasks", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=16, + help="Max parallel resolution tasks (default: 16)", + ) + args = parser.parse_args() + + # Derive output path from model name if not specified + if args.output is None: + model_slug = args.model.rsplit("/", 1)[-1] + args.output = f"data/vita/domains/ota/resolved_instructions_{model_slug}.json" + + raw_tasks = load_tasks(args.tasks_file) + print(f"Loaded {len(raw_tasks)} tasks from {args.tasks_file}") + + # Filter + if args.task_ids: + raw_tasks = [t for t in raw_tasks if t.get("id") in args.task_ids] + if not raw_tasks: + print(f"No tasks found matching IDs: {args.task_ids}") + sys.exit(1) + elif args.num_tasks is not None: + raw_tasks = raw_tasks[: args.num_tasks] + + # Load existing results if resuming + existing: dict[str, str] = {} + if args.resume: + try: + with open(args.output) as f: + existing = json.load(f) + print(f"Resuming: {len(existing)} tasks already resolved") + except FileNotFoundError: + pass + + # LLM config + llm_args = dict(models.get(args.model, models.get("default", {}))) + llm_args["max_tokens"] = max(llm_args.get("max_tokens", 0), 4096) + + results: dict[str, str] = dict(existing) + successes = 0 + failures = 0 + lock = threading.Lock() + + # Filter out already-resolved tasks + pending: list[tuple[int, dict]] = [] + for i, raw in enumerate(raw_tasks): + task_id = raw.get("id", f"unknown_{i}") + if task_id in existing: + print(f"[{i+1}/{len(raw_tasks)}] {task_id} — skipped (already resolved)") + successes += 1 + else: + pending.append((i, raw)) + + total = len(raw_tasks) + + def _resolve_one(idx: int, raw: dict) -> None: + nonlocal successes, failures + task_id = raw.get("id", f"unknown_{idx}") + env = raw.get("environment", {}) + system_time = env.get("time", "") + instructions = raw.get("instructions", "") + + if not system_time: + print(f"[{idx+1}/{total}] {task_id} — SKIPPED (no env time)") + return + if not instructions: + print(f"[{idx+1}/{total}] {task_id} — SKIPPED (no instructions)") + return + + print(f"[{idx+1}/{total}] {task_id} — resolving...", flush=True) + t0 = time.time() + + try: + result = resolve_dates( + task_id=task_id, + instructions=instructions, + system_time=system_time, + language=args.language, + llm_model=args.model, + llm_args=llm_args, + ) + elapsed = time.time() - t0 + n_dates = len(result["resolved_dates"]) + print(f"[{idx+1}/{total}] {task_id} — OK ({n_dates} dates, {elapsed:.1f}s)") + with lock: + results[task_id] = result + successes += 1 + except Exception as e: + elapsed = time.time() - t0 + print(f"[{idx+1}/{total}] {task_id} — FAILED ({elapsed:.1f}s): {e}") + with lock: + results[task_id] = {"_error": str(e)} + failures += 1 + + # Save after each task + with lock: + with open(args.output, "w") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + max_workers = max(1, args.max_concurrency) + if max_workers == 1: + for idx, raw in pending: + _resolve_one(idx, raw) + else: + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = {pool.submit(_resolve_one, idx, raw): idx for idx, raw in pending} + for fut in as_completed(futures): + fut.result() + + print(f"\nDone: {successes} succeeded, {failures} failed") + print(f"Saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/user/user_simulator.py b/examples/AgenticBenchmarks/VitaBench/src/vita/user/user_simulator.py index d3bb225..bd1c7a9 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/user/user_simulator.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/user/user_simulator.py @@ -1,3 +1,13 @@ +"""VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/user/user_simulator.py. +Everything is verbatim from the original except for the following changes: + +1. Added a ``solo_user_message`` argument to ``DummyUser.__init__``. +2. Added the ``DummyUser.build(...)`` classmethod that resolves the opening + message per ``solo_user_mode`` ('live' vs 'file'). +3. Added ``_fixed_message(...)`` and a short-circuit in ``generate_next_message`` + that returns the pregenerated message without an LLM call. +""" from typing import Optional, Tuple from loguru import logger @@ -171,11 +181,54 @@ def __init__( llm: Optional[str] = None, llm_args: Optional[dict] = None, language: str = None, + solo_user_message: Optional[str] = None, ): super().__init__(instructions=instructions, llm=llm, llm_args=llm_args) self.tools = tools self.persona = persona self.language = language + self.solo_user_message = solo_user_message + + @classmethod + def build( + cls, + task_id: str, + instructions: str, + persona: str, + llm: Optional[str] = None, + llm_args: Optional[dict] = None, + language: str = None, + solo_user_mode: str = "live", + solo_user_messages: Optional[dict] = None, + ) -> "DummyUser": + """Construct a DummyUser, resolving the user message according to the given mode. + + Args: + task_id: ID of the task being run, used to look up the pregenerated message. + solo_user_mode: ``'live'`` — generate the opening message via LLM each run + (introduces variance); ``'file'`` — look up a pregenerated message from + *solo_user_messages* (deterministic, raises if the task is missing). + solo_user_messages: Mapping of task_id -> pregenerated message string. + Required when *solo_user_mode* is ``'file'``. + """ + if solo_user_mode == "file": + if solo_user_messages is None or task_id not in solo_user_messages: + raise ValueError( + f"solo_user_mode='file' but no pregenerated message found for task '{task_id}'. " + "Run the pregeneration script first or switch to solo_user_mode='live'." + ) + solo_user_message = solo_user_messages[task_id] + else: + solo_user_message = None + + return cls( + instructions=instructions, + persona=persona, + llm=llm, + llm_args=llm_args, + language=language, + solo_user_message=solo_user_message, + ) @property def system_prompt(self) -> str: @@ -217,8 +270,22 @@ def is_stop(cls, message: UserMessage) -> bool: def generate_next_message( self, message: ValidUserInputMessage, state: UserState ) -> Tuple[UserMessage, UserState]: + if self.solo_user_message is not None: + return self._fixed_message(message, state) return self._generate_next_message(message, state) + def _fixed_message( + self, message: ValidUserInputMessage, state: UserState + ) -> Tuple[UserMessage, UserState]: + """Return the pregenerated fixed message without calling the LLM.""" + if isinstance(message, MultiToolMessage): + state.messages.extend(message.tool_messages) + else: + state.messages.append(message) + user_message = UserMessage(role="user", content=self.solo_user_message, cost=0.0) + state.messages.append(user_message) + return user_message, state + def _generate_next_message( self, message: ValidUserInputMessage, state: UserState ) -> Tuple[UserMessage, UserState]: diff --git a/examples/AgenticBenchmarks/VitaBench/src/vita/utils/utils.py b/examples/AgenticBenchmarks/VitaBench/src/vita/utils/utils.py index 1b902cf..2152a82 100644 --- a/examples/AgenticBenchmarks/VitaBench/src/vita/utils/utils.py +++ b/examples/AgenticBenchmarks/VitaBench/src/vita/utils/utils.py @@ -1,3 +1,11 @@ +"""VitaBench overlay file — modified from the original VitaBench repo +(https://github.com/meituan-longcat/vitabench), at src/vita/utils/utils.py. +Everything is verbatim from the original except for the following change: + +1. Rewrote ``evaluator_extracter`` to robustly extract the last valid JSON from + LLM output: it strips ```` blocks, tries ```json fences, then balanced + top-level ``{}`` / ``[]`` blocks, and finally falls back to ``repair_json``. +""" import re import hashlib import json @@ -233,8 +241,65 @@ def extract_json_fields(json_str): def evaluator_extracter(content: str) -> list[dict]: """ - Extract the result from the content. + Extract the last valid JSON object/array from LLM output. + + The LLM may emit reasoning text with stray brackets (e.g. "[10]") or + multiple JSON blocks (reconsidering). We find all candidates and take + the last one that parses. """ - good_json_string = repair_json(content) - result_data = json.loads(good_json_string) - return result_data + # Strip ... blocks + text = content.split("")[-1].strip() + + # 1. Try all ```json fences, take the last valid one + fence_matches = re.findall(r"```(?:json)?(.*?)```", text, re.DOTALL) + for candidate in reversed(fence_matches): + candidate = candidate.strip() + if not candidate: + continue + try: + return json.loads(candidate) + except json.JSONDecodeError: + continue + + # 2. Find balanced top-level [ ] and { } blocks, take the last valid one + candidates: list[str] = [] + depth = 0 + start = -1 + in_string = False + escape = False + open_char = None + for i, ch in enumerate(text): + if escape: + escape = False + continue + if ch == '\\' and in_string: + escape = True + continue + if ch == '"' and not escape: + in_string = not in_string + continue + if in_string: + continue + if depth == 0 and ch in ('{', '['): + start = i + open_char = ch + depth = 1 + elif depth > 0: + close_char = '}' if open_char == '{' else ']' + if ch == open_char: + depth += 1 + elif ch == close_char: + depth -= 1 + if depth == 0 and start != -1: + candidates.append(text[start:i + 1]) + start = -1 + + for candidate in reversed(candidates): + try: + return json.loads(candidate) + except json.JSONDecodeError: + continue + + # 3. Last resort: pass entire content to repair_json + good_json_string = repair_json(text) + return json.loads(good_json_string)