diff --git a/README.md b/README.md index ff06560..7217662 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,11 @@ The AIOS-Agent SDK is designed for agent users and developers, enabling them to 📝 See [here](https://docs.aios.foundation/getting-started/installation). Below are some useful commands to use +- [List available LLMs](./cerebrum/commands/list_available_llms.py) + ```bash + list-available-llms + ``` + - [List agents from agenthub](./cerebrum/commands/list_agenthub_agents.py) ```bash list-agenthub-agents diff --git a/benchmarks/agents/autogen.py b/benchmarks/agents/autogen.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/agents/pure_llm.py b/benchmarks/agents/cot/cot.py similarity index 99% rename from benchmarks/agents/pure_llm.py rename to benchmarks/agents/cot/cot.py index 3f5c01d..1d2cfd9 100644 --- a/benchmarks/agents/pure_llm.py +++ b/benchmarks/agents/cot/cot.py @@ -2,7 +2,7 @@ from litellm import completion -class PureLLM: +class CoT: def __init__(self, on_aios: bool = True): self.agent_name = "llm" self.on_aios = on_aios diff --git a/benchmarks/agents/interpreter.py b/benchmarks/agents/interpreter.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/agents/metagpt.py b/benchmarks/agents/metagpt.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/agents/nano_manus/agent.py b/benchmarks/agents/nano_manus/agent.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/agents/owl/agent.py b/benchmarks/agents/owl/agent.py new file mode 100644 index 0000000..14012b1 --- /dev/null +++ b/benchmarks/agents/owl/agent.py @@ -0,0 +1,7 @@ + +class OWLAgent: + def __init__(self): + pass + + def run_gaia(self): + pass diff --git a/benchmarks/agents/owl/role_playing.py b/benchmarks/agents/owl/role_playing.py new file mode 100644 index 0000000..cb2a1a3 --- /dev/null +++ b/benchmarks/agents/owl/role_playing.py @@ -0,0 +1,1079 @@ +from typing import Dict, List, Optional, Tuple + + +from camel.agents import ChatAgent +from camel.responses import ChatAgentResponse +from camel.messages.base import BaseMessage +from camel.societies import RolePlaying +from camel.logger import get_logger + + +from copy import deepcopy + +logger = get_logger(__name__) + + +class OwlRolePlaying(RolePlaying): + def __init__(self, **kwargs): + self.user_role_name = kwargs.get("user_role_name", "user") + self.assistant_role_name = kwargs.get("assistant_role_name", "assistant") + + self.output_language = kwargs.get("output_language", None) + + self.user_agent_kwargs: dict = kwargs.get("user_agent_kwargs", {}) + self.assistant_agent_kwargs: dict = kwargs.get("assistant_agent_kwargs", {}) + + self.output_language = kwargs.get("output_language", None) + + super().__init__(**kwargs) + + init_user_sys_msg, init_assistant_sys_msg = self._construct_gaia_sys_msgs() + + self.assistant_agent: ChatAgent + self.user_agent: ChatAgent + self.assistant_sys_msg: Optional[BaseMessage] + self.user_sys_msg: Optional[BaseMessage] + + # self.is_reasoning_task = self._judge_if_reasoning_task(self.task_prompt) + + # if self.is_reasoning_task: + # logger.info("The task is judged as a reasoning or coding task. The assistant agent will use the reasoning model O3-MINI.") + # else: + # logger.info("The assistant agent will use the default model.") + + self._init_agents( + init_assistant_sys_msg, + init_user_sys_msg, + assistant_agent_kwargs=self.assistant_agent_kwargs, + user_agent_kwargs=self.user_agent_kwargs, + output_language=self.output_language, + # is_reasoning_task=self.is_reasoning_task + ) + + def _init_agents( + self, + init_assistant_sys_msg: BaseMessage, + init_user_sys_msg: BaseMessage, + assistant_agent_kwargs: Optional[Dict] = None, + user_agent_kwargs: Optional[Dict] = None, + output_language: Optional[str] = None, + is_reasoning_task: bool = False, + ) -> None: + r"""Initialize assistant and user agents with their system messages. + + Args: + init_assistant_sys_msg (BaseMessage): Assistant agent's initial + system message. + init_user_sys_msg (BaseMessage): User agent's initial system + message. + assistant_agent_kwargs (Dict, optional): Additional arguments to + pass to the assistant agent. (default: :obj:`None`) + user_agent_kwargs (Dict, optional): Additional arguments to + pass to the user agent. (default: :obj:`None`) + output_language (str, optional): The language to be output by the + agents. (default: :obj:`None`) + """ + if self.model is not None: + if assistant_agent_kwargs is None: + assistant_agent_kwargs = {"model": self.model} + elif "model" not in assistant_agent_kwargs: + assistant_agent_kwargs.update(dict(model=self.model)) + if user_agent_kwargs is None: + user_agent_kwargs = {"model": self.model} + elif "model" not in user_agent_kwargs: + user_agent_kwargs.update(dict(model=self.model)) + + # # If the task is a reasoning task, the assistant agent should use the reasoning model O3-MINI + # if is_reasoning_task: + # assistant_agent_kwargs['model'] = ModelFactory.create( + # model_platform=ModelPlatformType.OPENAI, + # model_type=ModelType.O3_MINI, + # ) + + self.assistant_agent = ChatAgent( + init_assistant_sys_msg, + output_language=output_language, + **(assistant_agent_kwargs or {}), + ) + self.assistant_sys_msg = self.assistant_agent.system_message + + self.user_agent = ChatAgent( + init_user_sys_msg, + output_language=output_language, + **(user_agent_kwargs or {}), + ) + self.user_sys_msg = self.user_agent.system_message + + # def _judge_if_reasoning_task(self, question: str) -> bool: + # r"""Judge if the question is a reasoning task.""" + + # LLM = OpenAIModel(model_type=ModelType.O3_MINI) + # prompt = f""" + # Please judge whether the following question is a reasoning or coding task, which can be solved by reasoning without leveraging external resources, or is suitable for writing code to solve the task. + # If it is a reasoning or coding task, please return only "yes". + # If it is not a reasoning or coding task, please return only "no". + # Note: + # - If the question required some world knowledge to answer the question, please carefully judge it, because the model's own knowledge is often unreliable. + # - If it is suitable for writing codes (e.g. process excel files, write simulation codes, etc.), in most cases, it can be considered as a coding task. + # Question: {question} + # """ + # messages = [{"role": "user", "content": prompt}] + # resp = LLM.run(messages) + # if 'yes' in resp.choices[0].message.content.lower(): + # return True + # else: + # return False + + def _construct_gaia_sys_msgs(self): + user_system_prompt = f""" +===== RULES OF USER ===== +Never forget you are a user and I am a assistant. Never flip roles! You will always instruct me. We share a common interest in collaborating to successfully complete a task. +I must help you to complete a difficult task. +You must instruct me based on my expertise and your needs to solve the task step by step. The format of your instruction is: `Instruction: [YOUR INSTRUCTION]`, where "Instruction" describes a sub-task or question. +You must give me one instruction at a time. +I must write a response that appropriately solves the requested instruction. +You should instruct me not ask me questions. + +Please note that the task may be very complicated. Do not attempt to solve the task by single step. You must instruct me to find the answer step by step. +Here are some tips that will help you to give more valuable instructions about our task to me: + +- I have various tools to use, such as search toolkit, web browser simulation toolkit, document relevant toolkit, code execution toolkit, etc. Thus, You must think how human will solve the task step-by-step, and give me instructions just like that. For example, one may first use google search to get some initial information and the target url, then retrieve the content of the url, or do some web browser interaction to find the answer. +- Although the task is complex, the answer does exist. If you can't find the answer using the current scheme, try to re-plan and use other ways to find the answer, e.g. using other tools or methods that can achieve similar results. +- Always remind me to verify my final answer about the overall task. This work can be done by using multiple tools(e.g., screenshots, webpage analysis, etc.), or something else. +- If I have written code, please remind me to run the code and get the result. +- Search results typically do not provide precise answers. It is not likely to find the answer directly using search toolkit only, the search query should be concise and focuses on finding sources rather than direct answers, as it always need to use other tools to further process the url, e.g. interact with the webpage, extract webpage content, etc. +- If the question mentions youtube video, in most cases you have to process the content of the mentioned video. +- For downloading files, you can either use the web browser simulation toolkit or write codes (for example, the github content can be downloaded via https://raw.githubusercontent.com/...). +- Flexibly write codes to solve some problems, such as excel relevant tasks. + + +Now, here is the overall task: {self.task_prompt}. Never forget our task! + +Now you must start to instruct me to solve the task step-by-step. Do not add anything else other than your instruction! +Keep giving me instructions until you think the task is completed. +When the task is completed, you must only reply with a single word . +Never say unless my responses have solved your task. + """ + + assistant_system_prompt = f""" +===== RULES OF ASSISTANT ===== +Never forget you are a assistant and I am a user. Never flip roles! Never instruct me! You have to utilize your available tools to solve the task I assigned. +We share a common interest in collaborating to successfully complete a complex task. +You must help me to complete the task. + +Here is our overall task: {self.task_prompt}. Never forget our task! + +I must instruct you based on your expertise and my needs to complete the task. An instruction is typically a sub-task or question. + +You must leverage your available tools, try your best to solve the problem, and explain your solutions. +Unless I say the task is completed, you should always start with: +Solution: [YOUR_SOLUTION] +[YOUR_SOLUTION] should be specific, including detailed explanations and provide preferable detailed implementations and examples and lists for task-solving. + +Please note that our overall task may be very complicated. Here are some tips that may help you solve the task: + +- If one way fails to provide an answer, try other ways or methods. The answer does exists. +- If the search snippet is unhelpful but the URL comes from an authoritative source, try visit the website for more details. +- When looking for specific numerical values (e.g., dollar amounts), prioritize reliable sources and avoid relying only on search snippets. +- When solving tasks that require web searches, check Wikipedia first before exploring other websites. +- When trying to solve math problems, you can try to write python code and use sympy library to solve the problem. +- Always verify the accuracy of your final answers! Try cross-checking the answers by other ways. (e.g., screenshots, webpage analysis, etc.). +- Do not be overly confident in your own knowledge. Searching can provide a broader perspective and help validate existing knowledge. +- After writing codes, do not forget to run the code and get the result. If it encounters an error, try to debug it. Also, bear in mind that the code execution environment does not support interactive input. +- When a tool fails to run, or the code does not run correctly, never assume that it returns the correct result and continue to reason based on the assumption, because the assumed result cannot lead you to the correct answer. The right way is to think about the reason for the error and try again. +- Search results typically do not provide precise answers. It is not likely to find the answer directly using search toolkit only, the search query should be concise and focuses on finding sources rather than direct answers, as it always need to use other tools to further process the url, e.g. interact with the webpage, extract webpage content, etc. +- For downloading files, you can either use the web browser simulation toolkit or write codes. + + + """ + + user_sys_msg = BaseMessage.make_user_message( + role_name=self.user_role_name, content=user_system_prompt + ) + + assistant_sys_msg = BaseMessage.make_assistant_message( + role_name=self.assistant_role_name, content=assistant_system_prompt + ) + + return user_sys_msg, assistant_sys_msg + + def step( + self, assistant_msg: BaseMessage + ) -> Tuple[ChatAgentResponse, ChatAgentResponse]: + user_response = self.user_agent.step(assistant_msg) + if user_response.terminated or user_response.msgs is None: + return ( + ChatAgentResponse(msgs=[], terminated=False, info={}), + ChatAgentResponse( + msgs=[], + terminated=user_response.terminated, + info=user_response.info, + ), + ) + user_msg = self._reduce_message_options(user_response.msgs) + + modified_user_msg = deepcopy(user_msg) + + if "TASK_DONE" not in user_msg.content: + modified_user_msg.content += f"""\n + Here are auxiliary information about the overall task, which may help you understand the intent of the current task: + + {self.task_prompt} + + If there are available tools and you want to call them, never say 'I will ...', but first call the tool and reply based on tool call's result, and tell me which tool you have called. + """ + + else: + # The task is done, and the assistant agent need to give the final answer about the original task + modified_user_msg.content += f"""\n + Now please make a final answer of the original task based on our conversation : {self.task_prompt} + """ + + # process assistant's response + assistant_response = self.assistant_agent.step(modified_user_msg) + if assistant_response.terminated or assistant_response.msgs is None: + return ( + ChatAgentResponse( + msgs=[], + terminated=assistant_response.terminated, + info=assistant_response.info, + ), + ChatAgentResponse( + msgs=[user_msg], terminated=False, info=user_response.info + ), + ) + assistant_msg = self._reduce_message_options(assistant_response.msgs) + + modified_assistant_msg = deepcopy(assistant_msg) + if "TASK_DONE" not in user_msg.content: + modified_assistant_msg.content += f"""\n + Provide me with the next instruction and input (if needed) based on my response and our current task: {self.task_prompt} + Before producing the final answer, please check whether I have rechecked the final answer using different toolkit as much as possible. If not, please remind me to do that. + If I have written codes, remind me to run the codes. + If you think our task is done, reply with `TASK_DONE` to end our conversation. + """ + + # return the modified messages + return ( + ChatAgentResponse( + msgs=[modified_assistant_msg], + terminated=assistant_response.terminated, + info=assistant_response.info, + ), + ChatAgentResponse( + msgs=[modified_user_msg], + terminated=user_response.terminated, + info=user_response.info, + ), + ) + + +class OwlGAIARolePlaying(OwlRolePlaying): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def step( + self, assistant_msg: BaseMessage + ) -> Tuple[ChatAgentResponse, ChatAgentResponse]: + user_response = self.user_agent.step(assistant_msg) + if user_response.terminated or user_response.msgs is None: + return ( + ChatAgentResponse(msgs=[], terminated=False, info={}), + ChatAgentResponse( + msgs=[], + terminated=user_response.terminated, + info=user_response.info, + ), + ) + user_msg = self._reduce_message_options(user_response.msgs) + + modified_user_msg = deepcopy(user_msg) + + if "TASK_DONE" not in user_msg.content: + modified_user_msg.content += f"""\n + Here are auxiliary information about the overall task, which may help you understand the intent of the current task: + + {self.task_prompt} + + If there are available tools and you want to call them, never say 'I will ...', but first call the tool and reply based on tool call's result, and tell me which tool you have called. + """ + + else: + # The task is done, and the assistant agent need to give the final answer about the original task + modified_user_msg.content += f"""\n + Now please make a final answer of the original task based on our conversation : {self.task_prompt} + Please pay special attention to the format in which the answer is presented. + You should first analyze the answer format required by the question and then output the final answer that meets the format requirements. + Your response should include the following content: + - `analysis`: enclosed by , a detailed analysis of the reasoning result. + - `final_answer`: enclosed by , the final answer to the question. + Here are some hint about the final answer: + + Your final answer must be output exactly in the format specified by the question. It should be a number OR as few words as possible OR a comma separated list of numbers and/or strings: + - If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. + - If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. + - If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. + + """ + + # process assistant's response + assistant_response = self.assistant_agent.step(modified_user_msg) + if assistant_response.terminated or assistant_response.msgs is None: + return ( + ChatAgentResponse( + msgs=[], + terminated=assistant_response.terminated, + info=assistant_response.info, + ), + ChatAgentResponse( + msgs=[user_msg], terminated=False, info=user_response.info + ), + ) + assistant_msg = self._reduce_message_options(assistant_response.msgs) + + modified_assistant_msg = deepcopy(assistant_msg) + if "TASK_DONE" not in user_msg.content: + modified_assistant_msg.content += f"""\n + Provide me with the next instruction and input (if needed) based on my response and our current task: {self.task_prompt} + Before producing the final answer, please check whether I have rechecked the final answer using different toolkit as much as possible. If not, please remind me to do that. + If I have written codes, remind me to run the codes. + If you think our task is done, reply with `TASK_DONE` to end our conversation. + """ + + # return the modified messages + return ( + ChatAgentResponse( + msgs=[modified_assistant_msg], + terminated=assistant_response.terminated, + info=assistant_response.info, + ), + ChatAgentResponse( + msgs=[modified_user_msg], + terminated=user_response.terminated, + info=user_response.info, + ), + ) + + +def run_society( + society: OwlRolePlaying, + round_limit: int = 15, +) -> Tuple[str, List[dict], dict]: + overall_completion_token_count = 0 + overall_prompt_token_count = 0 + + chat_history = [] + init_prompt = """ + Now please give me instructions to solve over overall task step by step. If the task requires some specific knowledge, please instruct me to use tools to complete the task. + """ + input_msg = society.init_chat(init_prompt) + for _round in range(round_limit): + assistant_response, user_response = society.step(input_msg) + # Check if usage info is available before accessing it + if assistant_response.info.get("usage") and user_response.info.get("usage"): + overall_completion_token_count += assistant_response.info["usage"].get( + "completion_tokens", 0 + ) + user_response.info["usage"].get("completion_tokens", 0) + overall_prompt_token_count += assistant_response.info["usage"].get( + "prompt_tokens", 0 + ) + user_response.info["usage"].get("prompt_tokens", 0) + + # convert tool call to dict + tool_call_records: List[dict] = [] + if assistant_response.info.get("tool_calls"): + for tool_call in assistant_response.info["tool_calls"]: + tool_call_records.append(tool_call.as_dict()) + + _data = { + "user": user_response.msg.content + if hasattr(user_response, "msg") and user_response.msg + else "", + "assistant": assistant_response.msg.content + if hasattr(assistant_response, "msg") and assistant_response.msg + else "", + "tool_calls": tool_call_records, + } + + chat_history.append(_data) + logger.info( + f"Round #{_round} user_response:\n {user_response.msgs[0].content if user_response.msgs and len(user_response.msgs) > 0 else ''}" + ) + logger.info( + f"Round #{_round} assistant_response:\n {assistant_response.msgs[0].content if assistant_response.msgs and len(assistant_response.msgs) > 0 else ''}" + ) + + if ( + assistant_response.terminated + or user_response.terminated + or "TASK_DONE" in user_response.msg.content + ): + break + + input_msg = assistant_response.msg + + answer = chat_history[-1]["assistant"] + token_info = { + "completion_token_count": overall_completion_token_count, + "prompt_token_count": overall_prompt_token_count, + } + + return answer, chat_history, token_info + + +import logging +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from camel.agents import ( + ChatAgent, + CriticAgent, + TaskPlannerAgent, + TaskSpecifyAgent, +) +from camel.generators import SystemMessageGenerator +from camel.human import Human +from camel.messages import BaseMessage +from camel.models import BaseModelBackend +from camel.prompts import TextPrompt +from camel.responses import ChatAgentResponse +from camel.types import RoleType, TaskType + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +class RolePlaying: + r"""Role playing between two agents. + + Args: + assistant_role_name (str): The name of the role played by the + assistant. + user_role_name (str): The name of the role played by the user. + critic_role_name (str, optional): The name of the role played by the + critic. Role name with :obj:`"human"` will set critic as a + :obj:`Human` agent, else will create a :obj:`CriticAgent`. + (default: :obj:`"critic"`) + task_prompt (str, optional): A prompt for the task to be performed. + (default: :obj:`""`) + with_task_specify (bool, optional): Whether to use a task specify + agent. (default: :obj:`True`) + with_task_planner (bool, optional): Whether to use a task planner + agent. (default: :obj:`False`) + with_critic_in_the_loop (bool, optional): Whether to include a critic + in the loop. (default: :obj:`False`) + critic_criteria (str, optional): Critic criteria for the critic agent. + If not specified, set the criteria to improve task performance. + model (BaseModelBackend, optional): The model backend to use for + generating responses. If specified, it will override the model in + all agents if not specified in agent-specific kwargs. (default: + :obj:`OpenAIModel` with `GPT_4O_MINI`) + task_type (TaskType, optional): The type of task to perform. + (default: :obj:`TaskType.AI_SOCIETY`) + assistant_agent_kwargs (Dict, optional): Additional arguments to pass + to the assistant agent. (default: :obj:`None`) + user_agent_kwargs (Dict, optional): Additional arguments to pass to + the user agent. (default: :obj:`None`) + task_specify_agent_kwargs (Dict, optional): Additional arguments to + pass to the task specify agent. (default: :obj:`None`) + task_planner_agent_kwargs (Dict, optional): Additional arguments to + pass to the task planner agent. (default: :obj:`None`) + critic_kwargs (Dict, optional): Additional arguments to pass to the + critic. (default: :obj:`None`) + sys_msg_generator_kwargs (Dict, optional): Additional arguments to + pass to the system message generator. (default: :obj:`None`) + extend_sys_msg_meta_dicts (List[Dict], optional): A list of dicts to + extend the system message meta dicts with. (default: :obj:`None`) + extend_task_specify_meta_dict (Dict, optional): A dict to extend the + task specify meta dict with. (default: :obj:`None`) + output_language (str, optional): The language to be output by the + agents. (default: :obj:`None`) + """ + + def __init__( + self, + assistant_role_name: str, + user_role_name: str, + *, + critic_role_name: str = "critic", + task_prompt: str = "", + with_task_specify: bool = True, + with_task_planner: bool = False, + with_critic_in_the_loop: bool = False, + critic_criteria: Optional[str] = None, + model: Optional[BaseModelBackend] = None, + task_type: TaskType = TaskType.AI_SOCIETY, + assistant_agent_kwargs: Optional[Dict] = None, + user_agent_kwargs: Optional[Dict] = None, + task_specify_agent_kwargs: Optional[Dict] = None, + task_planner_agent_kwargs: Optional[Dict] = None, + critic_kwargs: Optional[Dict] = None, + sys_msg_generator_kwargs: Optional[Dict] = None, + extend_sys_msg_meta_dicts: Optional[List[Dict]] = None, + extend_task_specify_meta_dict: Optional[Dict] = None, + output_language: Optional[str] = None, + ) -> None: + if model is not None: + logger.warning( + "Model provided globally is set for all agents if not" + " already specified in agent_kwargs." + ) + + self.with_task_specify = with_task_specify + self.with_task_planner = with_task_planner + self.with_critic_in_the_loop = with_critic_in_the_loop + self.model = model + self.task_type = task_type + self.task_prompt = task_prompt + + self.specified_task_prompt: Optional[TextPrompt] = None + self._init_specified_task_prompt( + assistant_role_name, + user_role_name, + task_specify_agent_kwargs=task_specify_agent_kwargs, + extend_task_specify_meta_dict=extend_task_specify_meta_dict, + output_language=output_language, + ) + + self.planned_task_prompt: Optional[TextPrompt] = None + self._init_planned_task_prompt( + task_planner_agent_kwargs=task_planner_agent_kwargs, + output_language=output_language, + ) + + sys_msg_generator = SystemMessageGenerator( + task_type=self.task_type, + **(sys_msg_generator_kwargs or {}), + ) + + ( + init_assistant_sys_msg, + init_user_sys_msg, + sys_msg_meta_dicts, + ) = self._get_sys_message_info( + assistant_role_name, + user_role_name, + sys_msg_generator, + extend_sys_msg_meta_dicts=extend_sys_msg_meta_dicts, + ) + + self.assistant_agent: ChatAgent + self.user_agent: ChatAgent + self.assistant_sys_msg: Optional[BaseMessage] + self.user_sys_msg: Optional[BaseMessage] + self._init_agents( + init_assistant_sys_msg, + init_user_sys_msg, + assistant_agent_kwargs=assistant_agent_kwargs, + user_agent_kwargs=user_agent_kwargs, + output_language=output_language, + ) + self.critic: Optional[Union[CriticAgent, Human]] = None + self.critic_sys_msg: Optional[BaseMessage] = None + self._init_critic( + sys_msg_generator, + sys_msg_meta_dicts, + critic_role_name, + critic_criteria=critic_criteria, + critic_kwargs=critic_kwargs, + ) + + def _init_specified_task_prompt( + self, + assistant_role_name: str, + user_role_name: str, + task_specify_agent_kwargs: Optional[Dict] = None, + extend_task_specify_meta_dict: Optional[Dict] = None, + output_language: Optional[str] = None, + ) -> None: + r"""Use a task specify agent to generate a specified task prompt. + Generated specified task prompt will be used to replace original + task prompt. If there is no task specify agent, specified task + prompt will not be generated. + + Args: + assistant_role_name (str): The name of the role played by the + assistant. + user_role_name (str): The name of the role played by the user. + task_specify_agent_kwargs (Dict, optional): Additional arguments + to pass to the task specify agent. (default: :obj:`None`) + extend_task_specify_meta_dict (Dict, optional): A dict to extend + the task specify meta dict with. (default: :obj:`None`) + output_language (str, optional): The language to be output by the + agents. (default: :obj:`None`) + """ + if self.with_task_specify: + task_specify_meta_dict = dict() + if self.task_type in [TaskType.AI_SOCIETY, TaskType.MISALIGNMENT]: + task_specify_meta_dict.update( + dict( + assistant_role=assistant_role_name, + user_role=user_role_name, + ) + ) + task_specify_meta_dict.update(extend_task_specify_meta_dict or {}) + if self.model is not None: + if task_specify_agent_kwargs is None: + task_specify_agent_kwargs = {'model': self.model} + elif 'model' not in task_specify_agent_kwargs: + task_specify_agent_kwargs.update(dict(model=self.model)) + task_specify_agent = TaskSpecifyAgent( + task_type=self.task_type, + output_language=output_language, + **(task_specify_agent_kwargs or {}), + ) + self.specified_task_prompt = task_specify_agent.run( + self.task_prompt, + meta_dict=task_specify_meta_dict, + ) + self.task_prompt = self.specified_task_prompt + + def _init_planned_task_prompt( + self, + task_planner_agent_kwargs: Optional[Dict] = None, + output_language: Optional[str] = None, + ) -> None: + r"""Use a task plan agent to append a planned task prompt to task + prompt. The planned task prompt is generated based on the task + prompt, which can be original task prompt or specified task prompt + if available. If there is no task plan agent, planned task prompt + will not be generated. + + Args: + task_planner_agent_kwargs (Dict, optional): Additional arguments + to pass to the task planner agent. (default: :obj:`None`) + output_language (str, optional): The language to be output by the + agents. (default: :obj:`None`) + """ + if self.with_task_planner: + if self.model is not None: + if task_planner_agent_kwargs is None: + task_planner_agent_kwargs = {'model': self.model} + elif 'model' not in task_planner_agent_kwargs: + task_planner_agent_kwargs.update(dict(model=self.model)) + task_planner_agent = TaskPlannerAgent( + output_language=output_language, + **(task_planner_agent_kwargs or {}), + ) + self.planned_task_prompt = task_planner_agent.run(self.task_prompt) + self.task_prompt = ( + f"{self.task_prompt}\n" f"{self.planned_task_prompt}" + ) + else: + self.planned_task_prompt = None + + def _get_sys_message_info( + self, + assistant_role_name: str, + user_role_name: str, + sys_msg_generator: SystemMessageGenerator, + extend_sys_msg_meta_dicts: Optional[List[Dict]] = None, + ) -> Tuple[BaseMessage, BaseMessage, List[Dict]]: + r"""Get initial assistant and user system message with a list of + system message meta dicts. + + Args: + assistant_role_name (str): The name of the role played by the + assistant. + user_role_name (str): The name of the role played by the user. + sys_msg_generator (SystemMessageGenerator): A system message + generator for agents. + extend_sys_msg_meta_dicts (List[Dict], optional): A list of dicts + to extend the system message meta dicts with. + (default: :obj:`None`) + + Returns: + Tuple[BaseMessage, BaseMessage, List[Dict]]: A tuple containing a + `BaseMessage` representing the assistant's initial system + message, a `BaseMessage` representing the user's initial system + message, and a list of system message meta dicts. + """ + sys_msg_meta_dicts = [dict(task=self.task_prompt) for _ in range(2)] + if extend_sys_msg_meta_dicts is None and self.task_type in [ + TaskType.AI_SOCIETY, + TaskType.MISALIGNMENT, + ]: + extend_sys_msg_meta_dicts = [ + dict( + assistant_role=assistant_role_name, + user_role=user_role_name, + ) + for _ in range(2) + ] + + if extend_sys_msg_meta_dicts is not None: + sys_msg_meta_dicts = [ + {**sys_msg_meta_dict, **extend_sys_msg_meta_dict} + for sys_msg_meta_dict, extend_sys_msg_meta_dict in zip( + sys_msg_meta_dicts, extend_sys_msg_meta_dicts + ) + ] + + init_assistant_sys_msg, init_user_sys_msg = ( + sys_msg_generator.from_dicts( + meta_dicts=sys_msg_meta_dicts, + role_tuples=[ + (assistant_role_name, RoleType.ASSISTANT), + (user_role_name, RoleType.USER), + ], + ) + ) + return init_assistant_sys_msg, init_user_sys_msg, sys_msg_meta_dicts + + def _init_agents( + self, + init_assistant_sys_msg: BaseMessage, + init_user_sys_msg: BaseMessage, + assistant_agent_kwargs: Optional[Dict] = None, + user_agent_kwargs: Optional[Dict] = None, + output_language: Optional[str] = None, + ) -> None: + r"""Initialize assistant and user agents with their system messages. + + Args: + init_assistant_sys_msg (BaseMessage): Assistant agent's initial + system message. + init_user_sys_msg (BaseMessage): User agent's initial system + message. + assistant_agent_kwargs (Dict, optional): Additional arguments to + pass to the assistant agent. (default: :obj:`None`) + user_agent_kwargs (Dict, optional): Additional arguments to + pass to the user agent. (default: :obj:`None`) + output_language (str, optional): The language to be output by the + agents. (default: :obj:`None`) + """ + if self.model is not None: + if assistant_agent_kwargs is None: + assistant_agent_kwargs = {'model': self.model} + elif 'model' not in assistant_agent_kwargs: + assistant_agent_kwargs.update(dict(model=self.model)) + if user_agent_kwargs is None: + user_agent_kwargs = {'model': self.model} + elif 'model' not in user_agent_kwargs: + user_agent_kwargs.update(dict(model=self.model)) + + self.assistant_agent = ChatAgent( + init_assistant_sys_msg, + output_language=output_language, + **(assistant_agent_kwargs or {}), + ) + self.assistant_sys_msg = self.assistant_agent.system_message + + self.user_agent = ChatAgent( + init_user_sys_msg, + output_language=output_language, + **(user_agent_kwargs or {}), + ) + self.user_sys_msg = self.user_agent.system_message + + def _init_critic( + self, + sys_msg_generator: SystemMessageGenerator, + sys_msg_meta_dicts: List[Dict], + critic_role_name: str, + critic_criteria: Optional[str] = None, + critic_kwargs: Optional[Dict] = None, + ) -> None: + r"""Initialize critic agent. If critic role name is :obj:`"human"`, + create a :obj:`Human` critic agent. Else, create a :obj:`CriticAgent` + critic agent with specified critic criteria. If the critic criteria + is not specified, set it to improve task performance. + + Args: + sys_msg_generator (SystemMessageGenerator): A system message + generator for agents. + sys_msg_meta_dicts (list): A list of system message meta dicts. + critic_role_name (str): The name of the role played by the critic. + critic_criteria (str, optional): Critic criteria for the + critic agent. If not specified, set the criteria to + improve task performance. (default: :obj:`None`) + critic_kwargs (Dict, optional): Additional arguments to + pass to the critic. (default: :obj:`None`) + """ + if self.with_critic_in_the_loop: + if critic_role_name.lower() == "human": + self.critic = Human(**(critic_kwargs or {})) + else: + critic_criteria = ( + critic_criteria or "improving the task performance" + ) + critic_msg_meta_dict = dict( + critic_role=critic_role_name, + criteria=critic_criteria, + **sys_msg_meta_dicts[0], + ) + self.critic_sys_msg = sys_msg_generator.from_dict( + critic_msg_meta_dict, + role_tuple=(critic_role_name, RoleType.CRITIC), + ) + if self.model is not None: + if critic_kwargs is None: + critic_kwargs = {'model': self.model} + elif 'model' not in critic_kwargs: + critic_kwargs.update(dict(model=self.model)) + self.critic = CriticAgent( + self.critic_sys_msg, + **(critic_kwargs or {}), + ) + + def _reduce_message_options( + self, + messages: Sequence[BaseMessage], + ) -> BaseMessage: + r"""Processes a sequence of chat messages, returning the processed + message. If multiple messages are provided and + `with_critic_in_the_loop` is `False`, raises a `ValueError`. + If no messages are provided, a `ValueError` will be raised. + + Args: + messages (Sequence[BaseMessage]): A sequence of `BaseMessage` + objects to process. + + Returns: + BaseMessage: A single `BaseMessage` representing the processed + message. + """ + if len(messages) == 0: + raise ValueError("No messages to process.") + if len(messages) > 1 and not self.with_critic_in_the_loop: + raise ValueError( + "Got than one message to process. " + f"Num of messages: {len(messages)}." + ) + elif self.with_critic_in_the_loop and self.critic is not None: + critic_response = self.critic.reduce_step(messages) + processed_msg = critic_response.msg + else: + processed_msg = messages[0] + + return processed_msg + + def init_chat(self, init_msg_content: Optional[str] = None) -> BaseMessage: + r"""Initializes the chat by resetting both of the assistant and user + agents. Returns an initial message for the role-playing session. + + Args: + init_msg_content (str, optional): A user-specified initial message. + Will be sent to the role-playing session as the initial + message. (default: :obj:`None`) + + Returns: + BaseMessage: A single `BaseMessage` representing the initial + message. + """ + self.assistant_agent.reset() + self.user_agent.reset() + default_init_msg_content = ( + "Now start to give me instructions one by one. " + "Only reply with Instruction and Input." + ) + if init_msg_content is None: + init_msg_content = default_init_msg_content + + # Initialize a message sent by the assistant + init_msg = BaseMessage.make_assistant_message( + role_name=getattr(self.assistant_sys_msg, 'role_name', None) + or "assistant", + content=init_msg_content, + ) + + return init_msg + + async def ainit_chat( + self, init_msg_content: Optional[str] = None + ) -> BaseMessage: + r"""Asynchronously initializes the chat by resetting both of the + assistant and user agents. Returns an initial message for the + role-playing session. + + Args: + init_msg_content (str, optional): A user-specified initial message. + Will be sent to the role-playing session as the initial + message. (default: :obj:`None`) + + Returns: + BaseMessage: A single `BaseMessage` representing the initial + message. + """ + # Currently, reset() is synchronous, but if it becomes async in the + # future, we can await it here + self.assistant_agent.reset() + self.user_agent.reset() + default_init_msg_content = ( + "Now start to give me instructions one by one. " + "Only reply with Instruction and Input." + ) + if init_msg_content is None: + init_msg_content = default_init_msg_content + + # Initialize a message sent by the assistant + init_msg = BaseMessage.make_assistant_message( + role_name=getattr(self.assistant_sys_msg, 'role_name', None) + or "assistant", + content=init_msg_content, + ) + + return init_msg + + def step( + self, + assistant_msg: BaseMessage, + ) -> Tuple[ChatAgentResponse, ChatAgentResponse]: + r"""Advances the conversation by taking a message from the assistant, + processing it using the user agent, and then processing the resulting + message using the assistant agent. Returns a tuple containing the + resulting assistant message, whether the assistant agent terminated + the conversation, and any additional assistant information, as well as + a tuple containing the resulting user message, whether the user agent + terminated the conversation, and any additional user information. + + Args: + assistant_msg: A `BaseMessage` representing the message from the + assistant. + + Returns: + Tuple[ChatAgentResponse, ChatAgentResponse]: A tuple containing two + ChatAgentResponse: the first struct contains the resulting + assistant message, whether the assistant agent terminated the + conversation, and any additional assistant information; the + second struct contains the resulting user message, whether the + user agent terminated the conversation, and any additional user + information. + """ + user_response = self.user_agent.step(assistant_msg) + if user_response.terminated or user_response.msgs is None: + return ( + ChatAgentResponse(msgs=[], terminated=False, info={}), + ChatAgentResponse( + msgs=[], + terminated=user_response.terminated, + info=user_response.info, + ), + ) + user_msg = self._reduce_message_options(user_response.msgs) + + # To prevent recording the same memory more than once (once in chat + # step and once in role play), and the model generates only one + # response when multi-response support is enabled. + if ( + 'n' in self.user_agent.model_backend.model_config_dict.keys() + and self.user_agent.model_backend.model_config_dict['n'] > 1 + ): + self.user_agent.record_message(user_msg) + + assistant_response = self.assistant_agent.step(user_msg) + if assistant_response.terminated or assistant_response.msgs is None: + return ( + ChatAgentResponse( + msgs=[], + terminated=assistant_response.terminated, + info=assistant_response.info, + ), + ChatAgentResponse( + msgs=[user_msg], terminated=False, info=user_response.info + ), + ) + assistant_msg = self._reduce_message_options(assistant_response.msgs) + + # To prevent recording the same memory more than once (once in chat + # step and once in role play), and the model generates only one + # response when multi-response support is enabled. + if ( + 'n' in self.assistant_agent.model_backend.model_config_dict.keys() + and self.assistant_agent.model_backend.model_config_dict['n'] > 1 + ): + self.assistant_agent.record_message(assistant_msg) + + return ( + ChatAgentResponse( + msgs=[assistant_msg], + terminated=assistant_response.terminated, + info=assistant_response.info, + ), + ChatAgentResponse( + msgs=[user_msg], + terminated=user_response.terminated, + info=user_response.info, + ), + ) + + async def astep( + self, + assistant_msg: BaseMessage, + ) -> Tuple[ChatAgentResponse, ChatAgentResponse]: + r"""Asynchronously advances the conversation by taking a message from + the assistant, processing it using the user agent, and then processing + the resulting message using the assistant agent. Returns a tuple + containing the resulting assistant message, whether the assistant + agent terminated the conversation, and any additional assistant + information, as well as a tuple containing the resulting user message, + whether the user agent terminated the conversation, and any additional + user information. + + Args: + assistant_msg: A `BaseMessage` representing the message from the + assistant. + + Returns: + Tuple[ChatAgentResponse, ChatAgentResponse]: A tuple containing two + ChatAgentResponse: the first struct contains the resulting + assistant message, whether the assistant agent terminated the + conversation, and any additional assistant information; the + second struct contains the resulting user message, whether the + user agent terminated the conversation, and any additional user + information. + """ + user_response = await self.user_agent.astep(assistant_msg) + if user_response.terminated or user_response.msgs is None: + return ( + ChatAgentResponse(msgs=[], terminated=False, info={}), + ChatAgentResponse( + msgs=[], + terminated=user_response.terminated, + info=user_response.info, + ), + ) + user_msg = self._reduce_message_options(user_response.msgs) + + # To prevent recording the same memory more than once (once in chat + # step and once in role play), and the model generates only one + # response when multi-response support is enabled. + if ( + 'n' in self.user_agent.model_backend.model_config_dict.keys() + and self.user_agent.model_backend.model_config_dict['n'] > 1 + ): + self.user_agent.record_message(user_msg) + + assistant_response = await self.assistant_agent.astep(user_msg) + if assistant_response.terminated or assistant_response.msgs is None: + return ( + ChatAgentResponse( + msgs=[], + terminated=assistant_response.terminated, + info=assistant_response.info, + ), + ChatAgentResponse( + msgs=[user_msg], terminated=False, info=user_response.info + ), + ) + assistant_msg = self._reduce_message_options(assistant_response.msgs) + + # To prevent recording the same memory more than once (once in chat + # step and once in role play), and the model generates only one + # response when multi-response support is enabled. + if ( + 'n' in self.assistant_agent.model_backend.model_config_dict.keys() + and self.assistant_agent.model_backend.model_config_dict['n'] > 1 + ): + self.assistant_agent.record_message(assistant_msg) + + return ( + ChatAgentResponse( + msgs=[assistant_msg], + terminated=assistant_response.terminated, + info=assistant_response.info, + ), + ChatAgentResponse( + msgs=[user_msg], + terminated=user_response.terminated, + info=user_response.info, + ), + ) \ No newline at end of file diff --git a/benchmarks/experiment_core.py b/benchmarks/experiment_core.py index 9b8a373..568ba60 100644 --- a/benchmarks/experiment_core.py +++ b/benchmarks/experiment_core.py @@ -1,14 +1,14 @@ from typing import Any, Callable -from pydantic.v1 import BaseModel +from pydantic import BaseModel from tqdm import tqdm -from .agents.pure_llm import PureLLM +from .agents.react import ReActAgent AGENT_TYPE_MAPPING_AIOS = { - "swe:llm": PureLLM, - "humaneval:llm": PureLLM, - "gaia:llm": PureLLM, + "swe:react": ReActAgent, + "humaneval:react": ReActAgent, + "gaia:react": ReActAgent, } diff --git a/benchmarks/gaia/inference.py b/benchmarks/gaia/inference.py index 46656be..cdc0de3 100644 --- a/benchmarks/gaia/inference.py +++ b/benchmarks/gaia/inference.py @@ -10,12 +10,14 @@ def write_output_func(result_list: List, output_file: str): with open(output_file, "w", encoding="utf-8") as file: json.dump(result_list, file, ensure_ascii=False, indent=4) - logger.log(f"Write results num: {len(result_list)}", level="info") + # logger.log(f"Write results num: {len(result_list)}", level="info") def process_one_func(data, meta_data: MetaData): - agent: ExperimentAgent = AGENT_TYPE_MAPPING_AIOS[meta_data.agent_type](meta_data.on_aios) - result = agent.run_gaia(data["Question"]) + agent = AGENT_TYPE_MAPPING_AIOS[meta_data.agent_type](meta_data.on_aios) + + # breakpoint() + result = agent.run_gaia(**data) match = re.search(r'FINAL ANSWER: (.+)', result) if match: @@ -43,7 +45,7 @@ def process_one_func(data, meta_data: MetaData): dataset=dataset, agent_type=agent_type, output_file=main_args.output_file, - on_aios=main_args.on_aios, + on_aios=main_args.on_aios # max_num=main_args.max_num, # aios_args=vars(global_args), ) diff --git a/benchmarks/gaia/run_evaluation.py b/benchmarks/gaia/run_evaluation.py index db68277..c92e1b7 100644 --- a/benchmarks/gaia/run_evaluation.py +++ b/benchmarks/gaia/run_evaluation.py @@ -7,28 +7,32 @@ def run_evaluation(input_file: str, output_file: str, data_name: str, split: str): dataset = load_dataset(data_name, "2023_all", split=split) - with open(input_file, "r", encoding="utf-8") as file: - predictions = json.load(file) - - right_num = 0 - error_predictions = [] - for prediction, data in tqdm(zip(predictions, dataset)): - if prediction["result"] == data["Final answer"]: - right_num += 1 - else: - error_predictions.append({ - "task_id": data["task_id"], - "error_answer": prediction["result"], - "right_answer": data["Final answer"], - }) - - with open(output_file, "w", encoding="utf-8") as file: - json.dump(error_predictions, file, ensure_ascii=False, indent=4) - - print(f"Total num: {len(predictions)} \n" - f" Right num: {right_num} \n" - f" Right Rate: {right_num/len(predictions)}" - , level="info") + # with open(input_file, "r", encoding="utf-8") as file: + # predictions = json.load(file) + + # right_num = 0 + # error_predictions = [] + # for prediction, data in tqdm(zip(predictions, dataset)): + # if prediction["result"] == data["Final answer"]: + # right_num += 1 + # else: + # error_predictions.append({ + # "task_id": data["task_id"], + # "error_answer": prediction["result"], + # "right_answer": data["Final answer"], + # }) + + # with open(output_file, "w", encoding="utf-8") as file: + # json.dump(error_predictions, file, ensure_ascii=False, indent=4) + + for data in tqdm(dataset): + answer = data["Final answer"] + breakpoint() + + # print(f"Total num: {len(predictions)} \n" + # f" Right num: {right_num} \n" + # f" Right Rate: {right_num/len(predictions)}" + # , level="info") if __name__ == '__main__': diff --git a/benchmarks/gaia/run_exp.sh b/benchmarks/gaia/run_exp.sh index 0a40057..03e21e8 100644 --- a/benchmarks/gaia/run_exp.sh +++ b/benchmarks/gaia/run_exp.sh @@ -3,9 +3,16 @@ python -m benchmarks.gaia.inference \ --data_name gaia-benchmark/GAIA \ --split validation \ - --output_file benchmarks/gaia/llm_eval_prediction.json \ + --output_file benchmarks/gaia/react_eval_prediction.json \ --on_aios \ - --agent_type llm + --agent_type react + +python -m benchmarks.agents.react \ + --data_name gaia-benchmark/GAIA \ + --split validation \ + --output_file benchmarks/gaia/react_eval_prediction.json \ + --on_aios \ + --agent_type react # Step 2: Run the evaluation script # python -m benchmarks.gaia.inference \ diff --git a/benchmarks/swebench/inference.py b/benchmarks/swebench/inference.py index f36d912..bf5e44e 100644 --- a/benchmarks/swebench/inference.py +++ b/benchmarks/swebench/inference.py @@ -5,7 +5,7 @@ from datasets import load_dataset -from ..agents.pure_llm import PureLLM +from ..agents.react import PureLLM from ..experiment_core import MetaData, AGENT_TYPE_MAPPING_AIOS, run_inference from ..utils import get_parser diff --git a/benchmarks/utils.py b/benchmarks/utils.py index e7451da..dfdae70 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -4,7 +4,7 @@ def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--agent_type", type=str, default="interpreter") parser.add_argument("--data_name", type=str, default="gaia-benchmark/GAIA") - parser.add_argument("--split", type=str, default="test") + parser.add_argument("--split", type=str, default="validation") parser.add_argument("--output_file", type=str, default="prediction.json") parser.add_argument("--on_aios", action="store_true") parser.add_argument("--max_num", type=int, default=None) diff --git a/cerebrum/commands/list_available_llms.py b/cerebrum/commands/list_available_llms.py new file mode 100644 index 0000000..119559b --- /dev/null +++ b/cerebrum/commands/list_available_llms.py @@ -0,0 +1,58 @@ +from cerebrum.llm.apis import list_available_llms + +from rich.console import Console +from rich.table import Table +from rich.panel import Panel +from rich.text import Text +from rich.box import ROUNDED + +import sys + +def list_agenthub_agents(): + console = Console() + + with console.status("[bold green]Listing available LLMs..."): + llms = list_available_llms() + + if not llms: + console.print(Panel("[bold yellow]No LLMs found", title="LLM List")) + return + + # Create a table with row separators and rounded borders + table = Table( + title="Available Agents in AgentHub", + box=ROUNDED, + show_header=True, + header_style="bold white on blue", + show_lines=True, # This adds horizontal lines between rows + ) + + # Add columns to the table with adjusted widths + table.add_column("Name", style="cyan bold", no_wrap=True) + table.add_column("Backend", style="green", width=40, overflow="fold") + table.add_column("Hostname", style="blue", no_wrap=True) + + # Add rows to the table + for llm in llms: + name = llm.get("name", "N/A") + backend = llm.get("backend", "N/A") + hostname = llm.get("hostname", "N/A") + + table.add_row(name, backend, hostname) + + # Print the table + console.print("\n") # Add some space before the table + console.print(table) + + # Print summary + summary = Text() + summary.append(f"\nTotal LLMs available: ", style="bold") + summary.append(f"{len(llms)}", style="bold green") + console.print(summary) + console.print("\n") # Add some space after the summary + +def main(): + list_agenthub_agents() + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/cerebrum/example/agents/academic_agent/agent.py b/cerebrum/example/agents/academic_agent/agent.py index da5400a..4d2228f 100644 --- a/cerebrum/example/agents/academic_agent/agent.py +++ b/cerebrum/example/agents/academic_agent/agent.py @@ -159,7 +159,7 @@ def run(self, task_input): else: selected_tools = None - breakpoint() + # breakpoint() if action_type == "call_tool": response = llm_call_tool( diff --git a/cerebrum/example/agents/academic_agent/config.json b/cerebrum/example/agents/academic_agent/config.json index 9eeabc0..53e04d8 100644 --- a/cerebrum/example/agents/academic_agent/config.json +++ b/cerebrum/example/agents/academic_agent/config.json @@ -9,7 +9,7 @@ ], "meta": { "author": "example", - "version": "1.1.5", + "version": "1.1.7", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/autogen_demo_agent/config.json b/cerebrum/example/agents/autogen_demo_agent/config.json index acd8a76..d9bb712 100644 --- a/cerebrum/example/agents/autogen_demo_agent/config.json +++ b/cerebrum/example/agents/autogen_demo_agent/config.json @@ -4,7 +4,7 @@ "tools": [], "meta": { "author": "autogen", - "version": "0.0.3", + "version": "0.0.4", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/browser_use_agent/agent.py b/cerebrum/example/agents/browser_use_agent/agent.py new file mode 100644 index 0000000..4df98f9 --- /dev/null +++ b/cerebrum/example/agents/browser_use_agent/agent.py @@ -0,0 +1,681 @@ +from typing import List, Dict, Any, Optional, Literal, Tuple, Union +import asyncio +import datetime +import io +import json +import os +import random +import re +import shutil +import time +import urllib.parse + +from cerebrum.llm.apis import llm_chat, llm_chat_with_json_output, llm_chat_with_tool_call_output +from cerebrum.utils import _parse_json_output +from cerebrum.utils.browser import BaseBrowser, _reload_image + +AVAILABLE_ACTIONS_PROMPT = """ +1. `fill_input_id(identifier: Union[str, int], text: str)`: Fill an input +field (e.g. search box) with the given text and press Enter. +2. `click_id(identifier: Union[str, int])`: Click an element with the given ID. +3. `hover_id(identifier: Union[str, int])`: Hover over an element with the +given ID. +4. `download_file_id(identifier: Union[str, int])`: Download a file with the +given ID. It returns the path to the downloaded file. If the file is +successfully downloaded, you can stop the simulation and report the path to +the downloaded file for further processing. +5. `scroll_to_bottom()`: Scroll to the bottom of the page. +6. `scroll_to_top()`: Scroll to the top of the page. +7. `scroll_up()`: Scroll up the page. It is suitable when you want to see the +elements above the current viewport. +8. `scroll_down()`: Scroll down the page. It is suitable when you want to see +the elements below the current viewport. If the webpage does not change, It +means that the webpage has scrolled to the bottom. +9. `back()`: Navigate back to the previous page. This is useful when you want +to go back to the previous page, as current page is not useful. +10. `stop()`: Stop the action process, because the task is completed or failed +(impossible to find the answer). In this situation, you should provide your +answer in your output. +11. `get_url()`: Get the current URL of the current page. +12. `find_text_on_page(search_text: str)`: Find the next given text on the +current whole page, and scroll the page to the targeted text. It is equivalent +to pressing Ctrl + F and searching for the text, and is powerful when you want +to fast-check whether the current page contains some specific text. +13. `visit_page(url: str)`: Go to the specific url page. +14. `click_blank_area()`: Click a blank area of the page to unfocus the +current element. It is useful when you have clicked an element but it cannot +unfocus itself (e.g. Menu bar) to automatically render the updated webpage. +15. `ask_question_about_video(question: str)`: Ask a question about the +current webpage which contains video, e.g. youtube websites. +""" + +ACTION_WITH_FEEDBACK_LIST = [ + 'ask_question_about_video', + 'download_file_id', + 'find_text_on_page', +] + +class BrowserUseAgent: + r"""A class for browsing the web and interacting with web pages. + + This class provides methods for browsing the web and interacting with web + pages. + """ + + def __init__( + self, + headless: bool = False, + cache_dir: Optional[str] = None, + channel: Literal["chrome", "msedge", "chromium"] = "chromium", + history_window: int = 5, + ): + r"""Initialize the BrowserToolkit instance. + + Args: + headless (bool): Whether to run the browser in headless mode. + cache_dir (Union[str, None]): The directory to store cache files. + channel (Literal["chrome", "msedge", "chromium"]): The browser + channel to use. Must be one of "chrome", "msedge", or + "chromium". + history_window (int): The window size for storing the history of + actions. + """ + self.name = "browser_use_agent" + self.browser = BaseBrowser( + headless=headless, cache_dir=cache_dir, channel=channel + ) + self.history_window = history_window + self.history: list = [] + self.web_agent, self.planning_agent = self._initialize_agent() + + def _reset(self): + self.web_agent.reset() + self.planning_agent.reset() + self.history = [] + os.makedirs(self.browser.cache_dir, exist_ok=True) + + def _initialize_agent(self): + r"""Initialize the agent.""" + class WebAgent: + def __init__(self, system_prompt: str): + self.name = "web_agent" + self.system_prompt = system_prompt + self.messages = [ + {"role": "system", "content": system_prompt} + ] + + def step(self, message, response_format=None, tools=None): + self.messages.append(message) + llms = [ + { + "name": "gpt-4o", + "backend": "openai", + } + ] + if tools: + response = llm_chat_with_tool_call_output( + agent_name=self.name, + messages=self.messages, + llms=llms, + tools=tools + )["response"] + elif response_format: + response = llm_chat_with_json_output( + agent_name=self.name, + messages=self.messages, + llms=llms, + response_format=response_format + )["response"] + else: + response = llm_chat( + agent_name=self.name, + messages=self.messages, + llms=llms, + )["response"] + return response + + def reset(self): + self.messages = [ + {"role": "system", "content": self.system_prompt} + ] + + class PlanningAgent: + def __init__(self, system_prompt: str): + self.name = "planning_agent" + self.system_prompt = system_prompt + self.messages = [ + {"role": "system", "content": system_prompt} + ] + + def step(self, message, response_format=None, tools=None): + self.messages.append(message) + llms = [ + { + "name": "gpt-4o", + "backend": "openai", + } + ] + if tools: + response = llm_chat_with_tool_call_output( + agent_name=self.name, + messages=self.messages, + llms=llms, + tools=tools + )["response"] + elif response_format: + response = llm_chat_with_json_output( + agent_name=self.name, + messages=self.messages, + llms=llms, + response_format=response_format + )["response"] + else: + response = llm_chat( + agent_name=self.name, + messages=self.messages, + llms=llms, + )["response"] + return response + + def reset(self): + self.messages = [ + {"role": "system", "content": self.system_prompt} + ] + + system_prompt = """ +You are a helpful web agent that can assist users in browsing the web. +Given a high-level task, you can leverage predefined browser tools to help +users achieve their goals. + """ + + web_agent = WebAgent(system_prompt) + + planning_system_prompt = """ +You are a helpful planning agent that can assist users in planning complex +tasks which need multi-step browser interaction. + """ + + planning_agent = PlanningAgent(planning_system_prompt) + + return web_agent, planning_agent + + def convert_message(self, message, img=None): + import base64 + from io import BytesIO + + if img is not None: + # Convert PIL Image to base64 + buffered = BytesIO() + img.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return { + "role": "user", + "content": [ + { + "type": "text", + "text": message, + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{img_str}" + } + } + ] + } + else: + return { + "role": "user", + "content": message + } + + def _observe( + self, task_prompt: str, detailed_plan: Optional[str] = None + ) -> Tuple[str, str, str]: + r"""Let agent observe the current environment, and get the next action.""" + + detailed_plan_prompt = "" + + if detailed_plan is not None: + detailed_plan_prompt = f""" +Here is a plan about how to solve the task step-by-step which you must follow: +{detailed_plan} + """ + + observe_prompt = f""" +Please act as a web agent to help me complete the following high-level task: +{task_prompt} +Now, I have made screenshot (only the current viewport, not the full webpage) +based on the current browser state, and marked interactive elements in the +webpage. +Please carefully examine the requirements of the task, and current state of +the browser, and provide the next appropriate action to take. + +{detailed_plan_prompt} + +Here are the current available browser functions you can use: +{AVAILABLE_ACTIONS_PROMPT} + +Here are the latest {self.history_window} trajectory (at most) you have taken: + +{self.history[-self.history_window:]} + + +Your output should be in json format, including the following fields: +- `observation`: The detailed image description about the current viewport. Do +not over-confident about the correctness of the history actions. You should +always check the current viewport to make sure the correctness of the next +action. +- `reasoning`: The reasoning about the next action you want to take, and the +possible obstacles you may encounter, and how to solve them. Do not forget to +check the history actions to avoid the same mistakes. +- `action_code`: The action code you want to take. It is only one step action +code, without any other texts (such as annotation) + +Here is two example of the output: +```json +{{ + "observation": [IMAGE_DESCRIPTION], + "reasoning": [YOUR_REASONING], + "action_code": "fill_input_id([ID], [TEXT])" +}} + +{{ + "observation": "The current page is a CAPTCHA verification page on Amazon. It asks the user to ..", + "reasoning": "To proceed with the task of searching for products, I need to complete..", + "action_code": "fill_input_id(3, 'AUXPMR')" +}} + +Here are some tips for you: +- Never forget the overall question: **{task_prompt}** +- Maybe after a certain operation (e.g. click_id), the page content has not +changed. You can check whether the action step is successful by looking at the +`success` of the action step in the history. If successful, it means that the +page content is indeed the same after the click. You need to try other methods. +- If using one way to solve the problem is not successful, try other ways. +Make sure your provided ID is correct! +- Some cases are very complex and need to be achieve by an iterative process. +You can use the `back()` function to go back to the previous page to try other +methods. +- There are many links on the page, which may be useful for solving the +problem. You can use the `click_id()` function to click on the link to see if +it is useful. +- Always keep in mind that your action must be based on the ID shown in the +current image or viewport, not the ID shown in the history. +- Do not use `stop()` lightly. Always remind yourself that the image only +shows a part of the full page. If you cannot find the answer, try to use +functions like `scroll_up()` and `scroll_down()` to check the full content of +the webpage before doing anything else, because the answer or next key step +may be hidden in the content below. +- If the webpage needs human verification, you must avoid processing it. +Please use `back()` to go back to the previous page, and try other ways. +- If you have tried everything and still cannot resolve the issue, please stop +the simulation, and report issues you have encountered. +- Check the history actions carefully, detect whether you have repeatedly made +the same actions or not. +- When dealing with wikipedia revision history related tasks, you need to +think about the solution flexibly. First, adjust the browsing history +displayed on a single page to the maximum, and then make use of the +find_text_on_page function. This is extremely useful which can quickly locate +the text you want to find and skip massive amount of useless information. +- Flexibly use interactive elements like slide down selection bar to filter +out the information you need. Sometimes they are extremely useful. +""" + + som_screenshot, _ = self.browser.get_som_screenshot(save_image=True) + img = _reload_image(som_screenshot) + message = self.convert_message(observe_prompt, img) + self.web_agent.reset() + + response_format = { + "type": "json_schema", + "json_schema": { + "name": "react", + "schema": { + "type": "object", + "properties": { + "observation": {"type": "string"}, + "reasoning": {"type": "string"}, + "action_code": {"type": "string"} + }, + "required": ["observation", "reasoning", "action_code"] + }, + "strict": True + } + } + + response = self.web_agent.step(message, response_format=response_format)["response_message"] + response = _parse_json_output(response) + observation_result: str = response.get("observation", "") + reasoning_result: str = response.get("reasoning", "") + action_code: str = response.get("action_code", "") + + if action_code and "(" in action_code and ")" not in action_code: + action_match = re.search( + r'"action_code"\s*:\s*[`"]([^`"]*\([^)]*\))[`"]', response + ) + if action_match: + action_code = action_match.group(1) + else: + print( + f"Incomplete action_code detected: {action_code}" + ) + if action_code.startswith("fill_input_id("): + parts = action_code.split(",", 1) + if len(parts) > 1: + id_part = ( + parts[0].replace("fill_input_id(", "").strip() + ) + action_code = f"fill_input_id({id_part}, 'Please fill the text here.')" + + action_code = action_code.replace("`", "").strip() + + return observation_result, reasoning_result, action_code + + def _act(self, action_code: str) -> Tuple[bool, str]: + r"""Let agent act based on the given action code. + Args: + action_code (str): The action code to act. + + Returns: + Tuple[bool, str]: A tuple containing a boolean indicating whether + the action was successful, and the information to be returned. + """ + + def _check_if_with_feedback(action_code: str) -> bool: + r"""Check if the action code needs feedback.""" + + for action_with_feedback in ACTION_WITH_FEEDBACK_LIST: + if action_with_feedback in action_code: + return True + + return False + + def _fix_action_code(action_code: str) -> str: + r"""Fix potential missing quotes in action code""" + + match = re.match(r'(\w+)\((.*)\)', action_code) + if not match: + return action_code + + func_name, args_str = match.groups() + + args = [] + current_arg = "" + in_quotes = False + quote_char = None + + for char in args_str: + if char in ['"', "'"]: + if not in_quotes: + in_quotes = True + quote_char = char + current_arg += char + elif char == quote_char: + in_quotes = False + quote_char = None + current_arg += char + else: + current_arg += char + elif char == ',' and not in_quotes: + args.append(current_arg.strip()) + current_arg = "" + else: + current_arg += char + + if current_arg: + args.append(current_arg.strip()) + + fixed_args = [] + for arg in args: + if ( + (arg.startswith('"') and arg.endswith('"')) + or (arg.startswith("'") and arg.endswith("'")) + or re.match(r'^-?\d+(\.\d+)?$', arg) + or re.match(r'^-?\d+\.?\d*[eE][-+]?\d+$', arg) + or re.match(r'^0[xX][0-9a-fA-F]+$', arg) + ): + fixed_args.append(arg) + else: + fixed_args.append(f"'{arg}'") + + return f"{func_name}({', '.join(fixed_args)})" + + action_code = _fix_action_code(action_code) + prefix = "self.browser." + code = f"{prefix}{action_code}" + + try: + if _check_if_with_feedback(action_code): + # execute code, and get the executed result + result = eval(code) + time.sleep(1) + return True, result + else: + exec(code) + time.sleep(1) + return True, "Action was successful." + + except Exception as e: + time.sleep(1) + return ( + False, + f"Error while executing the action {action_code}: {e}. " + f"If timeout, please recheck whether you have provided the " + f"correct identifier.", + ) + + def _get_final_answer(self, task_prompt: str) -> str: + r"""Get the final answer based on the task prompt and current browser state. + It is used when the agent thinks that the task can be completed without any further action, and answer can be directly found in the current viewport. + """ + + prompt = f""" +We are solving a complex web task which needs multi-step browser interaction. After the multi-step observation, reasoning and acting with web browser, we think that the task is currently solved. +Here are all trajectory we have taken: +{self.history} +Please find the final answer, or give valuable insights and founds (e.g. if previous actions contain downloading files, your output should include the path of the downloaded file) about the overall task: {task_prompt} + """ + + message = { + "role": "user", + "content": prompt + } + + response = self.web_agent.step(message) + return response + + def _make_reflection(self, task_prompt: str) -> str: + r"""Make a reflection about the current state and the task prompt.""" + + reflection_prompt = f""" +Now we are working on a complex task that requires multi-step browser interaction. The task is: {task_prompt} +To achieve this goal, we have made a series of observations, reasonings, and actions. We have also made a reflection on previous states. + +Here are the global available browser functions we can use: +{AVAILABLE_ACTIONS_PROMPT} + +Here are the latest {self.history_window} trajectory (at most) we have taken: +{self.history[-self.history_window:]} + +The image provided is the current state of the browser, where we have marked interactive elements. +Please carefully examine the requirements of the task, and the current state of the browser, and then make reflections on the previous steps, thinking about whether they are helpful or not, and why, offering detailed feedback and suggestions for the next steps. +Your output should be in json format, including the following fields: +- `reflection`: The reflection about the previous steps, thinking about whether they are helpful or not, and why, offering detailed feedback. +- `suggestion`: The suggestion for the next steps, offering detailed suggestions, including the common solutions to the overall task based on the current state of the browser. + """ + som_image, _ = self.browser.get_som_screenshot() + img = _reload_image(som_image) + + message = self.convert_message(reflection_prompt, img) + + response = self.web_agent.step(message) + + return response + + def _task_planning(self, task_prompt: str, start_url: str) -> str: + r"""Plan the task based on the given task prompt.""" + + planning_prompt = f""" +{task_prompt} +According to the problem above, if we use browser interaction, what is the general process of the interaction after visiting the webpage `{start_url}`? + +Please note that it can be viewed as Partially Observable MDP. Do not over-confident about your plan. +Please first restate the task in detail, and then provide a detailed plan to solve the task. +""" + + message = self.convert_message(planning_prompt) + + response = self.planning_agent.step(message) + return response["response_message"] + + def _task_replanning( + self, task_prompt: str, detailed_plan: str + ) -> Tuple[bool, str]: + r"""Replan the task based on the given task prompt. + + Args: + task_prompt (str): The original task prompt. + detailed_plan (str): The detailed plan to replan. + + Returns: + Tuple[bool, str]: A tuple containing a boolean indicating whether the task needs to be replanned, and the replanned schema. + """ + + replanning_prompt = f""" +We are using browser interaction to solve a complex task which needs multi-step actions. +Here are the overall task: +{task_prompt} + +In order to solve the task, we made a detailed plan previously. Here is the detailed plan: +{detailed_plan} + +According to the task above, we have made a series of observations, reasonings, and actions. Here are the latest {self.history_window} trajectory (at most) we have taken: +{self.history[-self.history_window:]} + +However, the task is not completed yet. As the task is partially observable, we may need to replan the task based on the current state of the browser if necessary. +Now please carefully examine the current task planning schema, and our history actions, and then judge whether the task needs to be fundamentally replanned. If so, please provide a detailed replanned schema (including the restated overall task). + +Your output should be in json format, including the following fields: +- `if_need_replan`: bool, A boolean value indicating whether the task needs to be fundamentally replanned. +- `replanned_schema`: str, The replanned schema for the task, which should not be changed too much compared with the original one. If the task does not need to be replanned, the value should be an empty string. +""" + self.planning_agent.reset() + + response_format = { + "type": "json_schema", + "json_schema": { + "name": "react", + "schema": { + "type": "object", + "properties": { + "if_need_replan": {"type": "boolean"}, + "replanned_schema": {"type": "string"} + }, + "required": ["if_need_replan", "replanned_schema"] + } + } + } + + response = self.planning_agent.step(replanning_prompt, response_format=response_format) + + resp_dict = _parse_json_output(response) + if_need_replan = resp_dict.get("if_need_replan", False) + replanned_schema = resp_dict.get("replanned_schema", "") + + if if_need_replan: + return True, replanned_schema + else: + return False, replanned_schema + + def run( + self, task_input: str, start_url: str, round_limit: int = 50 + ) -> str: + r"""A powerful toolkit which can simulate the browser interaction to solve the task which needs multi-step actions. + + Args: + task_prompt (str): The task prompt to solve. + start_url (str): The start URL to visit. + round_limit (int): The round limit to solve the task. + (default: :obj:`12`). + + Returns: + str: The simulation result to the task. + """ + + self._reset() + task_completed = False + + detailed_plan = self._task_planning(task_input, start_url) + print(f"Detailed plan: {detailed_plan}") + + self.browser.init() + self.browser.visit_page(start_url) + + rounds = 0 + + while rounds < round_limit: + observation, reasoning, action_code = self._observe( + task_input, detailed_plan + ) + print(f"Observation: {observation}") + print(f"Reasoning: {reasoning}") + print(f"Action code: {action_code}") + + if "stop" in action_code: + task_completed = True + trajectory_info = { + "round": rounds, + "observation": observation, + "thought": reasoning, + "action": action_code, + "action_if_success": True, + "info": None, + "current_url": self.browser.get_url(), + } + self.history.append(trajectory_info) + break + + else: + success, info = self._act(action_code) + if not success: + print(f"Error while executing the action: {info}") + + trajectory_info = { + "round": rounds, + "observation": observation, + "thought": reasoning, + "action": action_code, + "action_if_success": success, + "info": info, + "current_url": self.browser.get_url(), + } + self.history.append(trajectory_info) + + rounds += 1 + + if not task_completed: + simulation_result = f""" + The task is not completed within the round limit. Please check the last round {self.history_window} information to see if there is any useful information: + {self.history[-self.history_window:]} + """ + + else: + simulation_result = self._get_final_answer(task_input) + + self.browser.close() + return { + "agent_name": self.name, + "rounds": rounds, + "result": simulation_result + } + + +def main(): + agent = BrowserUseAgent() + task_input = "What is the densest material on the moon?" + start_url = "https://www.wikipedia.org" + agent.run(task_input, start_url) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/cerebrum/example/agents/browser_use_agent/config.json b/cerebrum/example/agents/browser_use_agent/config.json new file mode 100644 index 0000000..a200456 --- /dev/null +++ b/cerebrum/example/agents/browser_use_agent/config.json @@ -0,0 +1,17 @@ +{ + "name": "browser_use_agent", + "description": [ + "You are a browser use agent. You can automate the browser to obtain information. " + ], + "tools": [ + ], + "meta": { + "author": "example", + "version": "0.0.1", + "license": "CC0" + }, + "build": { + "entry": "agent.py", + "module": "BrowserUseAgent" + } +} diff --git a/cerebrum/example/agents/browser_use_agent/meta_requirements.txt b/cerebrum/example/agents/browser_use_agent/meta_requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/cerebrum/example/agents/browser_use_agent/page_script.js b/cerebrum/example/agents/browser_use_agent/page_script.js new file mode 100644 index 0000000..8318dae --- /dev/null +++ b/cerebrum/example/agents/browser_use_agent/page_script.js @@ -0,0 +1,376 @@ +var MultimodalWebSurfer = MultimodalWebSurfer || (function() { + let nextLabel = 10; + + let roleMapping = { + "a": "link", + "area": "link", + "button": "button", + "input, type=button": "button", + "input, type=checkbox": "checkbox", + "input, type=email": "textbox", + "input, type=number": "spinbutton", + "input, type=radio": "radio", + "input, type=range": "slider", + "input, type=reset": "button", + "input, type=search": "searchbox", + "input, type=submit": "button", + "input, type=tel": "textbox", + "input, type=text": "textbox", + "input, type=url": "textbox", + "search": "search", + "select": "combobox", + "option": "option", + "textarea": "textbox" + }; + + let getCursor = function(elm) { + return window.getComputedStyle(elm)["cursor"]; + }; + + let getInteractiveElements = function() { + + let results = [] + let roles = ["scrollbar", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem", "button", "checkbox", "gridcell", "link", "menuitem", "menuitemcheckbox", "menuitemradio", "option", "progressbar", "radio", "textbox", "combobox", "menu", "tree", "treegrid", "grid", "listbox", "radiogroup", "widget"]; + let inertCursors = ["auto", "default", "none", "text", "vertical-text", "not-allowed", "no-drop"]; + + // Get the main interactive elements + let nodeList = document.querySelectorAll("input, select, textarea, button, [href], [onclick], [contenteditable], [tabindex]:not([tabindex='-1'])"); + for (let i=0; i -1) { + results.push(nodeList[i]); + } + } + } + + // Any element that changes the cursor to something implying interactivity + nodeList = document.querySelectorAll("*"); + for (let i=0; i= 0) { + continue; + } + + // Move up to the first instance of this cursor change + parent = node.parentNode; + while (parent && getCursor(parent) == cursor) { + node = parent; + parent = node.parentNode; + } + + // Add the node if it is new + if (results.indexOf(node) == -1) { + results.push(node); + } + } + + return results; + }; + + let labelElements = function(elements) { + for (let i=0; i= 1; + + let record = { + "tag_name": ariaRole[1], + "role": ariaRole[0], + "aria-name": ariaName, + "v-scrollable": vScrollable, + "rects": [] + }; + + for (const rect of rects) { + let x = rect.left + rect.width/2; + let y = rect.top + rect.height/2; + if (isTopmost(elements[i], x, y)) { + record["rects"].push(JSON.parse(JSON.stringify(rect))); + } + } + + if (record["rects"].length > 0) { + results[key] = record; + } + } + return results; + }; + + let getVisualViewport = function() { + let vv = window.visualViewport; + let de = document.documentElement; + return { + "height": vv ? vv.height : 0, + "width": vv ? vv.width : 0, + "offsetLeft": vv ? vv.offsetLeft : 0, + "offsetTop": vv ? vv.offsetTop : 0, + "pageLeft": vv ? vv.pageLeft : 0, + "pageTop": vv ? vv.pageTop : 0, + "scale": vv ? vv.scale : 0, + "clientWidth": de ? de.clientWidth : 0, + "clientHeight": de ? de.clientHeight : 0, + "scrollWidth": de ? de.scrollWidth : 0, + "scrollHeight": de ? de.scrollHeight : 0 + }; + }; + + let _getMetaTags = function() { + let meta = document.querySelectorAll("meta"); + let results = {}; + for (let i = 0; i { + addValue(information, propName, childInfo); + }); + } + + } else if (child.hasAttribute('itemprop')) { + const itemProp = child.getAttribute('itemprop'); + itemProp.split(' ').forEach(propName => { + if (propName === 'url') { + addValue(information, propName, child.href); + } else { + addValue(information, propName, sanitize(child.getAttribute("content") || child.content || child.textContent || child.src || "")); + } + }); + traverseItem(child, information); + } else { + traverseItem(child, information); + } + } + } + + const microdata = []; + + document.querySelectorAll("[itemscope]").forEach(function(elem, i) { + const itemType = elem.getAttribute('itemtype'); + const information = { + itemType: itemType + }; + traverseItem(elem, information); + microdata.push(information); + }); + + return microdata; + }; + + let getPageMetadata = function() { + let jsonld = _getJsonLd(); + let metaTags = _getMetaTags(); + let microdata = _getMicrodata(); + let results = {} + if (jsonld.length > 0) { + try { + results["jsonld"] = JSON.parse(jsonld); + } + catch (e) { + results["jsonld"] = jsonld; + } + } + if (microdata.length > 0) { + results["microdata"] = microdata; + } + for (let key in metaTags) { + if (metaTags.hasOwnProperty(key)) { + results["meta_tags"] = metaTags; + break; + } + } + return results; + }; + + return { + getInteractiveRects: getInteractiveRects, + getVisualViewport: getVisualViewport, + getFocusedElementId: getFocusedElementId, + getPageMetadata: getPageMetadata, + }; + })(); \ No newline at end of file diff --git a/cerebrum/example/agents/calculator_agent/agent.py b/cerebrum/example/agents/calculator_agent/agent.py new file mode 100644 index 0000000..352de47 --- /dev/null +++ b/cerebrum/example/agents/calculator_agent/agent.py @@ -0,0 +1,62 @@ +from cerebrum.tool.mcp_tool import MCPPool, MCPClient +from typing import List, Dict, Any +import asyncio + +from dotenv import load_dotenv + +load_dotenv() + +from cerebrum.llm.apis import llm_chat_with_tool_call_output + +class CalculatorAgent: + def __init__(self): + self.mcp_pool = MCPPool() + self.description = "Use calculator for precise numerical calculations." + + async def initialize(self): + calculator_client = MCPClient.from_smithery( + pkg_name="@githejie/mcp-server-calculator", + description="Use calculator for precise numerical calculations.", + ) + self.mcp_pool.add_mcp_client("calculator", calculator_client) + await self.mcp_pool.start() + + def get_tool_information(self) -> List[Dict[str, Any]]: + """Get all tool information for this worker""" + return self.tool_information + + def get_tool_hints(self) -> str: + """Get formatted tool hints for this worker""" + hints = "" + for tool_info in self.tool_information: + hint = tool_info['hint'] + hints += f"- {hint}\n" + return hints + + async def run(self, task_input: str) -> Dict[str, Any]: + """Execute shell commands using the code-executor MCP""" + # This is a placeholder - implement actual shell command execution here + tool_information = await self.get_all_tool_information() + tool_hints = self.get_tool_hints(tool_information) + tool_schemas = self.get_all_tool_schemas(tool_information) + + tool_calls = llm_chat_with_tool_call_output( + model="gpt-4o", + messages=[{"content": task_input, "role": "user"}], + tool_schemas=tool_schemas, + )["response"]["tool_calls"] + + result = "" + + for tool_call in tool_calls: + tool_name = tool_call["name"] + tool_args = tool_call["args"] + tool_result = await self.mcp_pool.clients[tool_name].execute(tool_args) + result += tool_result + + return result + + async def cleanup(self): + """Cleanup resources""" + await self.mcp_pool.stop() + diff --git a/cerebrum/example/agents/calculator_agent/config.json b/cerebrum/example/agents/calculator_agent/config.json new file mode 100644 index 0000000..32feda0 --- /dev/null +++ b/cerebrum/example/agents/calculator_agent/config.json @@ -0,0 +1,17 @@ +{ + "name": "calculator_agent", + "description": [ + "You are a calculator agent. You can use the calculator to calculate the result. " + ], + "tools": [ + ], + "meta": { + "author": "example", + "version": "0.0.1", + "license": "CC0" + }, + "build": { + "entry": "agent.py", + "module": "CalculatorAgent" + } +} diff --git a/cerebrum/example/agents/calculator_agent/meta_requirements.txt b/cerebrum/example/agents/calculator_agent/meta_requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/cerebrum/example/agents/cocktail_mixlogist/config.json b/cerebrum/example/agents/cocktail_mixlogist/config.json index f8ad7e9..1aef377 100644 --- a/cerebrum/example/agents/cocktail_mixlogist/config.json +++ b/cerebrum/example/agents/cocktail_mixlogist/config.json @@ -9,7 +9,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/code_executor/agent.py b/cerebrum/example/agents/code_executor/agent.py new file mode 100644 index 0000000..31a9e95 --- /dev/null +++ b/cerebrum/example/agents/code_executor/agent.py @@ -0,0 +1,58 @@ +from cerebrum.tool.mcp_tool import MCPPool, MCPClient +from typing import List, Dict, Any +import asyncio +from cerebrum.llm.apis import llm_chat_with_tool_call_output + +class CodeExecutor: + def __init__(self): + self.mcp_pool = MCPPool() + self.description = "Execute shell commands, analyze code, and manage files seamlessly" + + + async def initialize(self): + code_executor_client = MCPClient.from_smithery( + pkg_name="@auchenberg/claude-code-mcp", + description="Execute shell commands, analyze code, and manage files seamlessly", + ) + self.mcp_pool.add_mcp_client("code-executor", code_executor_client) + await self.mcp_pool.start() + + def get_tool_information(self) -> List[Dict[str, Any]]: + """Get all tool information for this worker""" + return self.tool_information + + def get_tool_hints(self) -> str: + """Get formatted tool hints for this worker""" + hints = "" + for tool_info in self.tool_information: + hint = tool_info['hint'] + hints += f"- {hint}\n" + return hints + + async def run(self, task_input: str) -> Dict[str, Any]: + """Execute shell commands using the code-executor MCP""" + # Implement shell command execution logic using code_executor_client + # This is a placeholder - implement actual shell command execution here + tool_information = await self.get_all_tool_information() + tool_hints = self.get_tool_hints(tool_information) + tool_schemas = self.get_all_tool_schemas(tool_information) + + tool_calls = llm_chat_with_tool_call_output( + model="gpt-4o", + messages=[{"content": task_input, "role": "user"}], + tool_schemas=tool_schemas, + )["response"]["tool_calls"] + + result = "" + + for tool_call in tool_calls: + tool_name = tool_call["name"] + tool_args = tool_call["args"] + tool_result = await self.mcp_pool.clients[tool_name].execute(tool_args) + result += tool_result + + return result + + async def cleanup(self): + """Cleanup resources""" + await self.mcp_pool.stop() diff --git a/cerebrum/example/agents/code_executor/config.json b/cerebrum/example/agents/code_executor/config.json new file mode 100644 index 0000000..18d574f --- /dev/null +++ b/cerebrum/example/agents/code_executor/config.json @@ -0,0 +1,17 @@ +{ + "name": "code_executor", + "description": [ + "You are a code executor. You can execute code and return the result. " + ], + "tools": [ + ], + "meta": { + "author": "example", + "version": "0.0.1", + "license": "CC0" + }, + "build": { + "entry": "agent.py", + "module": "CodeExecutor" + } +} diff --git a/cerebrum/example/agents/code_executor/meta_requirements.txt b/cerebrum/example/agents/code_executor/meta_requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/cerebrum/example/agents/creation_agent/config.json b/cerebrum/example/agents/creation_agent/config.json index 9c4ddf1..31a89ee 100644 --- a/cerebrum/example/agents/creation_agent/config.json +++ b/cerebrum/example/agents/creation_agent/config.json @@ -9,7 +9,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/demo_agent/config.json b/cerebrum/example/agents/demo_agent/config.json index 5157bd3..5874ce6 100644 --- a/cerebrum/example/agents/demo_agent/config.json +++ b/cerebrum/example/agents/demo_agent/config.json @@ -9,7 +9,7 @@ ], "meta": { "author": "demo_author", - "version": "0.0.1", + "version": "0.0.5", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/festival_card_designer/config.json b/cerebrum/example/agents/festival_card_designer/config.json index 5bba2d0..d913206 100644 --- a/cerebrum/example/agents/festival_card_designer/config.json +++ b/cerebrum/example/agents/festival_card_designer/config.json @@ -9,7 +9,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/language_tutor/config.json b/cerebrum/example/agents/language_tutor/config.json index 014c79b..24254f2 100644 --- a/cerebrum/example/agents/language_tutor/config.json +++ b/cerebrum/example/agents/language_tutor/config.json @@ -8,7 +8,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/logo_creator/config.json b/cerebrum/example/agents/logo_creator/config.json index d0c4ffc..75da211 100644 --- a/cerebrum/example/agents/logo_creator/config.json +++ b/cerebrum/example/agents/logo_creator/config.json @@ -9,7 +9,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/math_agent/config.json b/cerebrum/example/agents/math_agent/config.json index 4c897de..218bb43 100644 --- a/cerebrum/example/agents/math_agent/config.json +++ b/cerebrum/example/agents/math_agent/config.json @@ -9,7 +9,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/mcp_browser_use/agent.py b/cerebrum/example/agents/mcp_browser_use/agent.py deleted file mode 100644 index 75333ec..0000000 --- a/cerebrum/example/agents/mcp_browser_use/agent.py +++ /dev/null @@ -1,48 +0,0 @@ -from cerebrum.tool.mcp_tool import mcp_pool -from cerebrum.llm.apis import llm_chat_with_tool_call_output - -import asyncio -# print(mcp.name) - -async def main(): - await mcp_pool.start() - - # playwright_client = mcp_pool.get_mcp_client("playwright") - clients = mcp_pool.get_all_mcp_clients() - - tool_hints = [await client.hint() for client in clients] - - tool_schemas = [await client.tool_schemas() for client in clients] - - messages = [ - {"role": "user", "content": "search for elon musk's twitter account"}, - ] - - breakpoint() - - print(tool_hints) - - find_tools = {} - for client in clients: - for tool in await client.get_available_tools(): - find_tools[tool.name] = client.call_tool(tool.name) - - breakpoint() - - response = llm_chat_with_tool_call_output( - agent_name="computer_use_agent", - messages=messages, - tools=tool_schemas, - ) - - tool_calls = response["response"]["tool_calls"] - - breakpoint() - - for tool_call in tool_calls: - tool_result = await find_tools[tool_call["name"]](**tool_call["parameters"]) - breakpoint() - print(tool_result) - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/cerebrum/example/agents/meme_creator/config.json b/cerebrum/example/agents/meme_creator/config.json index bf331c4..95ed02d 100644 --- a/cerebrum/example/agents/meme_creator/config.json +++ b/cerebrum/example/agents/meme_creator/config.json @@ -8,7 +8,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/music_composer/config.json b/cerebrum/example/agents/music_composer/config.json index e2a77cf..d15cfae 100644 --- a/cerebrum/example/agents/music_composer/config.json +++ b/cerebrum/example/agents/music_composer/config.json @@ -9,7 +9,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/react/config.json b/cerebrum/example/agents/react/config.json new file mode 100644 index 0000000..9098215 --- /dev/null +++ b/cerebrum/example/agents/react/config.json @@ -0,0 +1,17 @@ +{ + "name": "react", + "description": [ + "You are a react agent. You can use the browser to search for information and execute code. " + ], + "tools": [ + ], + "meta": { + "author": "example", + "version": "0.0.1", + "license": "CC0" + }, + "build": { + "entry": "agent.py", + "module": "ReActAgent" + } +} diff --git a/cerebrum/example/agents/react/meta_requirements.txt b/cerebrum/example/agents/react/meta_requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/cerebrum/example/agents/react/react.py b/cerebrum/example/agents/react/react.py new file mode 100644 index 0000000..cc8f2de --- /dev/null +++ b/cerebrum/example/agents/react/react.py @@ -0,0 +1,277 @@ +from cerebrum.llm.apis import llm_chat, llm_chat_with_json_output, llm_chat_with_tool_call_output + +from litellm import completion + +from benchmarks.utils import get_parser + +from datasets import load_dataset + +from dotenv import load_dotenv + +from typing import List, Dict, Any + +import asyncio + +import json + +import uuid + +from cerebrum.example.agents.browser_use_agent.agent import BrowserUseAgent +from cerebrum.example.agents.code_executor.agent import CodeExecutor +from cerebrum.example.agents.calculator_agent.agent import CalculatorAgent + +from cerebrum.utils import _parse_json_output + +from cerebrum.interface import AutoTool + +load_dotenv() + +class ReActAgent: + def __init__(self, on_aios: bool = True): + self.agent_name = "react" + self.on_aios = on_aios + self.max_steps = 20 + self.history_window = 10 + self.history = [] + self.workers = { + "browser_use_agent": BrowserUseAgent(), + # "arxiv_search": AutoTool.from_preloaded("example/arxiv") + } + + def run(self, task_input: str): + + llms = [ + { + "name": "gpt-4o", + "backend": "openai" + } + ] + + WORKER_PROMPTS = """ +1. `self.workers["browser_use_agent"].run(task_input: str, start_url: str)`: Call the browser use agent to solve the given task and start from the given url. + """ + + system_prompt = f"""# Task Orchestration Instructions + +You are an orchestrator agent responsible for coordinating specialized workers to solve complex tasks. Your goal is to break down the main task into subtasks and assign them to appropriate workers. + +## Main Task +{task_input} + +## Available Workers +{WORKER_PROMPTS} + +## Your Responsibilities: +1. **Analyze the task** and break it down into logical subtasks +2. **Assign subtasks** to appropriate workers from your available list +3. **Coordinate the workflow** by processing each worker's output +4. **Synthesize results** into a comprehensive solution +5. **Verify completeness** before finalizing + +## Solution Requirements: +- Provide detailed explanations for each step +- Include specific implementations and examples where appropriate +- Ensure your solution directly addresses the original task +- If one approach fails, try alternative methods + +## Completion Protocol: +- Before submitting your final answer, double-check that you've fully completed the task +- Verify your solution against the original requirements +- When you're confident the task is complete, format your answer as: + ``` + [brief one-line summary of the task result] + ``` +- Only use the FINAL_ANSWER tag when the task is truly complete + +## Problem-Solving Tips: +- Try multiple approaches if your first method fails +- For web searches, check Wikipedia first before exploring other sources +- When searching, use advanced filters when appropriate (date, location, etc.) +- For math problems, consider using Python with the sympy library +- Always verify your answers through cross-checking +- Don't rely solely on your knowledge - use available tools +- When executing code, debug any errors rather than assuming correct results +- Search results rarely provide complete answers - use them to find sources for further analysis +- For file downloads, use web browser simulation or write appropriate code""" + + messages = [ + {"content": system_prompt, "role": "system"} + ] + + final_answer = "" + + # breakpoint() + rounds = 0 + + while rounds < self.max_steps: + step_instructions = f""" +## Step-by-Step Execution Protocol: +Here are the latest {self.history_window} trajectory (at most) you have taken: + +{self.history[-self.history_window:]} + + +Your output should be in json format, including the following fields: +- `observation`: Do not over-confident about the correctness of the history actions. You should +always check the current state to make sure the correctness of the next action. +- `reasoning`: The reasoning about the next action you want to take, and the possible obstacles you may encounter, +and how to solve them. Do not forget to check the history actions to avoid the same mistakes. +- `worker_name`: The name of the worker you want to use. It is only one step action +without any other texts (such as annotation) +- `worker_params`: The parameters for the worker. It is a dictionary containing the parameters for the worker. + +When you have recognized the task is completed, you need to output + +Here is two example of the output: +```json +{{ + "observation": "I have obtained the answer FeO is the densest iron oxide on the Moon...", + "reasoning": "Since I have already obtained the answer, I do not need to call any worker... ", + "worker_name": None, + "worker_params": None +}} + +{{ + "observation": "At current stage, I have already opened the amazon website...", + "reasoning": "To proceed with the task of searching for iphone products, I need to complete...", + "worker_name": "browser_use_agent", + "worker_params": {{ + "task_prompt": "search for the product 'iphone' on amazon", + "start_url": "https://www.amazon.com" + }} +}} """ + messages.append({"content": step_instructions, "role": "user"}) + + response_format = { + "type": "json_schema", + "json_schema": { + "name": "orchestration", + "schema": { + "type": "object", + "properties": { + "observation": {"type": "string"}, + "reasoning": {"type": "string"}, + "worker_name": {"type": "string"}, + "worker_params": {"type": "object"} + }, + "required": ["observation", "reasoning", "worker_name", "worker_params"] + } + } + } + + response = llm_chat_with_json_output( + agent_name=self.agent_name, + messages=messages, + llms=llms, + response_format=response_format + ) + + step_response = response["response"]["response_message"] + + resp_dict = _parse_json_output(step_response) + observation = resp_dict.get("observation", "") + reasoning = resp_dict.get("reasoning", "") + worker_name = resp_dict.get("worker_name", None) + worker_params = resp_dict.get("worker_params", {}) + + if worker_name is None: + if worker_params is {}: + trajectory_info = { + "round": rounds, + "observation": observation, + "thought": reasoning, + "called_worker": worker_name, + "called_worker_params": worker_params, + "info": None + } + final_answer = self.get_final_answer(task_input) + + break + + else: + + worker_response = self.workers[worker_name].run(**worker_params) + + result = worker_response["result"] + + trajectory_info = { + "round": rounds, + "observation": observation, + "thought": reasoning, + "called_worker": worker_name, + "called_worker_params": worker_params, + "info": result + } + + print(trajectory_info) + + self.history.append(trajectory_info) + + rounds += 1 + + return { + "agent_name": self.agent_name, + "result": final_answer, + "rounds": rounds + } + + def get_final_answer(self, task_input: str): + r"""Get the final answer based on the task prompt and current state. + It is used when the agent thinks that the task can be completed without any further action, and answer can be directly found in the current state. + """ + system_prompt = """ +You are an extractor agent. Your job is to extract the final answer from the history and the task prompt. +""" + prompt = f""" +You are solving a complex task which needs multi-step interaction with different workers. After the multi-step observation, reasoning and acting taken by different workers, you thinkthe task is currently solved. +Here are all trajectory we have taken: +{self.history} +Please find the final answer, or give valuable insights and founds (e.g. if previous actions contain downloading files, your output should include the path of the downloaded file) about the overall task: {task_input} + """ + + messages = [ + {"content": system_prompt, "role": "system"}, + {"content": prompt, "role": "user"} + ] + + llms = [ + { + "name": "gpt-4o-mini", + "backend": "openai" + } + ] + + response = llm_chat( + agent_name=self.agent_name, + messages=messages, + llms=llms + ) + + return response["response"]["response_message"] + +def main(): + agent = ReActAgent() + + data = { + "Question": """ +What is the temperature difference between Edison and New York today? +""", + "Tools": "1. Web browser, 2. Calculator" + } + + main_parser = get_parser() + main_args = main_parser.parse_args() + dataset = load_dataset(main_args.data_name, "2023_all", split=main_args.split) + + dataset = dataset["Question"] + + # for idx, question in enumerate(dataset): + # result = agent.run(question) + # print(result) + result = agent.run(data["Question"]) + print(result) + +if __name__ == "__main__": + + main() + diff --git a/cerebrum/example/agents/story_teller/config.json b/cerebrum/example/agents/story_teller/config.json index 2a6d5f8..6c38d99 100644 --- a/cerebrum/example/agents/story_teller/config.json +++ b/cerebrum/example/agents/story_teller/config.json @@ -9,7 +9,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/tech_support_agent/config.json b/cerebrum/example/agents/tech_support_agent/config.json index 93a687d..4c0e1e2 100644 --- a/cerebrum/example/agents/tech_support_agent/config.json +++ b/cerebrum/example/agents/tech_support_agent/config.json @@ -8,7 +8,7 @@ ], "meta": { "author": "example", - "version": "0.0.1", + "version": "0.0.2", "license": "CC0" }, "build": { diff --git a/cerebrum/example/agents/test_agent/config.json b/cerebrum/example/agents/test_agent/config.json index abf837f..9e96fa5 100644 --- a/cerebrum/example/agents/test_agent/config.json +++ b/cerebrum/example/agents/test_agent/config.json @@ -6,7 +6,7 @@ "tools": [], "meta": { "author": "example", - "version": "0.0.3", + "version": "0.0.4", "license": "CC0" }, "build": { diff --git a/cerebrum/llm/apis.py b/cerebrum/llm/apis.py index 312aeb5..d49064f 100644 --- a/cerebrum/llm/apis.py +++ b/cerebrum/llm/apis.py @@ -8,6 +8,15 @@ aios_kernel_url = config.get_kernel_url() +import requests + +def list_available_llms(): + """ + List all available LLMs. + """ + response = requests.get(f"{aios_kernel_url}/core/llms/list") + return response.json() + class LLMQuery(Query): """ Query class for LLM operations. diff --git a/cerebrum/tool/mcp_tool/__init__.py b/cerebrum/tool/mcp_tool/__init__.py index 54be5ee..6dab1a2 100644 --- a/cerebrum/tool/mcp_tool/__init__.py +++ b/cerebrum/tool/mcp_tool/__init__.py @@ -1,7 +1,9 @@ from .mcp_client import MCPClient from .pool import MCPPool -mcp_pool = MCPPool() +__all__ = ["MCPClient", "MCPPool"] + +# mcp_pool = MCPPool() # if os.getenv("BRAVE_API_KEY"): # TOOLS.add_mcp_client( @@ -72,12 +74,12 @@ # ), # ) -mcp_pool.add_mcp_client( - "playwright", - MCPClient.from_npx( - "@playwright/mcp@latest", - suffix_args=[ - "--headless" - ], - ), -) +# mcp_pool.add_mcp_client( +# "playwright", +# MCPClient.from_npx( +# "@playwright/mcp@latest", +# suffix_args=[ +# "--headless" +# ], +# ), +# ) diff --git a/cerebrum/tool/mcp_tool/mcp_client.py b/cerebrum/tool/mcp_tool/mcp_client.py index fe56760..d45495f 100644 --- a/cerebrum/tool/mcp_tool/mcp_client.py +++ b/cerebrum/tool/mcp_tool/mcp_client.py @@ -16,17 +16,24 @@ class MCPClient(BaseMCPClient): """ @classmethod - def from_smithery(cls, pkg_name: str, suffix_args: List[str] = []): + def from_smithery( + cls, + pkg_name: str, + description: str = "", + suffix_args: List[str] = [], + env: Dict[str, str] = None, + ): server_params = StdioServerParameters( command="npx", args=["-y", "@smithery/cli@latest", "run", pkg_name, *suffix_args], + env=env, ) # CONSOLE.log(f"Use MCP: {pkg_name} from smithery") - return cls(pkg_name, server_params) + return cls(pkg_name, description, server_params) @classmethod def from_npx( - cls, pkg_name: str, prefix_args: List[str] = [], suffix_args: List[str] = [] + cls, pkg_name: str, description: str = "", prefix_args: List[str] = [], suffix_args: List[str] = [] ): server_params = StdioServerParameters( command="npx", @@ -34,7 +41,7 @@ def from_npx( # **extra_args ) print(f"Use MCP: {pkg_name} from npx") - return cls(pkg_name, server_params) + return cls(pkg_name, description, server_params) @classmethod def from_docker( @@ -59,9 +66,10 @@ def from_docker( # CONSOLE.log(f"Use MCP: {image_name} from docker") return cls(image_name, server_params) - def __init__(self, name: str, server_params: StdioServerParameters): + def __init__(self, name: str, description: str = "", server_params: StdioServerParameters = None): """Initialize the MCP client with server parameters""" self.__name = name + self.__description = description self.server_params = server_params self.session = None self.read = None @@ -72,6 +80,10 @@ def __init__(self, name: str, server_params: StdioServerParameters): @property def name(self) -> str: return self.__name + + @property + def description(self) -> str: + return self.__description async def __aenter__(self): await self.connect() @@ -136,20 +148,62 @@ async def callable(*args, **kwargs): return response.content[0].text return callable - - async def hint(self) -> str: + + async def get_tool_hints_by_name(self, tool_name: str = None) -> str: tools = await self.get_available_tools() - hint = "" for tool in tools: - hint += f"- {tool.name}: {tool.description}\n" - return hint - - async def tool_schemas(self) -> List[dict]: + if tool.name == tool_name: + return f"- {tool.name}: {tool.description}\n" + return "" + + async def get_tool_schemas_by_name(self, tool_name: str = None) -> List[dict]: + tools = await self.get_available_tools() + for tool in tools: + if tool.name == tool_name: + schema = tool.inputSchema + if "$schema" in schema: + schema.pop("$schema") + return schema + return [] + + async def get_all_tool_hints(self) -> List[str]: + tools = await self.get_available_tools() + hints = [] + for tool in tools: + hints.append(f"{tool.name}: {tool.description}\n") + return hints + + async def get_all_tool_information(self) -> List[str]: + tools = await self.get_available_tools() + tool_information = [] + for tool in tools: + schema = tool.inputSchema + openai_tool_schema = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": schema, + } + } + + if "$schema" in schema: + schema.pop("$schema") + tool_information.append({ + "name": tool.name, + "description": tool.description, + "hint": f"{tool.name}: {tool.description}", + "schema": openai_tool_schema, + }) + return tool_information + + async def get_all_tool_schemas(self) -> List[dict]: openai_tools = [] tools = await self.get_available_tools() for tool in tools: schema = tool.inputSchema - schema.pop("$schema") + if "$schema" in schema: + schema.pop("$schema") openai_tools.append( { "type": "function", diff --git a/cerebrum/tool/mcp_tool/type.py b/cerebrum/tool/mcp_tool/type.py index 2dc38de..ac51681 100644 --- a/cerebrum/tool/mcp_tool/type.py +++ b/cerebrum/tool/mcp_tool/type.py @@ -15,6 +15,12 @@ async def connect(self, exit_stack: AsyncExitStack): def name(self) -> str: """The name of the MCP client""" pass + + @property + @abstractmethod + def description(self) -> str: + """The description of the MCP client""" + pass @abstractmethod async def get_available_tools(self) -> List[Tool]: @@ -36,17 +42,39 @@ async def call_tool(self, tool_name: str) -> Callable[..., Awaitable[str]]: A callable async function that executes the specified tool """ pass - + + @abstractmethod - async def hint(self) -> str: + async def get_tool_hints_by_name(self, tool_name: str = None) -> str: """ - Retrieve a hint for the MCP server. + Retrieve a hint for a specific tool from the MCP server. + """ + pass + + @abstractmethod + async def get_all_tool_hints(self) -> str: + """ + Retrieve hints for all tools from the MCP server. """ pass @abstractmethod - async def tool_schemas(self) -> List[dict]: + async def get_tool_schemas_by_name(self, tool_name: str = None) -> List[dict]: """ - Retrieve a list of tool schemas from the MCP server. + Retrieve schemas for a specific tool from the MCP server. """ - pass \ No newline at end of file + pass + + @abstractmethod + async def get_all_tool_schemas(self) -> List[dict]: + """ + Retrieve schemas for all tools from the MCP server. + """ + pass + + @abstractmethod + async def get_all_tool_information(self) -> List[dict]: + """ + Retrieve information for all tools from the MCP server. + """ + pass diff --git a/cerebrum/utils/__init__.py b/cerebrum/utils/__init__.py index c228b8f..e69de29 100644 --- a/cerebrum/utils/__init__.py +++ b/cerebrum/utils/__init__.py @@ -1,21 +0,0 @@ -import random -import os -from typing import Optional - - -def generator_tool_call_id(): - """generate tool call id - """ - return str(random.randint(0, 1000)) - -def get_from_env(env_key: str, default: Optional[str] = None) -> str: - """Get a value from an environment variable.""" - if env_key in os.environ and os.environ[env_key]: - return os.environ[env_key] - elif default is not None: - return default - else: - raise ValueError( - f"Did not find {env_key}, please add an environment variable" - f" `{env_key}` which contains it. " - ) \ No newline at end of file diff --git a/cerebrum/utils/browser.py b/cerebrum/utils/browser.py new file mode 100644 index 0000000..4608c21 --- /dev/null +++ b/cerebrum/utils/browser.py @@ -0,0 +1,836 @@ +import os +import time +import shutil +import datetime +import io +import urllib.parse +from typing import Optional, Union, Literal, Tuple, List, Dict, Any, BinaryIO +from PIL import Image +from playwright.sync_api import sync_playwright +from typing import TypedDict +import random +from typing import cast +from copy import deepcopy + +from PIL import Image, ImageDraw, ImageFont + +TOP_NO_LABEL_ZONE = 20 + + +class DOMRectangle(TypedDict): + x: Union[int, float] + y: Union[int, float] + width: Union[int, float] + height: Union[int, float] + top: Union[int, float] + right: Union[int, float] + bottom: Union[int, float] + left: Union[int, float] + + +class VisualViewport(TypedDict): + height: Union[int, float] + width: Union[int, float] + offsetLeft: Union[int, float] + offsetTop: Union[int, float] + pageLeft: Union[int, float] + pageTop: Union[int, float] + scale: Union[int, float] + clientWidth: Union[int, float] + clientHeight: Union[int, float] + scrollWidth: Union[int, float] + scrollHeight: Union[int, float] + + +class InteractiveRegion(TypedDict): + tag_name: str + role: str + aria_name: str + v_scrollable: bool + rects: List[DOMRectangle] + + +def _get_str(d: Any, k: str) -> str: + r"""Safely retrieve a string value from a dictionary.""" + if k not in d: + raise KeyError(f"Missing required key: '{k}'") + val = d[k] + if isinstance(val, str): + return val + raise TypeError( + f"Expected a string for key '{k}', " f"but got {type(val).__name__}" + ) + + +def _get_number(d: Any, k: str) -> Union[int, float]: + r"""Safely retrieve a number (int or float) from a dictionary""" + val = d[k] + if isinstance(val, (int, float)): + return val + raise TypeError( + f"Expected a number (int/float) for key " + f"'{k}', but got {type(val).__name__}" + ) + + +def _get_bool(d: Any, k: str) -> bool: + r"""Safely retrieve a boolean value from a dictionary.""" + val = d[k] + if isinstance(val, bool): + return val + raise TypeError( + f"Expected a boolean for key '{k}', " f"but got {type(val).__name__}" + ) + +def _reload_image(image: Image.Image): + buffer = io.BytesIO() + image.save(buffer, format="PNG") + buffer.seek(0) + return Image.open(buffer) + + +def dom_rectangle_from_dict(rect: Dict[str, Any]) -> DOMRectangle: + r"""Create a DOMRectangle object from a dictionary.""" + return DOMRectangle( + x=_get_number(rect, "x"), + y=_get_number(rect, "y"), + width=_get_number(rect, "width"), + height=_get_number(rect, "height"), + top=_get_number(rect, "top"), + right=_get_number(rect, "right"), + bottom=_get_number(rect, "bottom"), + left=_get_number(rect, "left"), + ) + + +def interactive_region_from_dict(region: Dict[str, Any]) -> InteractiveRegion: + r"""Create an :class:`InteractiveRegion` object from a dictionary.""" + typed_rects: List[DOMRectangle] = [] + for rect in region["rects"]: + typed_rects.append(dom_rectangle_from_dict(rect)) + + return InteractiveRegion( + tag_name=_get_str(region, "tag_name"), + role=_get_str(region, "role"), + aria_name=_get_str(region, "aria-name"), + v_scrollable=_get_bool(region, "v-scrollable"), + rects=typed_rects, + ) + + +def visual_viewport_from_dict(viewport: Dict[str, Any]) -> VisualViewport: + r"""Create a :class:`VisualViewport` object from a dictionary.""" + return VisualViewport( + height=_get_number(viewport, "height"), + width=_get_number(viewport, "width"), + offsetLeft=_get_number(viewport, "offsetLeft"), + offsetTop=_get_number(viewport, "offsetTop"), + pageLeft=_get_number(viewport, "pageLeft"), + pageTop=_get_number(viewport, "pageTop"), + scale=_get_number(viewport, "scale"), + clientWidth=_get_number(viewport, "clientWidth"), + clientHeight=_get_number(viewport, "clientHeight"), + scrollWidth=_get_number(viewport, "scrollWidth"), + scrollHeight=_get_number(viewport, "scrollHeight"), + ) + + +def add_set_of_mark( + screenshot: Union[bytes, Image.Image, io.BufferedIOBase], + ROIs: Dict[str, InteractiveRegion], +) -> Tuple[Image.Image, List[str], List[str], List[str]]: + if isinstance(screenshot, Image.Image): + return _add_set_of_mark(screenshot, ROIs) + + if isinstance(screenshot, bytes): + screenshot = io.BytesIO(screenshot) + + image = Image.open(cast(BinaryIO, screenshot)) + comp, visible_rects, rects_above, rects_below = _add_set_of_mark( + image, ROIs + ) + image.close() + return comp, visible_rects, rects_above, rects_below + + +def _add_set_of_mark( + screenshot: Image.Image, ROIs: Dict[str, InteractiveRegion] +) -> Tuple[Image.Image, List[str], List[str], List[str]]: + r"""Add a set of marks to the screenshot. + + Args: + screenshot (Image.Image): The screenshot to add marks to. + ROIs (Dict[str, InteractiveRegion]): The regions to add marks to. + + Returns: + Tuple[Image.Image, List[str], List[str], List[str]]: A tuple + containing the screenshot with marked ROIs, ROIs fully within the + images, ROIs located above the visible area, and ROIs located below + the visible area. + """ + visible_rects: List[str] = list() + rects_above: List[str] = list() # Scroll up to see + rects_below: List[str] = list() # Scroll down to see + + fnt = ImageFont.load_default(14) + base = screenshot.convert("L").convert("RGBA") + overlay = Image.new("RGBA", base.size) + + draw = ImageDraw.Draw(overlay) + for r in ROIs: + for rect in ROIs[r]["rects"]: + # Empty rectangles + if not rect or rect["width"] == 0 or rect["height"] == 0: + continue + + # TODO: add scroll left and right? + horizontal_center = (rect["right"] + rect["left"]) / 2.0 + vertical_center = (rect["top"] + rect["bottom"]) / 2.0 + is_within_horizon = 0 <= horizontal_center < base.size[0] + is_above_viewport = vertical_center < 0 + is_below_viewport = vertical_center >= base.size[1] + + if is_within_horizon: + if is_above_viewport: + rects_above.append(r) + elif is_below_viewport: + rects_below.append(r) + else: # Fully visible + visible_rects.append(r) + _draw_roi(draw, int(r), fnt, rect) + + comp = Image.alpha_composite(base, overlay) + overlay.close() + return comp, visible_rects, rects_above, rects_below + + +def _draw_roi( + draw: ImageDraw.ImageDraw, + idx: int, + font: ImageFont.FreeTypeFont | ImageFont.ImageFont, + rect: DOMRectangle, +) -> None: + r"""Draw a ROI on the image. + + Args: + draw (ImageDraw.ImageDraw): The draw object. + idx (int): The index of the ROI. + font (ImageFont.FreeTypeFont | ImageFont.ImageFont): The font. + rect (DOMRectangle): The DOM rectangle. + """ + color = _get_random_color(idx) + text_color = _get_text_color(color) + + roi = ((rect["left"], rect["top"]), (rect["right"], rect["bottom"])) + + label_location = (rect["right"], rect["top"]) + label_anchor = "rb" + + if label_location[1] <= TOP_NO_LABEL_ZONE: + label_location = (rect["right"], rect["bottom"]) + label_anchor = "rt" + + draw.rectangle( + roi, outline=color, fill=(color[0], color[1], color[2], 48), width=2 + ) + + bbox = draw.textbbox( + label_location, + str(idx), + font=font, + anchor=label_anchor, + align="center", + ) + bbox = (bbox[0] - 3, bbox[1] - 3, bbox[2] + 3, bbox[3] + 3) + draw.rectangle(bbox, fill=color) + + draw.text( + label_location, + str(idx), + fill=text_color, + font=font, + anchor=label_anchor, + align="center", + ) + + +def _get_text_color( + bg_color: Tuple[int, int, int, int], +) -> Tuple[int, int, int, int]: + r"""Determine the ideal text color (black or white) for contrast. + + Args: + bg_color: The background color (R, G, B, A). + + Returns: + A tuple representing black or white color for text. + """ + luminance = bg_color[0] * 0.3 + bg_color[1] * 0.59 + bg_color[2] * 0.11 + return (0, 0, 0, 255) if luminance > 120 else (255, 255, 255, 255) + + +def _get_random_color(identifier: int) -> Tuple[int, int, int, int]: + r"""Generate a consistent random RGBA color based on the identifier. + + Args: + identifier: The ID used as a seed to ensure color consistency. + + Returns: + A tuple representing (R, G, B, A) values. + """ + rnd = random.Random(int(identifier)) + r = rnd.randint(0, 255) + g = rnd.randint(125, 255) + b = rnd.randint(0, 50) + color = [r, g, b] + # TODO: check why shuffle is needed? + rnd.shuffle(color) + color.append(255) + return cast(Tuple[int, int, int, int], tuple(color)) + +class BaseBrowser: + def __init__( + self, + headless=True, + cache_dir: Optional[str] = None, + channel: Literal["chrome", "msedge", "chromium"] = "chromium", + ): + r"""Initialize the WebBrowser instance. + + Args: + headless (bool): Whether to run the browser in headless mode. + cache_dir (Union[str, None]): The directory to store cache files. + channel (Literal["chrome", "msedge", "chromium"]): The browser + channel to use. Must be one of "chrome", "msedge", or + "chromium". + + Returns: + None + """ + + self.history: list = [] + self.headless = headless + self.channel = channel + self._ensure_browser_installed() + self.playwright = sync_playwright().start() + self.page_history: list = [] # stores the history of visited pages + + # Set the cache directory + self.cache_dir = "tmp/" if cache_dir is None else cache_dir + os.makedirs(self.cache_dir, exist_ok=True) + + # Load the page script + abs_dir_path = os.path.dirname(os.path.abspath(__file__)) + page_script_path = os.path.join(abs_dir_path, "page_script.js") + + try: + with open(page_script_path, "r", encoding='utf-8') as f: + self.page_script = f.read() + f.close() + except FileNotFoundError: + raise FileNotFoundError( + f"Page script file not found at path: {page_script_path}" + ) + + def init(self) -> None: + r"""Initialize the browser.""" + # Launch the browser, if headless is False, the browser will display + self.browser = self.playwright.chromium.launch( + headless=self.headless, channel=self.channel + ) + # Create a new context + self.context = self.browser.new_context(accept_downloads=True) + # Create a new page + self.page = self.context.new_page() + + def clean_cache(self) -> None: + r"""Delete the cache directory and its contents.""" + if os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) + + def _wait_for_load(self, timeout: int = 20) -> None: + r"""Wait for a certain amount of time for the page to load.""" + timeout_ms = timeout * 1000 + + self.page.wait_for_load_state("load", timeout=timeout_ms) + + # TODO: check if this is needed + time.sleep(2) + + def click_blank_area(self) -> None: + r"""Click a blank area of the page to unfocus the current element.""" + self.page.mouse.click(0, 0) + self._wait_for_load() + + def visit_page(self, url: str) -> None: + r"""Visit a page with the given URL.""" + + self.page.goto(url) + self._wait_for_load() + self.page_url = url + + # def ask_question_about_video(self, question: str) -> str: + # r"""Ask a question about the video on the current page, + # such as YouTube video. + + # Args: + # question (str): The question to ask. + + # Returns: + # str: The answer to the question. + # """ + # video_analyzer = VideoAnalysisToolkit() + # result = video_analyzer.ask_question_about_video( + # self.page_url, question + # ) + # return result + + # @retry_on_error() + def get_screenshot( + self, save_image: bool = False + ) -> Tuple[Image.Image, Union[str, None]]: + r"""Get a screenshot of the current page. + + Args: + save_image (bool): Whether to save the image to the cache + directory. + + Returns: + Tuple[Image.Image, str]: A tuple containing the screenshot + image and the path to the image file if saved, otherwise + :obj:`None`. + """ + + image_data = self.page.screenshot(timeout=60000) + image = Image.open(io.BytesIO(image_data)) + + file_path = None + if save_image: + # Get url name to form a file name + # Use urlparser for a safer extraction the url name + parsed_url = urllib.parse.urlparse(self.page_url) + url_name = os.path.basename(str(parsed_url.path)) or "index" + + for char in ['\\', '/', ':', '*', '?', '"', '<', '>', '|', '.']: + url_name = url_name.replace(char, "_") + + # Get formatted time: mmddhhmmss + timestamp = datetime.datetime.now().strftime("%m%d%H%M%S") + file_path = os.path.join( + self.cache_dir, f"{url_name}_{timestamp}.png" + ) + with open(file_path, "wb") as f: + image.save(f, "PNG") + f.close() + + return image, file_path + + def capture_full_page_screenshots( + self, scroll_ratio: float = 0.8 + ) -> List[str]: + r"""Capture full page screenshots by scrolling the page with a buffer + zone. + + Args: + scroll_ratio (float): The ratio of viewport height to scroll each + step. (default: :obj:`0.8`) + + Returns: + List[str]: A list of paths to the screenshot files. + """ + screenshots = [] + scroll_height = self.page.evaluate("document.body.scrollHeight") + assert self.page.viewport_size is not None + viewport_height = self.page.viewport_size["height"] + current_scroll = 0 + screenshot_index = 1 + + max_height = scroll_height - viewport_height + scroll_step = int(viewport_height * scroll_ratio) + + last_height = 0 + + while True: + # print( + # f"Current scroll: {current_scroll}, max_height: " + # f"{max_height}, step: {scroll_step}" + # ) + + _, file_path = self.get_screenshot(save_image=True) + screenshots.append(file_path) + + self.page.evaluate(f"window.scrollBy(0, {scroll_step})") + # Allow time for content to load + time.sleep(0.5) + + current_scroll = self.page.evaluate("window.scrollY") + # Break if there is no significant scroll + if abs(current_scroll - last_height) < viewport_height * 0.1: + break + + last_height = current_scroll + screenshot_index += 1 + + return screenshots + + def get_visual_viewport(self) -> VisualViewport: + r"""Get the visual viewport of the current page. + + Returns: + VisualViewport: The visual viewport of the current page. + """ + try: + self.page.evaluate(self.page_script) + except Exception as e: + pass + # print(f"Error evaluating page script: {e}") + + return visual_viewport_from_dict( + self.page.evaluate("MultimodalWebSurfer.getVisualViewport();") + ) + + def get_interactive_elements(self) -> Dict[str, InteractiveRegion]: + r"""Get the interactive elements of the current page. + + Returns: + Dict[str, InteractiveRegion]: A dictionary of interactive elements. + """ + try: + self.page.evaluate(self.page_script) + except Exception as e: + print(f"Error evaluating page script: {e}") + + result = cast( + Dict[str, Dict[str, Any]], + self.page.evaluate("MultimodalWebSurfer.getInteractiveRects();"), + ) + + typed_results: Dict[str, InteractiveRegion] = {} + for k in result: + typed_results[k] = interactive_region_from_dict(result[k]) + + return typed_results # type: ignore[return-value] + + def get_som_screenshot( + self, + save_image: bool = False, + ) -> Tuple[Image.Image, Union[str, None]]: + r"""Get a screenshot of the current viewport with interactive elements + marked. + + Args: + save_image (bool): Whether to save the image to the cache + directory. + + Returns: + Tuple[Image.Image, str]: A tuple containing the screenshot image + and the path to the image file. + """ + + self._wait_for_load() + screenshot, _ = self.get_screenshot(save_image=save_image) + rects = self.get_interactive_elements() + + file_path = None + comp, visible_rects, rects_above, rects_below = add_set_of_mark( + screenshot, + rects, # type: ignore[arg-type] + ) + if save_image: + parsed_url = urllib.parse.urlparse(self.page_url) + url_name = os.path.basename(str(parsed_url.path)) or "index" + for char in ['\\', '/', ':', '*', '?', '"', '<', '>', '|', '.']: + url_name = url_name.replace(char, "_") + timestamp = datetime.datetime.now().strftime("%m%d%H%M%S") + file_path = os.path.join( + self.cache_dir, f"{url_name}_{timestamp}.png" + ) + with open(file_path, "wb") as f: + comp.save(f, "PNG") + f.close() + + return comp, file_path + + def scroll_up(self) -> None: + r"""Scroll up the page.""" + self.page.keyboard.press("PageUp") + + def scroll_down(self) -> None: + r"""Scroll down the page.""" + self.page.keyboard.press("PageDown") + + def get_url(self) -> str: + r"""Get the URL of the current page.""" + return self.page.url + + def click_id(self, identifier: Union[str, int]) -> None: + r"""Click an element with the given identifier.""" + if isinstance(identifier, int): + identifier = str(identifier) + target = self.page.locator(f"[__elementId='{identifier}']") + + try: + target.wait_for(timeout=5000) + except (TimeoutError, Exception) as e: + print(f"Error during click operation: {e}") + raise ValueError("No such element.") from None + + target.scroll_into_view_if_needed() + + new_page = None + try: + with self.page.expect_event("popup", timeout=1000) as page_info: + box = cast(Dict[str, Union[int, float]], target.bounding_box()) + self.page.mouse.click( + box["x"] + box["width"] / 2, box["y"] + box["height"] / 2 + ) + new_page = page_info.value + + # If a new page is opened, switch to it + if new_page: + self.page_history.append(deepcopy(self.page.url)) + self.page = new_page + + except (TimeoutError, Exception) as e: + print(f"Error during click operation: {e}") + pass + + self._wait_for_load() + + def extract_url_content(self) -> str: + r"""Extract the content of the current page.""" + content = self.page.content() + return content + + def download_file_id(self, identifier: Union[str, int]) -> str: + r"""Download a file with the given selector. + + Args: + identifier (str): The identifier of the file to download. + file_path (str): The path to save the downloaded file. + + Returns: + str: The result of the action. + """ + + if isinstance(identifier, int): + identifier = str(identifier) + try: + target = self.page.locator(f"[__elementId='{identifier}']") + except (TimeoutError, Exception) as e: + print(f"Error during download operation: {e}") + print( + f"Element with identifier '{identifier}' not found." + ) + return f"Element with identifier '{identifier}' not found." + + target.scroll_into_view_if_needed() + + file_path = os.path.join(self.cache_dir) + self._wait_for_load() + + try: + with self.page.expect_download() as download_info: + target.click() + download = download_info.value + file_name = download.suggested_filename + + file_path = os.path.join(file_path, file_name) + download.save_as(file_path) + + return f"Downloaded file to path '{file_path}'." + + except (TimeoutError, Exception) as e: + print(f"Error during download operation: {e}") + return f"Failed to download file with identifier '{identifier}'." + + def fill_input_id(self, identifier: Union[str, int], text: str) -> str: + r"""Fill an input field with the given text, and then press Enter. + + Args: + identifier (str): The identifier of the input field. + text (str): The text to fill. + + Returns: + str: The result of the action. + """ + if isinstance(identifier, int): + identifier = str(identifier) + + try: + target = self.page.locator(f"[__elementId='{identifier}']") + except (TimeoutError, Exception) as e: + print(f"Error during fill operation: {e}") + print( + f"Element with identifier '{identifier}' not found." + ) + return f"Element with identifier '{identifier}' not found." + + target.scroll_into_view_if_needed() + target.focus() + try: + target.fill(text) + except (TimeoutError, Exception) as e: + print(f"Error during fill operation: {e}") + target.press_sequentially(text) + + target.press("Enter") + self._wait_for_load() + return ( + f"Filled input field '{identifier}' with text '{text}' " + f"and pressed Enter." + ) + + def scroll_to_bottom(self) -> str: + self.page.evaluate("window.scrollTo(0, document.body.scrollHeight);") + self._wait_for_load() + return "Scrolled to the bottom of the page." + + def scroll_to_top(self) -> str: + self.page.evaluate("window.scrollTo(0, 0);") + self._wait_for_load() + return "Scrolled to the top of the page." + + def hover_id(self, identifier: Union[str, int]) -> str: + r"""Hover over an element with the given identifier. + + Args: + identifier (str): The identifier of the element to hover over. + + Returns: + str: The result of the action. + """ + if isinstance(identifier, int): + identifier = str(identifier) + try: + target = self.page.locator(f"[__elementId='{identifier}']") + except (TimeoutError, Exception) as e: + print(f"Error during hover operation: {e}") + print( + f"Element with identifier '{identifier}' not found." + ) + return f"Element with identifier '{identifier}' not found." + + target.scroll_into_view_if_needed() + target.hover() + self._wait_for_load() + return f"Hovered over element with identifier '{identifier}'." + + def find_text_on_page(self, search_text: str) -> str: + r"""Find the next given text on the page, and scroll the page to the + targeted text. It is equivalent to pressing Ctrl + F and searching for + the text. + """ + # ruff: noqa: E501 + script = f""" + (function() {{ + let text = "{search_text}"; + let found = window.find(text); + if (!found) {{ + let elements = document.querySelectorAll("*:not(script):not(style)"); + for (let el of elements) {{ + if (el.innerText && el.innerText.includes(text)) {{ + el.scrollIntoView({{behavior: "smooth", block: "center"}}); + el.style.backgroundColor = "yellow"; + el.style.border = '2px solid red'; + return true; + }} + }} + return false; + }} + return true; + }})(); + """ + found = self.page.evaluate(script) + self._wait_for_load() + if found: + return f"Found text '{search_text}' on the page." + else: + return f"Text '{search_text}' not found on the page." + + def back(self): + r"""Navigate back to the previous page.""" + + page_url_before = self.page.url + self.page.go_back() + + page_url_after = self.page.url + + if page_url_after == "about:blank": + self.visit_page(page_url_before) + + if page_url_before == page_url_after: + # If the page is not changed, try to use the history + if len(self.page_history) > 0: + self.visit_page(self.page_history.pop()) + + time.sleep(1) + self._wait_for_load() + + def close(self): + self.browser.close() + + # ruff: noqa: E501 + def show_interactive_elements(self): + r"""Show simple interactive elements on the current page.""" + self.page.evaluate(self.page_script) + self.page.evaluate(""" + () => { + document.querySelectorAll('a, button, input, select, textarea, [tabindex]:not([tabindex="-1"]), [contenteditable="true"]').forEach(el => { + el.style.border = '2px solid red'; + }); + } + """) + + # @retry_on_error() + def get_webpage_content(self) -> str: + from html2text import html2text + + self._wait_for_load() + html_content = self.page.content() + + markdown_content = html2text(html_content) + return markdown_content + + def _ensure_browser_installed(self) -> None: + r"""Ensure the browser is installed.""" + import platform + import subprocess + import sys + + try: + from playwright.sync_api import sync_playwright + + with sync_playwright() as p: + browser = p.chromium.launch(channel=self.channel) + browser.close() + except Exception: + print("Installing Chromium browser...") + try: + subprocess.run( + [ + sys.executable, + "-m", + "playwright", + "install", + self.channel, + ], + check=True, + capture_output=True, + ) + if platform.system().lower() == "linux": + subprocess.run( + [ + sys.executable, + "-m", + "playwright", + "install-deps", + self.channel, + ], + check=True, + capture_output=True, + ) + print("Chromium browser installation completed") + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to install browser: {e.stderr}") + diff --git a/cerebrum/utils/utils.py b/cerebrum/utils/utils.py new file mode 100644 index 0000000..13cabb5 --- /dev/null +++ b/cerebrum/utils/utils.py @@ -0,0 +1,78 @@ +import random +import os +from typing import Optional, Dict, Any +import re +import json + +def generator_tool_call_id(): + """generate tool call id + """ + return str(random.randint(0, 1000)) + +def get_from_env(env_key: str, default: Optional[str] = None) -> str: + """Get a value from an environment variable.""" + if env_key in os.environ and os.environ[env_key]: + return os.environ[env_key] + elif default is not None: + return default + else: + raise ValueError( + f"Did not find {env_key}, please add an environment variable" + f" `{env_key}` which contains it. " + ) + +def _parse_json_output(text: str) -> Dict[str, Any]: + r"""Extract JSON output from a string.""" + + markdown_pattern = r'```(?:json)?\s*(.*?)\s*```' + markdown_match = re.search(markdown_pattern, text, re.DOTALL) + if markdown_match: + text = markdown_match.group(1).strip() + + triple_quotes_pattern = r'"""(?:json)?\s*(.*?)\s*"""' + triple_quotes_match = re.search(triple_quotes_pattern, text, re.DOTALL) + if triple_quotes_match: + text = triple_quotes_match.group(1).strip() + + try: + return json.loads(text) + except json.JSONDecodeError: + try: + fixed_text = re.sub( + r'`([^`]*?)`(?=\s*[:,\[\]{}]|$)', r'"\1"', text + ) + return json.loads(fixed_text) + except json.JSONDecodeError: + result = {} + try: + bool_pattern = r'"(\w+)"\s*:\s*(true|false)' + for match in re.finditer(bool_pattern, text, re.IGNORECASE): + key, value = match.groups() + result[key] = value.lower() == "true" + + str_pattern = r'"(\w+)"\s*:\s*"([^"]*)"' + for match in re.finditer(str_pattern, text): + key, value = match.groups() + result[key] = value + + num_pattern = r'"(\w+)"\s*:\s*(-?\d+(?:\.\d+)?)' + for match in re.finditer(num_pattern, text): + key, value = match.groups() + try: + result[key] = int(value) + except ValueError: + result[key] = float(value) + + empty_str_pattern = r'"(\w+)"\s*:\s*""' + for match in re.finditer(empty_str_pattern, text): + key = match.group(1) + result[key] = "" + + if result: + return result + + print(f"Failed to parse JSON output: {text}") + return {} + except Exception as e: + print(f"Error while extracting fields from JSON: {e}") + return {} diff --git a/pyproject.toml b/pyproject.toml index 9a7609c..49b5617 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ exclude = [ files = ["requirements.txt"] [project.scripts] +list-available-llms = "cerebrum.commands.list_available_llms:main" run-agent = "cerebrum.commands.run_agent:main" download-agent = "cerebrum.commands.download_agent:main" upload-agent = "cerebrum.commands.upload_agent:main" diff --git a/requirements.txt b/requirements.txt index c15824d..bd7fe65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ requests platformdirs -pydantic \ No newline at end of file +pydantic +mcp +datasets \ No newline at end of file