diff --git a/docs/wayflowcore/source/core/changelog.rst b/docs/wayflowcore/source/core/changelog.rst index 57eb8bfa..6a19f41e 100644 --- a/docs/wayflowcore/source/core/changelog.rst +++ b/docs/wayflowcore/source/core/changelog.rst @@ -17,6 +17,21 @@ WayFlow 26.1.2 New features ^^^^^^^^^^^^ +* **Configurable retry policies for MCP tools** + + Added ``retry_policy`` to ``MCPTool`` and ``MCPToolBox`` to retry transient + missing-tool responses and MCP tool execution failures, including WayFlow + Agent Spec plugin serialization support. + + For usage details, see :doc:`the MCP tools guide `. + + +WayFlow 26.1.2 +-------------- + +New features +^^^^^^^^^^^^ + * **Configurable retry policies for remote components** Added the ``RetryPolicy`` object to configure retries, backoff, diff --git a/docs/wayflowcore/source/core/code_examples/howto_mcp.py b/docs/wayflowcore/source/core/code_examples/howto_mcp.py index 6748668c..4b9c4bea 100644 --- a/docs/wayflowcore/source/core/code_examples/howto_mcp.py +++ b/docs/wayflowcore/source/core/code_examples/howto_mcp.py @@ -82,11 +82,13 @@ def start_mcp_server() -> str: from wayflowcore.agent import Agent from wayflowcore.mcp import MCPTool, MCPToolBox, SSETransport, authless_mcp_enabled from wayflowcore.flow import Flow +from wayflowcore.retrypolicy import RetryPolicy from wayflowcore.steps import ToolExecutionStep mcp_server_url = f"http://localhost:8080/sse" # change to your own URL # We will see below how to connect a specific tool to an assistant, e.g. MCP_TOOL_NAME = "get_user_session" +MCP_TOOLBOX_TOOL_FILTER = ["get_user_session", "get_payslips"] # And see how to build an agent that can answer questions, e.g. USER_QUERY = "What was the payment date of the last payslip for the current user?" # .. end-##_Imports_for_this_guide @@ -101,6 +103,8 @@ def start_mcp_server() -> str: mcp_server_url: str # docs-skiprow USER_QUERY: str # docs-skiprow (llm, mcp_server_url, USER_QUERY, MCP_TOOL_NAME) = _update_globals(["llm_small", "sse_mcp_server", "mcp_user_query", "mcp_example_tool_name"]) # docs-skiprow # type: ignore +if MCP_TOOL_NAME != "get_user_session": # docs-skiprow + MCP_TOOLBOX_TOOL_FILTER = [MCP_TOOL_NAME] # docs-skiprow # .. start-##_Connecting_an_agent_to_the_MCP_server mcp_client = SSETransport(url=mcp_server_url) @@ -112,6 +116,22 @@ def start_mcp_server() -> str: tools=[mcp_toolbox] ) # .. end-##_Connecting_an_agent_to_the_MCP_server +# .. start-##_Configuring_toolbox_retry_policy +mcp_toolbox_with_retries = MCPToolBox( + client_transport=mcp_client, + tool_filter=MCP_TOOLBOX_TOOL_FILTER, + retry_policy=RetryPolicy( + max_attempts=3, + initial_retry_delay=0.25, + max_retry_delay=2.0, + ), +) + +assistant = Agent( + llm=llm, + tools=[mcp_toolbox_with_retries] +) +# .. end-##_Configuring_toolbox_retry_policy from wayflowcore.agentspec import AgentSpecExporter # docs-skiprow serialized_assistant = AgentSpecExporter().to_json(assistant) # docs-skiprow @@ -156,9 +176,17 @@ def run_agent_in_command_line(assistant: Agent): name=MCP_TOOL_NAME, client_transport=mcp_client ) +# .. start-##_Configuring_direct_tool_retry_policy +with authless_mcp_enabled(): + mcp_tool_with_retries = MCPTool( + name=MCP_TOOL_NAME, + client_transport=mcp_client, + retry_policy=RetryPolicy(max_attempts=3), + ) +# .. end-##_Configuring_direct_tool_retry_policy assistant = Flow.from_steps([ - ToolExecutionStep(name="mcp_tool_step", tool=mcp_tool) + ToolExecutionStep(name="mcp_tool_step", tool=mcp_tool_with_retries) ]) # .. end-##_Connecting_a_flow_to_the_MCP_server from wayflowcore.agentspec import AgentSpecExporter, AgentSpecLoader # docs-skiprow diff --git a/docs/wayflowcore/source/core/howtoguides/howto_mcp.rst b/docs/wayflowcore/source/core/howtoguides/howto_mcp.rst index 19b02ded..7035bd59 100644 --- a/docs/wayflowcore/source/core/howtoguides/howto_mcp.rst +++ b/docs/wayflowcore/source/core/howtoguides/howto_mcp.rst @@ -106,6 +106,31 @@ Here you will use the toolbox (see the section on Flows to see how to use the `` Specify the :ref:`transport ` to use to handle the connection to the server and create the toolbox. You can then equip an agent with the toolbox similarly to tools. +If the MCP server composes tools from external MCP servers and can temporarily +return a partial tool list during health or remount events, declare the expected +tools with ``tool_filter`` and configure ``retry_policy``. +When an expected tool is missing from a successful ``list_tools`` response, +WayFlow retries the tool-list resolution before failing. The same policy is +propagated to generated ``MCPTool`` instances for transient tool execution +failures. + +.. literalinclude:: ../code_examples/howto_mcp.py + :language: python + :start-after: # .. start-##_Configuring_toolbox_retry_policy + :end-before: # .. end-##_Configuring_toolbox_retry_policy + +.. note:: + Tool-list missing-tool retry only applies when WayFlow knows which tools are + expected, for example through ``tool_filter`` or a direct ``MCPTool(name=...)``. + If ``tool_filter`` is ``None``, WayFlow cannot determine whether a successful + tool-list response is incomplete. + +.. note:: + The MCP tool retry policy is separate from transport-level retry. It handles + successful tool-list responses that are missing expected tools and transient + tool execution failures. Configure ``retry_policy`` on the transport when you + need lower-level HTTP client retry or request timeout behavior. + .. note:: ``authless_mcp_enabled()`` disables authorization for local/testing only—do not use in production. Keep it scoped around the code that creates MCP tools or toolboxes. @@ -153,6 +178,13 @@ Create the flow using the MCP tool: Here you specify the client transport as with the MCP ToolBox, as well as the name of the specific tool you want to use. Additionally, you can override the tool description (exposed by the MCP server) by specifying the ``description`` parameter. +You can also pass ``retry_policy`` to retry direct tool resolution and transient +tool execution failures. + +.. literalinclude:: ../code_examples/howto_mcp.py + :language: python + :start-after: # .. start-##_Configuring_direct_tool_retry_policy + :end-before: # .. end-##_Configuring_direct_tool_retry_policy .. tip:: diff --git a/wayflowcore/src/wayflowcore/exceptions.py b/wayflowcore/src/wayflowcore/exceptions.py index d0172c21..3cf607d1 100644 --- a/wayflowcore/src/wayflowcore/exceptions.py +++ b/wayflowcore/src/wayflowcore/exceptions.py @@ -70,6 +70,20 @@ class MaxNumTrialsExceededException(ValueError): class NoSuchToolFoundOnMCPServerError(ValueError): """Error thrown when MCP server returns no tools with a given signature""" + def __init__( + self, + message: str, + missing_tool_names: list[str] | None = None, + expected_tool_names: list[str] | None = None, + exposed_tool_names: list[str] | None = None, + attempts: int | None = None, + ) -> None: + super().__init__(message) + self.missing_tool_names = missing_tool_names or [] + self.expected_tool_names = expected_tool_names or [] + self.exposed_tool_names = exposed_tool_names or [] + self.attempts = attempts + class DataclassFieldDeserializationError(ValueError): """Error thrown when the deserialization of a field of a dataclass fails""" diff --git a/wayflowcore/src/wayflowcore/mcp/mcphelpers.py b/wayflowcore/src/wayflowcore/mcp/mcphelpers.py index 87d51fcf..97dc3c5d 100644 --- a/wayflowcore/src/wayflowcore/mcp/mcphelpers.py +++ b/wayflowcore/src/wayflowcore/mcp/mcphelpers.py @@ -36,6 +36,7 @@ from mcp.server.fastmcp import Context from mcp.server.session import ServerSessionT from mcp.shared.context import LifespanContextT, RequestT +from mcp.shared.exceptions import McpError from wayflowcore.events.event import ToolExecutionStreamingChunkReceivedEvent from wayflowcore.events.eventlistener import record_event @@ -45,6 +46,14 @@ get_mcp_async_runtime, ) from wayflowcore.mcp.clienttransport import ClientTransport, ClientTransportWithAuth +from wayflowcore.models._requesthelpers import ( + RetryClassification, + _classify_http_exception_for_retry, + _get_http_headers_from_exception, + _get_http_status_code_from_exception, + _get_retry_after_value_from_headers, + execute_async_with_retry, +) from wayflowcore.property import ( DictProperty, JsonSchemaParam, @@ -53,6 +62,7 @@ Property, UnionProperty, ) +from wayflowcore.retrypolicy import RetryPolicy from wayflowcore.tools.servertools import ServerTool from wayflowcore.tools.tools import Tool from wayflowcore.tracing.span import ToolExecutionSpan, get_current_span @@ -308,24 +318,179 @@ async def _invoke_mcp_tool_call_async( tool_name: str, tool_args: Dict[str, Any], output_descriptors: List[Property], + retry_policy: Optional[RetryPolicy] = None, ) -> Any: - with _catch_and_raise_mcp_connection_errors(): - result: types.CallToolResult = await session.call_tool( - tool_name, tool_args, progress_callback=_mcp_progress_handler - ) + """Call an MCP tool and apply retry policy to retryable execution failures.""" + + async def operation() -> Any: + with _catch_and_raise_mcp_connection_errors(): + result: types.CallToolResult = await session.call_tool( + tool_name, tool_args, progress_callback=_mcp_progress_handler + ) + + output = _try_handle_structured_content_from_tool_result(result, output_descriptors) + if output is not None: + return output + + return _extract_text_content_from_tool_result(result) + + if retry_policy is None: + return await operation() + + return await execute_async_with_retry( + operation, + retry_policy=retry_policy, + classify_exception=_classify_mcp_tool_call_for_retry, + retry_budget_exhausted_message="MCP tool execution retry budget exhausted", + ) + + +def _is_missing_mcp_tool_error(exc: BaseException) -> bool: + """Return whether an MCP error means the requested tool is temporarily missing.""" + + if not isinstance(exc, McpError): + return False + + message = exc.error.message.lower() + return "tool" in message and ( + "not found" in message or "unknown" in message or "does not exist" in message + ) + - output = _try_handle_structured_content_from_tool_result(result, output_descriptors) - if output is not None: - return output +def _classify_mcp_tool_call_for_retry(exc: Exception, policy: RetryPolicy) -> RetryClassification: + """Classify MCP tool-call failures that are safe to retry under the policy.""" - return _extract_text_content_from_tool_result(result) + retry_classification = _classify_http_exception_for_retry(exc, policy) + if retry_classification is not None: + return retry_classification + + # MCP clients often wrap transport errors in domain exceptions, so inspect + # the whole exception chain before deciding the failure is not retryable. + pending: List[BaseException] = [exc] + seen: set[int] = set() + while pending: + current = pending.pop() + if id(current) in seen: + continue + seen.add(id(current)) + + status_code = _get_http_status_code_from_exception(current) + if status_code == 404: + return status_code, _get_retry_after_value_from_headers( + _get_http_headers_from_exception(current) + ) + + if _is_missing_mcp_tool_error(current): + return None, None + + if current.__context__ is not None: + pending.append(current.__context__) + if current.__cause__ is not None: + pending.append(current.__cause__) + + return None async def _get_server_signatures_from_mcp_server(session: ClientSession) -> types.ListToolsResult: + """Fetch the raw MCP tool list from the server with connection errors normalized.""" + with _catch_and_raise_mcp_connection_errors(): return await session.list_tools() +def _classify_missing_mcp_tool_for_retry( + exc: Exception, policy: RetryPolicy +) -> RetryClassification: + """Retry only semantic MCP tool-list misses detected during validation.""" + + if isinstance(exc, NoSuchToolFoundOnMCPServerError): + return None, None + return None + + +def _raise_if_expected_tools_missing( + remote_mcp_signature: types.ListToolsResult, + expected_signatures_by_name: Dict[str, Optional[Tool]], + attempts: Optional[int], +) -> None: + """Raise with diagnostics if expected tools are absent from the MCP tool list.""" + + expected_tool_names = list(expected_signatures_by_name or {}) + if not expected_tool_names: + return + + exposed_tool_names = sorted({exposed_tool.name for exposed_tool in remote_mcp_signature.tools}) + exposed_tool_names_set = set(exposed_tool_names) + missing_tool_names = [ + expected_tool_name + for expected_tool_name in expected_tool_names + if expected_tool_name not in exposed_tool_names_set + ] + if not missing_tool_names: + return + + attempts_msg = f" after {attempts} tool-list attempt(s)" if attempts is not None else "" + raise NoSuchToolFoundOnMCPServerError( + f"Expected MCP tool(s) {missing_tool_names} but they were missing from the " + f"list of exposed tools{attempts_msg}. Exposed tools: {exposed_tool_names}", + missing_tool_names=missing_tool_names, + expected_tool_names=expected_tool_names, + exposed_tool_names=exposed_tool_names, + attempts=attempts, + ) + + +async def _get_validated_server_signatures_from_mcp_server( + session: ClientSession, + expected_signatures_by_name: Dict[str, Optional[Tool]], + attempts: Optional[int], +) -> types.ListToolsResult: + """Fetch MCP tool signatures and validate that all expected tools are exposed.""" + + remote_mcp_signature = await _get_server_signatures_from_mcp_server(session) + _raise_if_expected_tools_missing( + remote_mcp_signature, + expected_signatures_by_name, + attempts, + ) + return remote_mcp_signature + + +async def _get_validated_server_signatures_with_retry( + session: ClientSession, + expected_signatures_by_name: Dict[str, Optional[Tool]], + retry_policy: Optional[RetryPolicy], +) -> types.ListToolsResult: + """Fetch MCP tool signatures with semantic retry for missing expected tools.""" + + if not expected_signatures_by_name or retry_policy is None: + # A partial list can only be identified when callers declare which tools + # they expect. For "list all tools", preserve the single-call behavior. + return await _get_validated_server_signatures_from_mcp_server( + session, + expected_signatures_by_name, + 1 if expected_signatures_by_name else None, + ) + + attempts = 0 + + async def operation() -> types.ListToolsResult: + nonlocal attempts + attempts += 1 + return await _get_validated_server_signatures_from_mcp_server( + session, + expected_signatures_by_name, + attempts, + ) + + return await execute_async_with_retry( + operation, + retry_policy=retry_policy, + classify_exception=_classify_missing_mcp_tool_for_retry, + retry_budget_exhausted_message="MCP tool-list resolution retry budget exhausted", + ) + + def _try_convert_mcp_output_schema_to_properties( schema: Optional[Dict[str, Any]], tool_title: str, @@ -386,29 +551,16 @@ async def get_server_tools_from_mcp_server( session: ClientSession, expected_signatures_by_name: Dict[str, Optional[Tool]], client_transport: ClientTransport, + retry_policy: Optional[RetryPolicy] = None, ) -> List[ServerTool]: from wayflowcore.mcp.tools import MCPTool processed_tool_signatures: List[ServerTool] = [] - remote_mcp_signature = await _get_server_signatures_from_mcp_server(session) - - if missing_tool_name := next( - ( - expected_tool_name - for expected_tool_name in expected_signatures_by_name or {} - if expected_tool_name - not in ( - exposed_tool_names := { - exposed_tool.name for exposed_tool in remote_mcp_signature.tools - } - ) - ), - None, - ): - raise NoSuchToolFoundOnMCPServerError( - f"Expected tool '{missing_tool_name}' but tool was missing from the list of exposed tools. " - f"Exposed tools: {exposed_tool_names}" - ) + remote_mcp_signature = await _get_validated_server_signatures_with_retry( + session, + expected_signatures_by_name, + retry_policy, + ) for exposed_tool in remote_mcp_signature.tools: exposed_tool_name = exposed_tool.name @@ -478,18 +630,33 @@ async def get_server_tools_from_mcp_server( client_transport=client_transport, _validate_tool_exist_on_server=False, requires_confirmation=requires_confirmation, + retry_policy=retry_policy, ) ) return processed_tool_signatures async def _get_tool_on_server( - session: ClientSession, name: str, client_transport: ClientTransport + session: ClientSession, + name: str, + client_transport: ClientTransport, + retry_policy: Optional[RetryPolicy] = None, ) -> Tool: try: - tools = await get_server_tools_from_mcp_server(session, {name: None}, client_transport) + tools = await get_server_tools_from_mcp_server( + session, + {name: None}, + client_transport, + retry_policy=retry_policy, + ) except NoSuchToolFoundOnMCPServerError as e: - tools = [] + raise NoSuchToolFoundOnMCPServerError( + f"Cannot find a tool named {name} on the MCP server. {e}", + missing_tool_names=e.missing_tool_names, + expected_tool_names=e.expected_tool_names, + exposed_tool_names=e.exposed_tool_names, + attempts=e.attempts, + ) from e except Exception as e: raise ConnectionError(f"Cannot connect to the MCP server {client_transport}") from e diff --git a/wayflowcore/src/wayflowcore/mcp/tools.py b/wayflowcore/src/wayflowcore/mcp/tools.py index 11ecff5a..0ea3a5c7 100644 --- a/wayflowcore/src/wayflowcore/mcp/tools.py +++ b/wayflowcore/src/wayflowcore/mcp/tools.py @@ -22,6 +22,7 @@ get_server_tools_from_mcp_server, ) from wayflowcore.property import Property +from wayflowcore.retrypolicy import RetryPolicy from wayflowcore.serialization.context import DeserializationContext, SerializationContext from wayflowcore.serialization.serializer import SerializableDataclassMixin, SerializableObject from wayflowcore.tools.servertools import ServerTool @@ -44,6 +45,15 @@ class MCPTool(ServerTool, SerializableDataclassMixin, SerializableObject): client_transport: ClientTransport """Transport to use for establishing and managing connections to the MCP server.""" + retry_policy: Optional[RetryPolicy] = None + """ + Optional retry policy for MCP tool-list resolution and tool execution. + + For tool-list resolution, only the attempt and backoff fields of + ``RetryPolicy`` are used. For tool execution, the retry policy is also used + to classify retryable HTTP errors. + """ + def __init__( self, name: str, @@ -56,8 +66,12 @@ def __init__( __metadata_info__: Optional[MetadataType] = None, id: Optional[str] = None, requires_confirmation: bool = False, + retry_policy: Optional[RetryPolicy] = None, ): self.client_transport = client_transport + self.retry_policy = retry_policy + if self.retry_policy is not None and not isinstance(self.retry_policy, RetryPolicy): + raise TypeError("retry_policy must be a wayflowcore.retrypolicy.RetryPolicy instance") _validate_auth(self.client_transport) should_validate_tool = _validate_server_exists and _validate_tool_exist_on_server @@ -76,7 +90,13 @@ def __init__( if should_validate_tool: # 2. Perform the call (from the portal) - tool = mcp_runtime.call(_get_tool_on_server, session, name, self.client_transport) + tool = mcp_runtime.call( + _get_tool_on_server, + session, + name, + self.client_transport, + self.retry_policy, + ) if description is None: description = tool.description @@ -129,7 +149,12 @@ async def run_async(self, *args: Any, **kwargs: Any) -> Any: ) return await mcp_runtime.call_async( - _invoke_mcp_tool_call_async, session, self.name, kwargs, self.output_descriptors + _invoke_mcp_tool_call_async, + session, + self.name, + kwargs, + self.output_descriptors, + self.retry_policy, ) def run(self, *args: Any, **kwargs: Any) -> Any: @@ -138,7 +163,12 @@ def run(self, *args: Any, **kwargs: Any) -> Any: session = mcp_runtime.get_or_create_session(self.client_transport) return mcp_runtime.call( - _invoke_mcp_tool_call_async, session, self.name, kwargs, self.output_descriptors + _invoke_mcp_tool_call_async, + session, + self.name, + kwargs, + self.output_descriptors, + self.retry_policy, ) def _serialize_to_dict(self, serialization_context: "SerializationContext") -> Dict[str, Any]: @@ -153,6 +183,7 @@ def _serialize_to_dict(self, serialization_context: "SerializationContext") -> D "output_descriptors", "client_transport", "requires_confirmation", + "retry_policy", ] } @@ -162,19 +193,22 @@ def _deserialize_from_dict( ) -> "SerializableObject": from wayflowcore.serialization.serializer import autodeserialize_any_from_dict + field_names = [ + "name", + "description", + "input_descriptors", + "output_descriptors", + "client_transport", + "requires_confirmation", + "retry_policy", + ] return MCPTool( **{ attr_name: autodeserialize_any_from_dict( input_dict[attr_name], deserialization_context ) - for attr_name in [ - "name", - "description", - "input_descriptors", - "output_descriptors", - "client_transport", - "requires_confirmation", - ] + for attr_name in field_names + if attr_name in input_dict }, # deserialization should not require to be able to reach the server _validate_server_exists=False, @@ -214,9 +248,20 @@ class MCPToolBox(ToolBox, DataclassComponent): * Input descriptors can be provided with description of each input. The names and types should match the remote tool schema. """ + retry_policy: Optional[RetryPolicy] = None + """ + Optional retry policy for MCP tool-list resolution and tool execution. + + For tool-list resolution, only the attempt and backoff fields of + ``RetryPolicy`` are used. Generated ``MCPTool`` instances also use this + policy while executing MCP tool calls. + """ + _validate_mcp_client_transport: InitVar[bool] = field(default=True, compare=False) def __post_init__(self, _validate_mcp_client_transport: bool) -> None: + if self.retry_policy is not None and not isinstance(self.retry_policy, RetryPolicy): + raise TypeError("retry_policy must be a wayflowcore.retrypolicy.RetryPolicy instance") if _validate_mcp_client_transport: _validate_auth(self.client_transport) @@ -236,6 +281,7 @@ async def _get_tools_async_impl(self, session: ClientSession) -> Sequence[Server session=session, expected_signatures_by_name=expected_signatures_by_name, client_transport=self.client_transport, + retry_policy=self.retry_policy, ) async def _get_tools_inner_async(self) -> Sequence[ServerTool]: diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py index d562ba96..499efb0b 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_deserialization_plugin.py @@ -719,6 +719,9 @@ def convert_to_wayflow( id=agentspec_component.id, _validate_server_exists=False, _validate_tool_exist_on_server=False, + retry_policy=self._convert_retry_policy_to_runtime( + getattr(agentspec_component, "retry_policy", None) + ), ) elif isinstance(agentspec_component, AgentSpecPluginConstantValuesNode): # Map PluginConstantValuesNode -> RuntimeConstantValuesStep @@ -1740,6 +1743,9 @@ class SupportsTimeoutKwargs(TypedDict, total=False): tool_filter=tool_filter, **self._get_component_arguments(agentspec_component), requires_confirmation=agentspec_component.requires_confirmation, + retry_policy=self._convert_retry_policy_to_runtime( + getattr(agentspec_component, "retry_policy", None) + ), ) elif isinstance(agentspec_component, AgentSpecAgentNode): return RuntimeAgentExecutionStep( diff --git a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py index 47bbf26d..94742762 100644 --- a/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py +++ b/wayflowcore/src/wayflowcore/serialization/_builtins_serialization_plugin.py @@ -178,6 +178,7 @@ PluginOracleDatabaseDatastore as AgentSpecPluginOracleDatabaseDatastore, ) from wayflowcore.agentspec.components.flow import ExtendedFlow as AgentSpecExtendedFlow +from wayflowcore.agentspec.components.mcp import PluginMCPToolSpec as AgentSpecPluginMCPToolSpec from wayflowcore.agentspec.components.mcp import ( PluginSSEmTLSTransport as AgentSpecPluginSSEmTLSTransport, ) @@ -1526,7 +1527,7 @@ def _tool_convert_to_agentspec( ) # other cases: mcpservertool, server, client tools if isinstance(runtime_tool, RuntimeMCPTool): - return AgentSpecMCPTool( + mcp_tool_kwargs = dict( name=runtime_tool.name, description=runtime_tool.description, metadata=metadata, @@ -1544,8 +1545,14 @@ def _tool_convert_to_agentspec( referenced_objects, ), requires_confirmation=runtime_tool.requires_confirmation, + retry_policy=( + self._retrypolicy_convert_to_agentspec(runtime_tool.retry_policy) + if runtime_tool.retry_policy is not None + else None + ), id=runtime_tool.id, ) + return AgentSpecMCPTool(**mcp_tool_kwargs) elif isinstance(runtime_tool, RuntimeServerTool): return AgentSpecServerTool( name=runtime_tool.name, @@ -1992,21 +1999,26 @@ def _mcptoolspec_convert_to_agentspec( conversion_context: "WayflowToAgentSpecConversionContext", runtime_mcptoolspec: RuntimeTool, referenced_objects: Optional[Dict[str, Any]] = None, - ) -> AgentSpecMCPToolSpec: - - return AgentSpecMCPToolSpec( - name=runtime_mcptoolspec.name, - description=runtime_mcptoolspec.description, - inputs=[ - _runtime_property_to_pyagentspec_property(input_) - for input_ in runtime_mcptoolspec.input_descriptors or [] - ], - outputs=[ - _runtime_property_to_pyagentspec_property(output) - for output in runtime_mcptoolspec.output_descriptors or [] - ], - requires_confirmation=runtime_mcptoolspec.requires_confirmation, - metadata=_create_agentspec_metadata_from_runtime_component(runtime_mcptoolspec), + use_plugin_model: bool = False, + ) -> Union[AgentSpecMCPToolSpec, AgentSpecPluginMCPToolSpec]: + + agentspec_model = AgentSpecPluginMCPToolSpec if use_plugin_model else AgentSpecMCPToolSpec + return cast( + Union[AgentSpecMCPToolSpec, AgentSpecPluginMCPToolSpec], + agentspec_model( + name=runtime_mcptoolspec.name, + description=runtime_mcptoolspec.description, + inputs=[ + _runtime_property_to_pyagentspec_property(input_) + for input_ in runtime_mcptoolspec.input_descriptors or [] + ], + outputs=[ + _runtime_property_to_pyagentspec_property(output) + for output in runtime_mcptoolspec.output_descriptors or [] + ], + requires_confirmation=runtime_mcptoolspec.requires_confirmation, + metadata=_create_agentspec_metadata_from_runtime_component(runtime_mcptoolspec), + ), ) def _toolbox_convert_to_agentspec( @@ -2022,7 +2034,9 @@ def _toolbox_convert_to_agentspec( tool_ if isinstance(tool_, str) else self._mcptoolspec_convert_to_agentspec( - conversion_context, tool_, referenced_objects + conversion_context, + tool_, + referenced_objects, ) ) for tool_ in runtime_toolbox.tool_filter @@ -2030,7 +2044,7 @@ def _toolbox_convert_to_agentspec( if runtime_toolbox.tool_filter is not None else None ) - return AgentSpecMCPToolBox( + mcp_toolbox_kwargs = dict( name=runtime_toolbox.name, client_transport=self._mcp_clienttransport_convert_to_agentspec( conversion_context, @@ -2041,7 +2055,13 @@ def _toolbox_convert_to_agentspec( id=runtime_toolbox.id, description=runtime_toolbox.description, requires_confirmation=runtime_toolbox.requires_confirmation or False, + retry_policy=( + self._retrypolicy_convert_to_agentspec(runtime_toolbox.retry_policy) + if runtime_toolbox.retry_policy is not None + else None + ), ) + return AgentSpecMCPToolBox(**mcp_toolbox_kwargs) if isinstance(runtime_toolbox, RuntimeSearchToolBox): return AgentSpecPluginSearchToolBox( name=runtime_toolbox.name, diff --git a/wayflowcore/tests/agentspec/test_mcptool.py b/wayflowcore/tests/agentspec/test_mcptool.py index 4d22e737..3b1fd31c 100644 --- a/wayflowcore/tests/agentspec/test_mcptool.py +++ b/wayflowcore/tests/agentspec/test_mcptool.py @@ -4,11 +4,14 @@ # (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) or Universal Permissive License # (UPL) 1.0 (LICENSE-UPL or https://oss.oracle.com/licenses/upl), at your option. import json +import warnings import pytest from pyagentspec.adapters._agentspecloader import ( _DEFAULT_BLOCKED_COMPONENTS as _AGENTSPEC_DEFAULT_BLOCKED_COMPONENTS, ) +from pyagentspec.mcp import MCPTool as AgentSpecMCPTool +from pyagentspec.mcp import MCPToolBox as AgentSpecMCPToolBox from pyagentspec.versioning import AgentSpecVersionEnum from wayflowcore import Agent @@ -27,6 +30,7 @@ ) from wayflowcore.mcp._session_persistence import AsyncRuntime from wayflowcore.property import StringProperty +from wayflowcore.retrypolicy import RetryPolicy from wayflowcore.warnings import SecurityWarning @@ -100,6 +104,62 @@ def test_mcp_tool_can_be_converted_to_agentspec_and_back( assert value == all_reloaded_agent_tools[key] +def test_mcp_toolbox_retry_policy_exports_to_native_agentspec_and_round_trips() -> None: + toolbox = MCPToolBox( + client_transport=SSETransport(url="https://example.com/sse"), + tool_filter=["expected_tool"], + retry_policy=RetryPolicy(max_attempts=4), + _validate_mcp_client_transport=False, + ) + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agentspec_toolbox = AgentSpecExporter().to_component(toolbox) + + assert not [warning for warning in captured_warnings if "retry_policy" in str(warning.message)] + assert isinstance(agentspec_toolbox, AgentSpecMCPToolBox) + assert agentspec_toolbox.retry_policy is not None + assert agentspec_toolbox.retry_policy.max_attempts == 4 + + with pytest.warns(match="without authentication"): + with authless_mcp_enabled(): + deserialized_toolbox = AgentSpecLoader().load_component(agentspec_toolbox) + + assert isinstance(deserialized_toolbox, MCPToolBox) + assert deserialized_toolbox.retry_policy is not None + assert deserialized_toolbox.retry_policy.max_attempts == 4 + + +def test_mcp_tool_retry_policy_exports_to_native_agentspec_and_round_trips() -> None: + with pytest.warns(match="without authentication"): + with authless_mcp_enabled(): + mcp_tool = MCPTool( + name="expected_tool", + description="Expected tool", + input_descriptors=[], + client_transport=SSETransport(url="https://example.com/sse"), + _validate_server_exists=False, + retry_policy=RetryPolicy(max_attempts=5), + ) + + with warnings.catch_warnings(record=True) as captured_warnings: + warnings.simplefilter("always") + agentspec_tool = AgentSpecExporter().to_component(mcp_tool) + + assert not [warning for warning in captured_warnings if "retry_policy" in str(warning.message)] + assert isinstance(agentspec_tool, AgentSpecMCPTool) + assert agentspec_tool.retry_policy is not None + assert agentspec_tool.retry_policy.max_attempts == 5 + + with pytest.warns(match="without authentication"): + with authless_mcp_enabled(): + deserialized_tool = AgentSpecLoader().load_component(agentspec_tool) + + assert isinstance(deserialized_tool, MCPTool) + assert deserialized_tool.retry_policy is not None + assert deserialized_tool.retry_policy.max_attempts == 5 + + def _make_agent_with_mcp_stdio_transport(remotely_hosted_llm): client_transport = StdioTransport(command="command", cwd="..") mcp_tool = MCPTool( diff --git a/wayflowcore/tests/mcptools/test_mcp_tools.py b/wayflowcore/tests/mcptools/test_mcp_tools.py index e752e033..5015d262 100644 --- a/wayflowcore/tests/mcptools/test_mcp_tools.py +++ b/wayflowcore/tests/mcptools/test_mcp_tools.py @@ -8,13 +8,15 @@ import re import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, Generator, List, Tuple, cast +from typing import Any, Awaitable, Callable, Dict, Generator, List, Tuple, cast from unittest.mock import patch import anyio import httpx import pytest from anyio import to_thread +from mcp import ClientSession +from mcp import types as mcp_types from wayflowcore import Agent, Flow from wayflowcore.auth import AuthChallengeResult @@ -22,6 +24,7 @@ from wayflowcore.conversation import _register_conversation from wayflowcore.events.event import Event, ToolExecutionStreamingChunkReceivedEvent from wayflowcore.events.eventlistener import EventListener, register_event_listeners +from wayflowcore.exceptions import NoSuchToolFoundOnMCPServerError from wayflowcore.executors._agentexecutor import AgentConversationExecutor from wayflowcore.executors.executionstatus import ( AuthChallengeRequestStatus, @@ -45,7 +48,12 @@ ) from wayflowcore.mcp._auth import headless_auth_flow_handler from wayflowcore.mcp._session_persistence import AsyncRuntime, get_mcp_async_runtime -from wayflowcore.mcp.mcphelpers import _reset_mcp_contextvar, mcp_streaming_tool +from wayflowcore.mcp.mcphelpers import ( + _classify_mcp_tool_call_for_retry, + _reset_mcp_contextvar, + get_server_tools_from_mcp_server, + mcp_streaming_tool, +) from wayflowcore.property import ( AnyProperty, BooleanProperty, @@ -56,6 +64,7 @@ StringProperty, UnionProperty, ) +from wayflowcore.retrypolicy import RetryPolicy from wayflowcore.serialization import autodeserialize, serialize from wayflowcore.steps import MapStep, OutputMessageStep, ToolExecutionStep from wayflowcore.steps.flowexecutionstep import FlowExecutionStep @@ -212,6 +221,312 @@ def run_toolbox_test(transport: ClientTransport) -> None: assert mcp_tool.input_descriptors == [IntegerProperty(name="a"), IntegerProperty(name="b")] +def _make_mcp_tool_signature(name: str) -> mcp_types.Tool: + return mcp_types.Tool( + name=name, + description=f"{name} description", + inputSchema={"type": "object", "properties": {}}, + ) + + +class _ListToolsSession: + """Test double for ClientSession.list_tools with per-attempt tool snapshots. + + This models an MCP server that can return a partial tool list on early calls + and a different tool set after retries. + """ + + def __init__(self, tool_names_by_attempt: List[List[str]]) -> None: + self.tool_names_by_attempt = tool_names_by_attempt + self.calls = 0 + + async def list_tools(self) -> mcp_types.ListToolsResult: + attempt_index = min(self.calls, len(self.tool_names_by_attempt) - 1) + self.calls += 1 + return mcp_types.ListToolsResult( + tools=[ + _make_mcp_tool_signature(tool_name) + for tool_name in self.tool_names_by_attempt[attempt_index] + ] + ) + + +class _CallToolSession: + """Test double for ClientSession.call_tool with per-attempt outcomes.""" + + def __init__(self, outcomes_by_attempt: List[Any]) -> None: + self.outcomes_by_attempt = outcomes_by_attempt + self.calls = 0 + + async def call_tool( + self, + name: str, + arguments: Dict[str, Any], + progress_callback: Any = None, + ) -> mcp_types.CallToolResult: + outcome = self.outcomes_by_attempt[self.calls] + self.calls += 1 + if isinstance(outcome, BaseException): + raise outcome + return mcp_types.CallToolResult(content=[mcp_types.TextContent(type="text", text=outcome)]) + + +class _RunAsyncRuntime: + """Test double for MCPTool.run_async runtime calls using a fixed session.""" + + def __init__(self, session: _CallToolSession) -> None: + self.session = session + + def get_or_create_session(self, _transport: ClientTransport) -> object: + return self.session + + async def call_async( + self, + async_fn: Callable[..., Awaitable[Any]], + /, + *args: Any, + **kwargs: Any, + ) -> Any: + return await async_fn(*args, **kwargs) + + +def _make_http_status_error(status_code: int) -> httpx.HTTPStatusError: + request = httpx.Request("POST", "https://mcp.example.com/mcp") + response = httpx.Response(status_code, request=request) + return httpx.HTTPStatusError( + f"HTTP status {status_code}", + request=request, + response=response, + ) + + +@pytest.mark.anyio +async def test_mcp_tool_list_resolution_retry_recovers_missing_expected_tool(with_mcp_enabled): + session = _ListToolsSession([["b_tool", "c_tool"], ["a_tool", "b_tool", "c_tool"]]) + retry_policy = RetryPolicy( + max_attempts=1, + initial_retry_delay=0.001, + max_retry_delay=0.001, + jitter=None, + ) + + with authless_mcp_enabled(): + tools = await get_server_tools_from_mcp_server( + cast(ClientSession, session), + {"a_tool": None}, + SSETransport(url="anything"), + retry_policy, + ) + + assert session.calls == 2 + assert [tool.name for tool in tools] == ["a_tool"] + assert cast(MCPTool, tools[0]).retry_policy is retry_policy + + +@pytest.mark.anyio +async def test_mcp_tool_list_resolution_retry_exhaustion_has_missing_tool_metadata( + with_mcp_enabled, +): + session = _ListToolsSession([["b_tool", "c_tool"], ["b_tool", "c_tool"]]) + retry_policy = RetryPolicy( + max_attempts=1, + initial_retry_delay=0.001, + max_retry_delay=0.001, + jitter=None, + ) + + with pytest.raises(NoSuchToolFoundOnMCPServerError) as exc_info: + await get_server_tools_from_mcp_server( + cast(ClientSession, session), + {"a_tool": None}, + SSETransport(url="anything"), + retry_policy, + ) + + assert session.calls == 2 + assert exc_info.value.missing_tool_names == ["a_tool"] + assert exc_info.value.expected_tool_names == ["a_tool"] + assert exc_info.value.exposed_tool_names == ["b_tool", "c_tool"] + assert exc_info.value.attempts == 2 + + +@pytest.mark.anyio +async def test_mcp_tool_list_resolution_does_not_retry_without_expected_tools(with_mcp_enabled): + session = _ListToolsSession([["b_tool"], ["a_tool", "b_tool"]]) + retry_policy = RetryPolicy( + max_attempts=1, + initial_retry_delay=0.001, + max_retry_delay=0.001, + jitter=None, + ) + + with authless_mcp_enabled(): + tools = await get_server_tools_from_mcp_server( + cast(ClientSession, session), + {}, + SSETransport(url="anything"), + retry_policy, + ) + + assert session.calls == 1 + assert [tool.name for tool in tools] == ["b_tool"] + assert cast(MCPTool, tools[0]).retry_policy is retry_policy + + +@pytest.mark.anyio +async def test_mcp_tool_call_retry_recovers_from_transient_404( + monkeypatch: pytest.MonkeyPatch, +) -> None: + session = _CallToolSession([_make_http_status_error(404), "tool result"]) + retry_policy = RetryPolicy( + max_attempts=1, + initial_retry_delay=0.001, + max_retry_delay=0.001, + jitter=None, + ) + + monkeypatch.setattr( + "wayflowcore.mcp.tools.get_mcp_async_runtime", + lambda: _RunAsyncRuntime(session), + ) + + with pytest.warns(SecurityWarning, match="without authentication"): + with authless_mcp_enabled(): + tool = MCPTool( + name="a_tool", + description="a_tool description", + input_descriptors=[], + output_descriptors=[StringProperty()], + client_transport=SSETransport(url="anything"), + _validate_server_exists=False, + retry_policy=retry_policy, + ) + + result = await tool.run_async() + + assert session.calls == 2 + assert result == "tool result" + + +@pytest.mark.anyio +async def test_mcp_tool_call_passes_arguments_to_public_run_async( + monkeypatch: pytest.MonkeyPatch, +) -> None: + session = _CallToolSession(["tool result"]) + received_arguments: List[Dict[str, Any]] = [] + original_call_tool = session.call_tool + + async def call_tool_with_argument_capture( + name: str, + arguments: Dict[str, Any], + progress_callback: Any = None, + ) -> mcp_types.CallToolResult: + received_arguments.append(arguments) + return await original_call_tool(name, arguments, progress_callback) + + session.call_tool = call_tool_with_argument_capture # type: ignore[method-assign] + + monkeypatch.setattr( + "wayflowcore.mcp.tools.get_mcp_async_runtime", + lambda: _RunAsyncRuntime(session), + ) + + with pytest.warns(SecurityWarning, match="without authentication"): + with authless_mcp_enabled(): + tool = MCPTool( + name="a_tool", + description="a_tool description", + input_descriptors=[StringProperty(name="query")], + output_descriptors=[StringProperty()], + client_transport=SSETransport(url="anything"), + _validate_server_exists=False, + retry_policy=RetryPolicy(max_attempts=1), + ) + + result = await tool.run_async(query="hello") + + assert result == "tool result" + assert received_arguments == [{"query": "hello"}] + + +def test_mcp_tool_call_retry_classification_detects_wrapped_404() -> None: + retry_policy = RetryPolicy(max_attempts=1) + wrapped_error = RuntimeError("MCP call failed") + wrapped_error.__cause__ = _make_http_status_error(404) + + assert _classify_mcp_tool_call_for_retry(wrapped_error, retry_policy) == (404, None) + + +@pytest.mark.anyio +async def test_mcp_tool_call_does_not_retry_without_retry_policy( + monkeypatch: pytest.MonkeyPatch, +) -> None: + session = _CallToolSession([_make_http_status_error(404), "tool result"]) + + monkeypatch.setattr( + "wayflowcore.mcp.tools.get_mcp_async_runtime", + lambda: _RunAsyncRuntime(session), + ) + + with pytest.warns(SecurityWarning, match="without authentication"): + with authless_mcp_enabled(): + tool = MCPTool( + name="a_tool", + description="a_tool description", + input_descriptors=[], + output_descriptors=[StringProperty()], + client_transport=SSETransport(url="anything"), + _validate_server_exists=False, + ) + + with pytest.raises(httpx.HTTPStatusError): + await tool.run_async() + + assert session.calls == 1 + + +def test_direct_mcp_tool_passes_retry_policy( + monkeypatch: pytest.MonkeyPatch, + with_mcp_enabled, +) -> None: + retry_policy = RetryPolicy(max_attempts=3) + + class FakeRuntime: + def get_or_create_session(self, transport: ClientTransport) -> object: + return object() + + def call(self, async_fn, /, *args, **kwargs): + return anyio.run(async_fn, *args, **kwargs) + + async def fake_get_tool_on_server( + session: ClientSession, + name: str, + client_transport: ClientTransport, + received_retry_policy: RetryPolicy | None = None, + ) -> Tool: + assert name == "a_tool" + assert received_retry_policy is retry_policy + return Tool( + name="a_tool", + description="a_tool description", + input_descriptors=[], + output_descriptors=[StringProperty()], + ) + + monkeypatch.setattr("wayflowcore.mcp.tools.get_mcp_async_runtime", lambda: FakeRuntime()) + monkeypatch.setattr("wayflowcore.mcp.tools._get_tool_on_server", fake_get_tool_on_server) + + with authless_mcp_enabled(): + tool = MCPTool( + name="a_tool", + client_transport=SSETransport(url="anything"), + retry_policy=retry_policy, + ) + + assert tool.retry_policy is retry_policy + assert tool.name == "a_tool" + + @pytest.mark.parametrize( "client_transport_name", [ diff --git a/wayflowcore/tests/serialization/test_tool_serialization.py b/wayflowcore/tests/serialization/test_tool_serialization.py index f279188e..0c4226f8 100644 --- a/wayflowcore/tests/serialization/test_tool_serialization.py +++ b/wayflowcore/tests/serialization/test_tool_serialization.py @@ -10,6 +10,7 @@ from wayflowcore.agent import Agent from wayflowcore.flow import Flow +from wayflowcore.mcp import MCPTool, MCPToolBox, SSETransport, authless_mcp_enabled from wayflowcore.property import IntegerProperty, NullProperty, StringProperty, UnionProperty from wayflowcore.retrypolicy import RetryPolicy from wayflowcore.serialization import autodeserialize, deserialize, serialize @@ -253,6 +254,44 @@ def test_remote_tool_retry_policy_round_trips(remote_tool): assert deserialized_tool.retry_policy.max_attempts == 4 +def test_mcp_toolbox_retry_policy_round_trips() -> None: + toolbox = MCPToolBox( + client_transport=SSETransport(url="https://example.com/sse"), + tool_filter=["expected_tool"], + retry_policy=RetryPolicy(max_attempts=4), + _validate_mcp_client_transport=False, + ) + + with pytest.warns(match="without authentication"): + with authless_mcp_enabled(): + deserialized_toolbox = autodeserialize(serialize(toolbox)) + + assert isinstance(deserialized_toolbox, MCPToolBox) + assert deserialized_toolbox.retry_policy is not None + assert deserialized_toolbox.retry_policy.max_attempts == 4 + + +def test_mcp_tool_retry_policy_round_trips() -> None: + with pytest.warns(match="without authentication"): + with authless_mcp_enabled(): + mcp_tool = MCPTool( + name="expected_tool", + description="Expected tool", + input_descriptors=[], + client_transport=SSETransport(url="https://example.com/sse"), + _validate_server_exists=False, + retry_policy=RetryPolicy(max_attempts=5), + ) + + with pytest.warns(match="without authentication"): + with authless_mcp_enabled(): + deserialized_tool = autodeserialize(serialize(mcp_tool)) + + assert isinstance(deserialized_tool, MCPTool) + assert deserialized_tool.retry_policy is not None + assert deserialized_tool.retry_policy.max_attempts == 5 + + def test_serialize_agent_with_remote_tools(remotely_hosted_llm, remote_tool): agent = Agent(llm=remotely_hosted_llm, tools=[remote_tool]) serialize_agent = serialize(agent)