diff --git a/scripts/gen_signature.py b/scripts/gen_signature.py index b3a7add..b435e2c 100644 --- a/scripts/gen_signature.py +++ b/scripts/gen_signature.py @@ -9,6 +9,11 @@ from acp import schema +SIGNATURE_OPTIONAL_FIELDS: set[tuple[str, str]] = { + ("LoadSessionRequest", "mcp_servers"), + ("NewSessionRequest", "mcp_servers"), +} + class NodeTransformer(ast.NodeTransformer): def __init__(self) -> None: @@ -16,6 +21,7 @@ def __init__(self) -> None: self._schema_import_node: ast.ImportFrom | None = None self._should_rewrite = False self._literals = {name: value for name, value in schema.__dict__.items() if t.get_origin(value) is t.Literal} + self._current_model_name: str | None = None def _add_typing_import(self, name: str) -> None: if not self._type_import_node: @@ -71,9 +77,13 @@ def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST: self._should_rewrite = True model_name = t.cast(ast.Name, decorator.args[0]).id model = t.cast(type[schema.BaseModel], getattr(schema, model_name)) - param_defaults = [ - self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta" - ] + self._current_model_name = model_name + try: + param_defaults = [ + self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta" + ] + finally: + self._current_model_name = None param_defaults.sort(key=lambda x: x[1] is not None) node.args.args[1:] = [param for param, _ in param_defaults] node.args.defaults = [default for _, default in param_defaults if default is not None] @@ -84,12 +94,18 @@ def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST: def _to_param_def(self, name: str, field: FieldInfo) -> tuple[ast.arg, ast.expr | None]: arg = ast.arg(arg=name) ann = field.annotation - if field.default is PydanticUndefined: - default = None - elif isinstance(field.default, dict | BaseModel): + override_optional = (self._current_model_name, name) in SIGNATURE_OPTIONAL_FIELDS + if override_optional: + if ann is not None: + ann = ann | None default = ast.Constant(None) else: - default = ast.Constant(value=field.default) + if field.default is PydanticUndefined: + default = None + elif isinstance(field.default, dict | BaseModel): + default = ast.Constant(None) + else: + default = ast.Constant(value=field.default) if ann is not None: arg.annotation = self._format_annotation(ann) return arg, default diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index ac0d34f..9831d7e 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -93,23 +93,31 @@ async def initialize( @param_model(NewSessionRequest) async def new_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, **kwargs: Any ) -> NewSessionResponse: + resolved_mcp_servers = mcp_servers or [] return await request_model( self._conn, AGENT_METHODS["session_new"], - NewSessionRequest(cwd=cwd, mcp_servers=mcp_servers, field_meta=kwargs or None), + NewSessionRequest(cwd=cwd, mcp_servers=resolved_mcp_servers, field_meta=kwargs or None), NewSessionResponse, ) @param_model(LoadSessionRequest) async def load_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + self, + cwd: str, + session_id: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, ) -> LoadSessionResponse: + resolved_mcp_servers = mcp_servers or [] return await request_model_from_dict( self._conn, AGENT_METHODS["session_load"], - LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None), + LoadSessionRequest( + cwd=cwd, mcp_servers=resolved_mcp_servers, session_id=session_id, field_meta=kwargs or None + ), LoadSessionResponse, ) diff --git a/src/acp/interfaces.py b/src/acp/interfaces.py index 457dfe7..55c00f3 100644 --- a/src/acp/interfaces.py +++ b/src/acp/interfaces.py @@ -154,12 +154,16 @@ async def initialize( @param_model(NewSessionRequest) async def new_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any + self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, **kwargs: Any ) -> NewSessionResponse: ... @param_model(LoadSessionRequest) async def load_session( - self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any + self, + cwd: str, + session_id: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, ) -> LoadSessionResponse | None: ... @param_model(ListSessionsRequest) diff --git a/tests/real_user/test_mcp_servers_optional.py b/tests/real_user/test_mcp_servers_optional.py new file mode 100644 index 0000000..96aae75 --- /dev/null +++ b/tests/real_user/test_mcp_servers_optional.py @@ -0,0 +1,68 @@ +import asyncio +from typing import Any + +import pytest + +from acp import InitializeResponse, LoadSessionResponse, NewSessionResponse +from acp.core import AgentSideConnection, ClientSideConnection +from acp.schema import HttpMcpServer, McpServerStdio, SseMcpServer +from tests.conftest import TestAgent, TestClient + + +class McpOptionalAgent(TestAgent): + def __init__(self) -> None: + super().__init__() + self.seen_new_session: tuple[str, Any] | None = None + self.seen_load_session: tuple[str, str, Any] | None = None + + async def new_session( + self, + cwd: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, + ) -> NewSessionResponse: + resolved_mcp_servers = mcp_servers or [] + self.seen_new_session = (cwd, resolved_mcp_servers) + return await super().new_session(cwd=cwd, mcp_servers=resolved_mcp_servers, **kwargs) + + async def load_session( + self, + cwd: str, + session_id: str, + mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, + **kwargs: Any, + ) -> LoadSessionResponse | None: + resolved_mcp_servers = mcp_servers or [] + self.seen_load_session = (cwd, session_id, resolved_mcp_servers) + return await super().load_session(cwd=cwd, session_id=session_id, mcp_servers=resolved_mcp_servers, **kwargs) + + +@pytest.mark.asyncio +async def test_session_requests_default_empty_mcp_servers(server) -> None: + client = TestClient() + captured_agent: list[McpOptionalAgent] = [] + + agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type] + _agent_side = AgentSideConnection( + lambda _conn: captured_agent.append(McpOptionalAgent()) or captured_agent[-1], + server._server_writer, + server._server_reader, + listening=True, + ) + + init = await asyncio.wait_for(agent_conn.initialize(protocol_version=1), timeout=1.0) + assert isinstance(init, InitializeResponse) + + new_session = await asyncio.wait_for(agent_conn.new_session(cwd="/workspace"), timeout=1.0) + assert isinstance(new_session, NewSessionResponse) + + load_session = await asyncio.wait_for( + agent_conn.load_session(cwd="/workspace", session_id=new_session.session_id), + timeout=1.0, + ) + assert isinstance(load_session, LoadSessionResponse) + + assert captured_agent, "Agent was not constructed" + [agent] = captured_agent + assert agent.seen_new_session == ("/workspace", []) + assert agent.seen_load_session == ("/workspace", new_session.session_id, [])