diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index a9b55f526e..db0dde0b7a 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -38,6 +38,7 @@ from google.adk.runners import Runner from typing_extensions import override +from ...errors.session_not_found_error import SessionNotFoundError from ...utils.context_utils import Aclosing from ..agent.interceptors.new_integration_extension import _NEW_A2A_ADK_INTEGRATION_EXTENSION from ..converters.request_converter import AgentRunRequest @@ -75,12 +76,14 @@ def __init__( config: Optional[A2aAgentExecutorConfig] = None, use_legacy: bool = False, force_new_version: bool = False, + auto_create_session: bool = True, ): super().__init__() self._runner = runner self._config = config or A2aAgentExecutorConfig() self._use_legacy = use_legacy self._force_new_version = force_new_version + self._auto_create_session = auto_create_session self._executor_impl = None async def _resolve_runner(self) -> Runner: @@ -141,6 +144,7 @@ async def execute( self._executor_impl = ExecutorImpl( runner=self._runner, config=self._config, + auto_create_session=self._auto_create_session, ) await self._executor_impl.execute(context, event_queue) return @@ -328,28 +332,27 @@ async def _prepare_session( run_request: AgentRunRequest, runner: Runner, ): - - session_id = run_request.session_id - # create a new session if not exists - user_id = run_request.user_id session = await runner.session_service.get_session( app_name=runner.app_name, - user_id=user_id, - session_id=session_id, + user_id=run_request.user_id, + session_id=run_request.session_id, ) - if session is None: - session = await runner.session_service.create_session( - app_name=runner.app_name, - user_id=user_id, - state={}, - session_id=session_id, - ) - # Update run_request with the new session_id - run_request.session_id = session.id - + if not session: + if self._auto_create_session: + session = await runner.session_service.create_session( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + ) + else: + raise SessionNotFoundError( + f'Session not found: {run_request.session_id}' + ) + # Update run_request with the new session_id + run_request.session_id = session.id return session - def _check_new_version_extension(self, context: RequestContext): + def _check_new_version_extension(self, context: RequestContext) -> bool: """Check if the extension for the new version is requested and activate it.""" if _NEW_A2A_ADK_INTEGRATION_EXTENSION in context.requested_extensions: context.add_activated_extension(_NEW_A2A_ADK_INTEGRATION_EXTENSION) diff --git a/src/google/adk/a2a/executor/a2a_agent_executor_impl.py b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py index 320af124df..95ba93e1fa 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor_impl.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor_impl.py @@ -18,6 +18,7 @@ from datetime import timezone import inspect import logging +from typing import Any from typing import Awaitable from typing import Callable from typing import Optional @@ -37,6 +38,7 @@ from a2a.types import TextPart from typing_extensions import override +from ...errors.session_not_found_error import SessionNotFoundError from ...runners import Runner from ...sessions import base_session_service from ...utils.context_utils import Aclosing @@ -69,10 +71,12 @@ def __init__( *, runner: Runner | Callable[..., Runner | Awaitable[Runner]], config: Optional[A2aAgentExecutorConfig] = None, + auto_create_session: bool = True, ): super().__init__() self._runner = runner self._config = config or A2aAgentExecutorConfig() + self._auto_create_session = auto_create_session @override async def cancel(self, context: RequestContext, event_queue: EventQueue): @@ -281,29 +285,34 @@ async def _resolve_session( run_request: AgentRunRequest, runner: Runner, ): - session_id = run_request.session_id - # create a new session if not exists - user_id = run_request.user_id + if not run_request.user_id: + raise ValueError('user_id must be set in AgentRunRequest') + if not run_request.session_id: + raise ValueError('session_id must be set in AgentRunRequest') session = await runner.session_service.get_session( app_name=runner.app_name, - user_id=user_id, - session_id=session_id, + user_id=run_request.user_id, + session_id=run_request.session_id, # Checking existence doesn't require event history. config=base_session_service.GetSessionConfig(num_recent_events=0), ) - if session is None: - session = await runner.session_service.create_session( - app_name=runner.app_name, - user_id=user_id, - state={}, - session_id=session_id, - ) - # Update run_request with the new session_id - run_request.session_id = session.id + if not session: + if self._auto_create_session: + session = await runner.session_service.create_session( + app_name=runner.app_name, + user_id=run_request.user_id, + session_id=run_request.session_id, + ) + else: + raise SessionNotFoundError( + f'Session not found: {run_request.session_id}' + ) + # Update run_request with the new session_id + run_request.session_id = session.id def _get_invocation_metadata( self, executor_context: ExecutorContext - ) -> dict[str, str]: + ) -> dict[str, Any]: return { _get_adk_metadata_key('app_name'): executor_context.app_name, _get_adk_metadata_key('user_id'): executor_context.user_id, diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 3e8ed461e2..ddeb3129ce 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -14,6 +14,7 @@ from __future__ import annotations +from contextlib import AbstractAsyncContextManager from contextlib import asynccontextmanager import logging from typing import AsyncIterator @@ -85,7 +86,9 @@ def to_a2a( agent_card: Optional[Union[AgentCard, str]] = None, push_config_store: Optional[PushNotificationConfigStore] = None, runner: Optional[Runner] = None, - lifespan: Optional[Callable[[Starlette], AsyncIterator[None]]] = None, + lifespan: Optional[ + Callable[[Starlette], AbstractAsyncContextManager[None]] + ] = None, ) -> Starlette: """Convert an ADK agent to a A2A Starlette application. @@ -170,7 +173,7 @@ async def create_runner() -> Runner: ) # Build the agent card and configure A2A routes - async def setup_a2a(app: Starlette): + async def setup_a2a(app: Starlette) -> None: # Use provided agent card or build one asynchronously if provided_agent_card is not None: final_agent_card = provided_agent_card diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index 0906e9a6ba..c95fa06d2e 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -41,7 +41,7 @@ # Strong references to fire-and-forget tasks to prevent garbage collection. # See https://docs.python.org/3/library/asyncio-task.html#creating-tasks -_background_tasks: set[asyncio.Task] = set() +_background_tasks: set[asyncio.Task[None]] = set() _GENERATE_MEMORIES_CONFIG_FALLBACK_KEYS = frozenset({ 'disable_consolidation', @@ -565,7 +565,7 @@ def _get_api_client(self) -> vertexai.AsyncClient: return vertexai.Client(project=self._project, location=self._location).aio -def _log_ingest_task_error(task: asyncio.Task) -> None: +def _log_ingest_task_error(task: asyncio.Task[None]) -> None: """Logs errors from fire-and-forget ingest_events tasks.""" if task.cancelled(): return diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 4f44e1363c..3ced00b21f 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -224,18 +224,17 @@ async def mock_run_async(**kwargs): @pytest.mark.asyncio async def test_prepare_session_new_session(self): - """Test session preparation when session doesn't exist.""" + """Test session preparation creates a new session when none is found.""" run_args = AgentRunRequest( user_id="test-user", - session_id=None, + session_id="new-session-id", new_message=Mock(spec=Content), run_config=Mock(spec=RunConfig), ) - # Mock session service - self.mock_runner.session_service.get_session = AsyncMock(return_value=None) mock_session = Mock() mock_session.id = "new-session-id" + self.mock_runner.session_service.get_session = AsyncMock(return_value=None) self.mock_runner.session_service.create_session = AsyncMock( return_value=mock_session ) @@ -245,14 +244,23 @@ async def test_prepare_session_new_session(self): self.mock_context, run_args, self.mock_runner ) - # Verify session was created + # Verify session was created and run_request updated assert result == mock_session - assert run_args.session_id is not None - self.mock_runner.session_service.create_session.assert_called_once() + assert run_args.session_id == "new-session-id" + self.mock_runner.session_service.get_session.assert_called_once_with( + app_name="test-app", + user_id="test-user", + session_id="new-session-id", + ) + self.mock_runner.session_service.create_session.assert_called_once_with( + app_name="test-app", + user_id="test-user", + session_id="new-session-id", + ) @pytest.mark.asyncio async def test_prepare_session_existing_session(self): - """Test session preparation when session exists.""" + """Test session preparation returns existing session without creating one.""" run_args = AgentRunRequest( user_id="test-user", session_id="existing-session", @@ -260,20 +268,25 @@ async def test_prepare_session_existing_session(self): run_config=Mock(spec=RunConfig), ) - # Mock session service mock_session = Mock() mock_session.id = "existing-session" self.mock_runner.session_service.get_session = AsyncMock( return_value=mock_session ) + self.mock_runner.session_service.create_session = AsyncMock() # Execute result = await self.executor._prepare_session( self.mock_context, run_args, self.mock_runner ) - # Verify existing session was returned + # Verify existing session was returned without creating a new one assert result == mock_session + self.mock_runner.session_service.get_session.assert_called_once_with( + app_name="test-app", + user_id="test-user", + session_id="existing-session", + ) self.mock_runner.session_service.create_session.assert_not_called() def test_constructor_with_callable_runner(self): diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py index 940b79a0b9..b0af5df7a2 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor_impl.py @@ -97,7 +97,7 @@ async def test_execute_success_new_task(self): new_message=Mock(spec=Content), run_config=Mock(spec=RunConfig), ) - # Mock session service + # Mock session lookup returning existing session mock_session = Mock() mock_session.id = "test-session" self.mock_runner.session_service.get_session = AsyncMock( @@ -200,7 +200,7 @@ async def test_execute_existing_task(self): run_config=Mock(spec=RunConfig), ) - # Mock session service + # Mock session lookup returning existing session mock_session = Mock() mock_session.id = "test-session" self.mock_runner.session_service.get_session = AsyncMock( @@ -638,18 +638,17 @@ async def test_execute_missing_user_input(self, mock_handle_user_input): @pytest.mark.asyncio async def test_resolve_session_creates_new_session(self): - """Test that _resolve_session creates a new session if it doesn't exist.""" - self.mock_runner.session_service.get_session = AsyncMock(return_value=None) - + """Test that _resolve_session creates a session when none is found.""" new_session = Mock() new_session.id = "new-session-id" + self.mock_runner.session_service.get_session = AsyncMock(return_value=None) self.mock_runner.session_service.create_session = AsyncMock( return_value=new_session ) run_request = AgentRunRequest( user_id="test-user", - session_id="old-session-id", + session_id="new-session-id", new_message=Mock(spec=Content), run_config=Mock(spec=RunConfig), ) @@ -657,16 +656,15 @@ async def test_resolve_session_creates_new_session(self): await self.executor._resolve_session(run_request, self.mock_runner) self.mock_runner.session_service.get_session.assert_called_once_with( - app_name=self.mock_runner.app_name, + app_name="test-app", user_id="test-user", - session_id="old-session-id", + session_id="new-session-id", config=GetSessionConfig(num_recent_events=0, after_timestamp=None), ) self.mock_runner.session_service.create_session.assert_called_once_with( - app_name=self.mock_runner.app_name, + app_name="test-app", user_id="test-user", - state={}, - session_id="old-session-id", + session_id="new-session-id", ) assert run_request.session_id == "new-session-id" diff --git a/tests/unittests/a2a/integration/server.py b/tests/unittests/a2a/integration/server.py index c965a71091..aa0c9d6314 100644 --- a/tests/unittests/a2a/integration/server.py +++ b/tests/unittests/a2a/integration/server.py @@ -44,6 +44,7 @@ def __init__(self, run_async_fn): app_name="FakeApp", agent=agent, session_service=session_service, + auto_create_session=True, ) self.run_async_fn = run_async_fn