Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions scripts/gen_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@

from acp import schema

SIGNATURE_OPTIONAL_FIELDS: set[tuple[str, str]] = {
("LoadSessionRequest", "mcp_servers"),
("NewSessionRequest", "mcp_servers"),
}


class NodeTransformer(ast.NodeTransformer):
def __init__(self) -> None:
self._type_import_node: ast.ImportFrom | None = None
self._schema_import_node: ast.ImportFrom | None = None
self._should_rewrite = False
self._literals = {name: value for name, value in schema.__dict__.items() if t.get_origin(value) is t.Literal}
self._current_model_name: str | None = None

def _add_typing_import(self, name: str) -> None:
if not self._type_import_node:
Expand Down Expand Up @@ -71,9 +77,13 @@ def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST:
self._should_rewrite = True
model_name = t.cast(ast.Name, decorator.args[0]).id
model = t.cast(type[schema.BaseModel], getattr(schema, model_name))
param_defaults = [
self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta"
]
self._current_model_name = model_name
try:
param_defaults = [
self._to_param_def(name, field) for name, field in model.model_fields.items() if name != "field_meta"
]
finally:
self._current_model_name = None
param_defaults.sort(key=lambda x: x[1] is not None)
node.args.args[1:] = [param for param, _ in param_defaults]
node.args.defaults = [default for _, default in param_defaults if default is not None]
Expand All @@ -84,12 +94,18 @@ def visit_func(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.AST:
def _to_param_def(self, name: str, field: FieldInfo) -> tuple[ast.arg, ast.expr | None]:
arg = ast.arg(arg=name)
ann = field.annotation
if field.default is PydanticUndefined:
default = None
elif isinstance(field.default, dict | BaseModel):
override_optional = (self._current_model_name, name) in SIGNATURE_OPTIONAL_FIELDS
if override_optional:
if ann is not None:
ann = ann | None
default = ast.Constant(None)
else:
default = ast.Constant(value=field.default)
if field.default is PydanticUndefined:
default = None
elif isinstance(field.default, dict | BaseModel):
default = ast.Constant(None)
else:
default = ast.Constant(value=field.default)
if ann is not None:
arg.annotation = self._format_annotation(ann)
return arg, default
Expand Down
16 changes: 12 additions & 4 deletions src/acp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,31 @@ async def initialize(

@param_model(NewSessionRequest)
async def new_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, **kwargs: Any
) -> NewSessionResponse:
resolved_mcp_servers = mcp_servers or []
return await request_model(
self._conn,
AGENT_METHODS["session_new"],
NewSessionRequest(cwd=cwd, mcp_servers=mcp_servers, field_meta=kwargs or None),
NewSessionRequest(cwd=cwd, mcp_servers=resolved_mcp_servers, field_meta=kwargs or None),
NewSessionResponse,
)

@param_model(LoadSessionRequest)
async def load_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
self,
cwd: str,
session_id: str,
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
**kwargs: Any,
) -> LoadSessionResponse:
resolved_mcp_servers = mcp_servers or []
return await request_model_from_dict(
self._conn,
AGENT_METHODS["session_load"],
LoadSessionRequest(cwd=cwd, mcp_servers=mcp_servers, session_id=session_id, field_meta=kwargs or None),
LoadSessionRequest(
cwd=cwd, mcp_servers=resolved_mcp_servers, session_id=session_id, field_meta=kwargs or None
),
LoadSessionResponse,
)

Expand Down
8 changes: 6 additions & 2 deletions src/acp/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,16 @@ async def initialize(

@param_model(NewSessionRequest)
async def new_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], **kwargs: Any
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None, **kwargs: Any
) -> NewSessionResponse: ...

@param_model(LoadSessionRequest)
async def load_session(
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio], session_id: str, **kwargs: Any
self,
cwd: str,
session_id: str,
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
**kwargs: Any,
) -> LoadSessionResponse | None: ...

@param_model(ListSessionsRequest)
Expand Down
68 changes: 68 additions & 0 deletions tests/real_user/test_mcp_servers_optional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import asyncio
from typing import Any

import pytest

from acp import InitializeResponse, LoadSessionResponse, NewSessionResponse
from acp.core import AgentSideConnection, ClientSideConnection
from acp.schema import HttpMcpServer, McpServerStdio, SseMcpServer
from tests.conftest import TestAgent, TestClient


class McpOptionalAgent(TestAgent):
def __init__(self) -> None:
super().__init__()
self.seen_new_session: tuple[str, Any] | None = None
self.seen_load_session: tuple[str, str, Any] | None = None

async def new_session(
self,
cwd: str,
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
**kwargs: Any,
) -> NewSessionResponse:
resolved_mcp_servers = mcp_servers or []
self.seen_new_session = (cwd, resolved_mcp_servers)
return await super().new_session(cwd=cwd, mcp_servers=resolved_mcp_servers, **kwargs)

async def load_session(
self,
cwd: str,
session_id: str,
mcp_servers: list[HttpMcpServer | SseMcpServer | McpServerStdio] | None = None,
**kwargs: Any,
) -> LoadSessionResponse | None:
resolved_mcp_servers = mcp_servers or []
self.seen_load_session = (cwd, session_id, resolved_mcp_servers)
return await super().load_session(cwd=cwd, session_id=session_id, mcp_servers=resolved_mcp_servers, **kwargs)


@pytest.mark.asyncio
async def test_session_requests_default_empty_mcp_servers(server) -> None:
client = TestClient()
captured_agent: list[McpOptionalAgent] = []

agent_conn = ClientSideConnection(client, server._client_writer, server._client_reader) # type: ignore[arg-type]
_agent_side = AgentSideConnection(
lambda _conn: captured_agent.append(McpOptionalAgent()) or captured_agent[-1],
server._server_writer,
server._server_reader,
listening=True,
)

init = await asyncio.wait_for(agent_conn.initialize(protocol_version=1), timeout=1.0)
assert isinstance(init, InitializeResponse)

new_session = await asyncio.wait_for(agent_conn.new_session(cwd="/workspace"), timeout=1.0)
assert isinstance(new_session, NewSessionResponse)

load_session = await asyncio.wait_for(
agent_conn.load_session(cwd="/workspace", session_id=new_session.session_id),
timeout=1.0,
)
assert isinstance(load_session, LoadSessionResponse)

assert captured_agent, "Agent was not constructed"
[agent] = captured_agent
assert agent.seen_new_session == ("/workspace", [])
assert agent.seen_load_session == ("/workspace", new_session.session_id, [])