From baa9ca5ce36be0ed9d3305c81f83ac52f7545672 Mon Sep 17 00:00:00 2001 From: Nicolas Mota Date: Tue, 5 May 2026 13:19:05 -0300 Subject: [PATCH 1/3] feat(sessions): add get_user_state to BaseSessionService Adds a public get_user_state(app_name, user_id) method so callers can read user-scoped state directly without requiring an active session_id. This closes the API gap that forced callers to either (a) call list_sessions just to access user state or (b) maintain a separate process-level cache as a workaround. Implementations: - InMemorySessionService: reads from self.user_state dict - DatabaseSessionService: queries StorageUserState by (app_name, user_id) - SqliteSessionService: delegates to existing _get_user_state helper - VertexAiSessionService: raises NotImplementedError (Vertex AI Agent Engine API does not expose user state independently of a session) - BaseSessionService: default raises NotImplementedError so existing custom subclasses continue to compile Returns raw keys without the user: prefix, consistent with how user state is stored internally (the prefix is applied by the state-merging layer when building the merged session view). Co-authored-by: Cursor --- .../adk/sessions/base_session_service.py | 38 +++++ .../adk/sessions/database_session_service.py | 16 ++ .../adk/sessions/in_memory_session_service.py | 8 + .../adk/sessions/sqlite_session_service.py | 7 + .../adk/sessions/vertex_ai_session_service.py | 13 ++ .../unittests/sessions/test_get_user_state.py | 159 ++++++++++++++++++ 6 files changed, 241 insertions(+) create mode 100644 tests/unittests/sessions/test_get_user_state.py diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index af94bb9eeb..f9cf3d06f0 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -111,6 +111,44 @@ async def delete_session( ) -> None: """Deletes a session.""" + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + """Returns the user-scoped state for the given app and user. + + User state is keyed by ``(app_name, user_id)`` and shared across all + sessions of the same user within the same app. The returned dictionary + uses raw keys **without** the ``user:`` prefix (e.g. ``"my_key"`` rather + than ``"user:my_key"``). + + This method exists so that callers can read user state without holding an + active ``session_id``. A common use case is bootstrapping context at the + start of a new session before calling ``create_session``, which would + otherwise require an expensive ``list_sessions`` call just to access + user-scoped data. + + Returns an empty dict when no user state has been stored for this + ``(app_name, user_id)`` combination. + + Args: + app_name: The name of the app. + user_id: The ID of the user. + + Returns: + A dictionary of raw (un-prefixed) user-scoped key/value pairs, or an + empty dict when no user state exists. + + Raises: + NotImplementedError: When the concrete ``BaseSessionService`` + implementation does not support reading user state independently of a + session. Callers should catch this and fall back to + ``list_sessions`` or accept that user state is unavailable. + """ + raise NotImplementedError( + f'{type(self).__name__} does not support get_user_state. ' + 'Use list_sessions to retrieve user state indirectly.' + ) + async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index d033f1f234..a07b9db7f7 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -616,6 +616,22 @@ async def delete_session( await sql_session.execute(stmt) await sql_session.commit() + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + await self._prepare_tables() + schema = self._get_schema_classes() + async with self._rollback_on_exception_session( + read_only=True + ) as sql_session: + storage_user_state = await sql_session.get( + schema.StorageUserState, (app_name, user_id) + ) + if storage_user_state is None: + return {} + return dict(storage_user_state.state) + @override async def append_event(self, session: Session, event: Event) -> Event: await self._prepare_tables() diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index b8f6cfab46..3089fb377e 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -312,6 +312,14 @@ def _delete_session_impl( self.sessions[app_name][user_id].pop(session_id) + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + return dict( + self.user_state.get(app_name, {}).get(user_id, {}) + ) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 427bc3e73e..33f02a17bb 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -359,6 +359,13 @@ async def delete_session( ) await db.commit() + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + async with self._get_db_connection() as db: + return await self._get_user_state(db, app_name, user_id) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 8c1fdc134e..82dc2f3791 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -269,6 +269,19 @@ async def delete_session( logger.error('Error deleting session %s: %s', session_id, e) raise + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + # The Vertex AI Agent Engine API does not expose user state independently + # of a session. Callers that require cross-session user state should use + # list_sessions to retrieve it indirectly. + raise NotImplementedError( + 'VertexAiSessionService does not support get_user_state. ' + 'The Vertex AI Agent Engine API does not expose user state ' + 'independently of a session. Use list_sessions instead.' + ) + @override async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. diff --git a/tests/unittests/sessions/test_get_user_state.py b/tests/unittests/sessions/test_get_user_state.py new file mode 100644 index 0000000000..b65a9cdeae --- /dev/null +++ b/tests/unittests/sessions/test_get_user_state.py @@ -0,0 +1,159 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BaseSessionService.get_user_state across concrete implementations.""" + +import enum + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.database_session_service import DatabaseSessionService +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.sqlite_session_service import SqliteSessionService +import pytest + + +_APP = 'test-app' +_USER = 'user-42' +_OTHER_USER = 'user-99' + + +class SessionServiceType(enum.Enum): + IN_MEMORY = 'IN_MEMORY' + DATABASE = 'DATABASE' + SQLITE = 'SQLITE' + + +def _make_service(service_type: SessionServiceType, tmp_path=None): + if service_type == SessionServiceType.DATABASE: + return DatabaseSessionService('sqlite+aiosqlite:///:memory:') + if service_type == SessionServiceType.SQLITE: + return SqliteSessionService(str(tmp_path / 'sqlite.db')) + return InMemorySessionService() + + +@pytest.fixture( + params=[ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ] +) +async def session_service(request, tmp_path): + """Provides a session service and closes database backends on teardown.""" + service = _make_service(request.param, tmp_path) + yield service + if isinstance(service, DatabaseSessionService): + await service.close() + + +@pytest.mark.asyncio +async def test_get_user_state_returns_empty_dict_when_no_state_exists( + session_service, +): + """Returns {} when (app_name, user_id) has never had state written.""" + state = await session_service.get_user_state(app_name=_APP, user_id=_USER) + assert state == {} + + +@pytest.mark.asyncio +async def test_get_user_state_returns_state_written_via_append_event( + session_service, +): + """State written with the user: prefix is returned without the prefix.""" + session = await session_service.create_session( + app_name=_APP, user_id=_USER + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions( + state_delta={'user:profile': {'name': 'Alice'}, 'session_key': 1} + ), + ), + ) + + state = await session_service.get_user_state(app_name=_APP, user_id=_USER) + + assert state == {'profile': {'name': 'Alice'}} + assert 'session_key' not in state + + +@pytest.mark.asyncio +async def test_get_user_state_is_not_visible_across_users(session_service): + """User state is scoped to (app_name, user_id) — other users see {}.""" + session = await session_service.create_session( + app_name=_APP, user_id=_USER + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:secret': 'only-for-user-42'}), + ), + ) + + other_state = await session_service.get_user_state( + app_name=_APP, user_id=_OTHER_USER + ) + assert other_state == {} + + +@pytest.mark.asyncio +async def test_get_user_state_available_before_session_is_created( + session_service, +): + """Core use case: user state is readable without an active session_id.""" + # Write state via a first session. + first_session = await session_service.create_session( + app_name=_APP, user_id=_USER + ) + await session_service.append_event( + first_session, + Event( + author='system', + actions=EventActions(state_delta={'user:ctx': {'v': 1}}), + ), + ) + + # Simulate a brand-new session_id (not yet created) — get_user_state must + # still return the persisted user state. + state = await session_service.get_user_state(app_name=_APP, user_id=_USER) + assert state == {'ctx': {'v': 1}} + + +@pytest.mark.asyncio +async def test_get_user_state_reflects_latest_write(session_service): + """Subsequent writes overwrite earlier values under the same key.""" + session = await session_service.create_session( + app_name=_APP, user_id=_USER + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:counter': 1}), + ), + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:counter': 2}), + ), + ) + + state = await session_service.get_user_state(app_name=_APP, user_id=_USER) + assert state['counter'] == 2 From 4825af3b542d1bdcf4505fa54b357566c916ebcb Mon Sep 17 00:00:00 2001 From: Nicolas Mota Date: Tue, 5 May 2026 13:23:08 -0300 Subject: [PATCH 2/3] style: apply pre-commit formatting (isort + pyink) Co-authored-by: Cursor --- .../adk/sessions/in_memory_session_service.py | 4 +--- tests/unittests/sessions/test_get_user_state.py | 13 +++---------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 3089fb377e..73a54f398b 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -316,9 +316,7 @@ def _delete_session_impl( async def get_user_state( self, *, app_name: str, user_id: str ) -> dict[str, Any]: - return dict( - self.user_state.get(app_name, {}).get(user_id, {}) - ) + return dict(self.user_state.get(app_name, {}).get(user_id, {})) @override async def append_event(self, session: Session, event: Event) -> Event: diff --git a/tests/unittests/sessions/test_get_user_state.py b/tests/unittests/sessions/test_get_user_state.py index b65a9cdeae..33ce72bbc2 100644 --- a/tests/unittests/sessions/test_get_user_state.py +++ b/tests/unittests/sessions/test_get_user_state.py @@ -23,7 +23,6 @@ from google.adk.sessions.sqlite_session_service import SqliteSessionService import pytest - _APP = 'test-app' _USER = 'user-42' _OTHER_USER = 'user-99' @@ -72,9 +71,7 @@ async def test_get_user_state_returns_state_written_via_append_event( session_service, ): """State written with the user: prefix is returned without the prefix.""" - session = await session_service.create_session( - app_name=_APP, user_id=_USER - ) + session = await session_service.create_session(app_name=_APP, user_id=_USER) await session_service.append_event( session, Event( @@ -94,9 +91,7 @@ async def test_get_user_state_returns_state_written_via_append_event( @pytest.mark.asyncio async def test_get_user_state_is_not_visible_across_users(session_service): """User state is scoped to (app_name, user_id) — other users see {}.""" - session = await session_service.create_session( - app_name=_APP, user_id=_USER - ) + session = await session_service.create_session(app_name=_APP, user_id=_USER) await session_service.append_event( session, Event( @@ -137,9 +132,7 @@ async def test_get_user_state_available_before_session_is_created( @pytest.mark.asyncio async def test_get_user_state_reflects_latest_write(session_service): """Subsequent writes overwrite earlier values under the same key.""" - session = await session_service.create_session( - app_name=_APP, user_id=_USER - ) + session = await session_service.create_session(app_name=_APP, user_id=_USER) await session_service.append_event( session, Event( From 662da6a2a5204c030809dac855c417c546f3f5b8 Mon Sep 17 00:00:00 2001 From: Nicolas Mota Date: Tue, 5 May 2026 14:08:57 -0300 Subject: [PATCH 3/3] fix(sessions): address code-review findings on get_user_state - base_session_service: fix NotImplementedError message that incorrectly suggested list_sessions returns user state; now describes the correct workaround (list_sessions + get_session) - database_session_service: defensive dict(state or {}) guard against NULL state column values from legacy migrations - vertex_ai_session_service: replace inline comment with proper docstring explaining why the method is not supported - tests: add app-level isolation test, add VertexAiSessionService NotImplementedError test, remove redundant @pytest.mark.asyncio markers (asyncio_mode=auto), add return type annotation to _make_service Co-authored-by: Cursor --- .../adk/sessions/base_session_service.py | 8 +- .../adk/sessions/database_session_service.py | 2 +- .../adk/sessions/vertex_ai_session_service.py | 16 +- ...abase_session_service_tool_confirmation.py | 171 +++++ ...abase_session_service_tool_confirmation.py | 619 ++++++++++++++++++ .../unittests/sessions/test_get_user_state.py | 37 +- 6 files changed, 838 insertions(+), 15 deletions(-) create mode 100644 tests/integration/test_database_session_service_tool_confirmation.py create mode 100644 tests/unittests/sessions/test_database_session_service_tool_confirmation.py diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index f9cf3d06f0..f06374a523 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -141,12 +141,14 @@ async def get_user_state( Raises: NotImplementedError: When the concrete ``BaseSessionService`` implementation does not support reading user state independently of a - session. Callers should catch this and fall back to - ``list_sessions`` or accept that user state is unavailable. + session. Callers should catch this, then enumerate sessions via + ``list_sessions`` and call ``get_session`` on each result to access + the merged state, or accept that user state is unavailable. """ raise NotImplementedError( f'{type(self).__name__} does not support get_user_state. ' - 'Use list_sessions to retrieve user state indirectly.' + 'To read user state, enumerate sessions via list_sessions and ' + 'call get_session on each result to access the merged state.' ) async def append_event(self, session: Session, event: Event) -> Event: diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index a07b9db7f7..521a9b5d8b 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -630,7 +630,7 @@ async def get_user_state( ) if storage_user_state is None: return {} - return dict(storage_user_state.state) + return dict(storage_user_state.state or {}) @override async def append_event(self, session: Session, event: Event) -> Event: diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 82dc2f3791..a4eca9b5c9 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -273,13 +273,21 @@ async def delete_session( async def get_user_state( self, *, app_name: str, user_id: str ) -> dict[str, Any]: - # The Vertex AI Agent Engine API does not expose user state independently - # of a session. Callers that require cross-session user state should use - # list_sessions to retrieve it indirectly. + """Not supported by the Vertex AI Agent Engine backend. + + The Vertex AI Agent Engine API does not expose user state independently of + a session. To read user state, enumerate sessions via ``list_sessions`` + and call ``get_session`` on each result to access the merged state. + + Raises: + NotImplementedError: Always, because the Vertex AI Agent Engine API does + not provide a way to query user state without a session. + """ raise NotImplementedError( 'VertexAiSessionService does not support get_user_state. ' 'The Vertex AI Agent Engine API does not expose user state ' - 'independently of a session. Use list_sessions instead.' + 'independently of a session. To read user state, enumerate sessions ' + 'via list_sessions and call get_session on each result.' ) @override diff --git a/tests/integration/test_database_session_service_tool_confirmation.py b/tests/integration/test_database_session_service_tool_confirmation.py new file mode 100644 index 0000000000..f1290f7b34 --- /dev/null +++ b/tests/integration/test_database_session_service_tool_confirmation.py @@ -0,0 +1,171 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for DatabaseSessionService with tool confirmation support.""" + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.database_session_service import DatabaseSessionService +from google.adk.tools.tool_confirmation import ToolConfirmation +from google.genai import types +import pytest + + +@pytest.mark.asyncio +@pytest.mark.parametrize('llm_backend', ['GOOGLE_AI'], indirect=True) +async def test_database_session_service_save_retrieve_tool_confirmation(): + """Test that DatabaseSessionService correctly saves and retrieves events with requested_tool_confirmations.""" + async with DatabaseSessionService('sqlite+aiosqlite:///:memory:') as session_service: + # Create a session + session = await session_service.create_session( + app_name='test_app', user_id='test_user', state={} + ) + + # Create an event with requested_tool_confirmations + tool_confirmation = ToolConfirmation( + confirmed=True, hint='Approve this', payload={'amount': 100} + ) + event_actions = EventActions( + requested_tool_confirmations={'func_call_456': tool_confirmation} + ) + + event = Event( + id='event_2', + invocation_id='inv_2', + author='agent', + content=types.Content( + role='model', + parts=[types.Part.from_text(text='Test message')], + ), + actions=event_actions, + ) + + # Append event to session + await session_service.append_event(session, event) + + # Retrieve the session + retrieved_session = await session_service.get_session( + app_name='test_app', user_id='test_user', session_id=session.id + ) + + # Verify the event was saved and retrieved correctly + assert len(retrieved_session.events) == 1 + retrieved_event = retrieved_session.events[0] + + # Verify EventActions structure + assert isinstance(retrieved_event.actions, EventActions) + assert 'func_call_456' in retrieved_event.actions.requested_tool_confirmations + + # Verify ToolConfirmation object was correctly reconstructed + retrieved_confirmation = ( + retrieved_event.actions.requested_tool_confirmations['func_call_456'] + ) + assert isinstance(retrieved_confirmation, ToolConfirmation) + assert retrieved_confirmation.confirmed is True + assert retrieved_confirmation.hint == 'Approve this' + assert retrieved_confirmation.payload == {'amount': 100} + + +@pytest.mark.asyncio +@pytest.mark.parametrize('llm_backend', ['GOOGLE_AI'], indirect=True) +async def test_database_session_service_multiple_tool_confirmations(): + """Test that DatabaseSessionService handles multiple tool confirmations in a single event.""" + async with DatabaseSessionService('sqlite+aiosqlite:///:memory:') as session_service: + session = await session_service.create_session( + app_name='test_app', user_id='test_user', state={} + ) + + # Create an event with multiple requested_tool_confirmations + tool_confirmation_1 = ToolConfirmation( + confirmed=False, hint='First action', payload={'step': 1} + ) + tool_confirmation_2 = ToolConfirmation( + confirmed=False, hint='Second action', payload={'step': 2} + ) + event_actions = EventActions( + requested_tool_confirmations={ + 'func_call_1': tool_confirmation_1, + 'func_call_2': tool_confirmation_2, + } + ) + + event = Event( + id='event_3', + invocation_id='inv_3', + author='agent', + content=types.Content( + role='model', + parts=[types.Part.from_text(text='Multiple confirmations')], + ), + actions=event_actions, + ) + + await session_service.append_event(session, event) + + # Retrieve and verify + retrieved_session = await session_service.get_session( + app_name='test_app', user_id='test_user', session_id=session.id + ) + assert len(retrieved_session.events) == 1 + retrieved_event = retrieved_session.events[0] + + assert isinstance(retrieved_event.actions, EventActions) + assert len(retrieved_event.actions.requested_tool_confirmations) == 2 + + # Verify both confirmations + conf_1 = retrieved_event.actions.requested_tool_confirmations['func_call_1'] + assert isinstance(conf_1, ToolConfirmation) + assert conf_1.confirmed is False + assert conf_1.hint == 'First action' + assert conf_1.payload == {'step': 1} + + conf_2 = retrieved_event.actions.requested_tool_confirmations['func_call_2'] + assert isinstance(conf_2, ToolConfirmation) + assert conf_2.confirmed is False + assert conf_2.hint == 'Second action' + assert conf_2.payload == {'step': 2} + + +@pytest.mark.asyncio +@pytest.mark.parametrize('llm_backend', ['GOOGLE_AI'], indirect=True) +async def test_database_session_service_empty_tool_confirmations(): + """Test that DatabaseSessionService handles events without tool confirmations.""" + async with DatabaseSessionService('sqlite+aiosqlite:///:memory:') as session_service: + session = await session_service.create_session( + app_name='test_app', user_id='test_user', state={} + ) + + # Create an event without requested_tool_confirmations + event = Event( + id='event_4', + invocation_id='inv_4', + author='agent', + content=types.Content( + role='model', + parts=[types.Part.from_text(text='No confirmations')], + ), + actions=EventActions(), + ) + + await session_service.append_event(session, event) + + # Retrieve and verify + retrieved_session = await session_service.get_session( + app_name='test_app', user_id='test_user', session_id=session.id + ) + assert len(retrieved_session.events) == 1 + retrieved_event = retrieved_session.events[0] + + assert isinstance(retrieved_event.actions, EventActions) + assert not retrieved_event.actions.requested_tool_confirmations diff --git a/tests/unittests/sessions/test_database_session_service_tool_confirmation.py b/tests/unittests/sessions/test_database_session_service_tool_confirmation.py new file mode 100644 index 0000000000..7e7763fe84 --- /dev/null +++ b/tests/unittests/sessions/test_database_session_service_tool_confirmation.py @@ -0,0 +1,619 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DatabaseSessionService with tool confirmation support.""" + +import copy +from unittest import mock + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.runners import Runner +from google.adk.sessions.database_session_service import DatabaseSessionService +from google.adk.sessions.schemas.v0 import StorageEvent +from google.adk.sessions.session import Session +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.tool_confirmation import ToolConfirmation +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from google.genai.types import FunctionCall +from google.genai.types import FunctionResponse +from google.genai.types import GenerateContentResponse +from google.genai.types import Part +import pytest + +from .. import testing_utils + + +def get_database_urls(): + """Returns a list of database URLs to test with. + + For unit tests, we only use SQLite in-memory database to keep tests + isolated and fast. Integration tests with real databases should be + in a separate test file. + """ + return [ + ('sqlite+aiosqlite:///:memory:', 'sqlite'), + ] + + +@pytest.mark.asyncio +async def test_storage_event_serialize_deserialize_tool_confirmation(): + """Test that StorageEvent correctly serializes and deserializes events with requested_tool_confirmations.""" + # Create a session for testing + session = Session( + app_name='test_app', + user_id='test_user', + id='test_session', + state={}, + events=[], + last_update_time=0.0, + ) + + # Create an event with requested_tool_confirmations + tool_confirmation = ToolConfirmation( + confirmed=False, hint='Please approve this action', payload={'key': 'value'} + ) + event_actions = EventActions( + requested_tool_confirmations={'function_call_123': tool_confirmation} + ) + + event = Event( + id='event_1', + invocation_id='inv_1', + author='agent', + content=types.Content( + role='model', + parts=[types.Part.from_text(text='Test message')], + ), + actions=event_actions, + ) + + # Test serialization: Event -> StorageEvent + storage_event = StorageEvent.from_event(session, event) + + # Verify that actions was serialized as a dict + assert isinstance(storage_event.actions, dict) + assert 'requested_tool_confirmations' in storage_event.actions + assert 'function_call_123' in storage_event.actions['requested_tool_confirmations'] + + # Verify the tool confirmation data is preserved + stored_confirmation = storage_event.actions['requested_tool_confirmations'][ + 'function_call_123' + ] + assert isinstance(stored_confirmation, dict) + assert stored_confirmation['confirmed'] is False + assert stored_confirmation['hint'] == 'Please approve this action' + assert stored_confirmation['payload'] == {'key': 'value'} + + # Test deserialization: StorageEvent -> Event + deserialized_event = storage_event.to_event() + + # Verify the event structure + assert deserialized_event.id == event.id + assert deserialized_event.invocation_id == event.invocation_id + assert deserialized_event.author == event.author + + # Verify EventActions was correctly reconstructed + assert isinstance(deserialized_event.actions, EventActions) + assert 'function_call_123' in deserialized_event.actions.requested_tool_confirmations + + # Verify ToolConfirmation objects were correctly reconstructed + deserialized_confirmation = ( + deserialized_event.actions.requested_tool_confirmations['function_call_123'] + ) + assert isinstance(deserialized_confirmation, ToolConfirmation) + assert deserialized_confirmation.confirmed is False + assert deserialized_confirmation.hint == 'Please approve this action' + assert deserialized_confirmation.payload == {'key': 'value'} + + +@pytest.mark.asyncio +async def test_pickle_limitation_with_tool_confirmation(): + """Test that demonstrates the pickle serialization limitation with ToolConfirmation. + + This test simulates the old behavior where EventActions with ToolConfirmation + objects were pickled directly without using model_dump(). The problem is that + when pickling Pydantic models nested inside EventActions, the deserialized + objects may not be properly reconstructed as Pydantic models, losing validation + and type safety. + + The current implementation fixes this by: + 1. Using model_dump(mode='python') before pickling (converts to dict) + 2. Using EventActions.model_validate() after unpickling (reconstructs Pydantic models) + """ + import pickle + + # Create an event with requested_tool_confirmations + tool_confirmation = ToolConfirmation( + confirmed=False, hint='Test hint', payload={'test': 'data'} + ) + event_actions = EventActions( + requested_tool_confirmations={'func_call_1': tool_confirmation} + ) + + # Simulate the OLD behavior: pickle EventActions directly (without model_dump) + # This is what would happen in the old implementation + pickled_data = pickle.dumps(event_actions) + unpickled_actions = pickle.loads(pickled_data) + + # Verify that the unpickled object is still an EventActions instance + assert isinstance(unpickled_actions, EventActions) + assert 'func_call_1' in unpickled_actions.requested_tool_confirmations + + # THE PROBLEM: The ToolConfirmation object may not be properly reconstructed + # In some cases, pickle might deserialize it, but it may lose Pydantic validation + unpickled_confirmation = ( + unpickled_actions.requested_tool_confirmations['func_call_1'] + ) + + # This might work in some cases, but the object may not be a proper ToolConfirmation + # instance with full Pydantic validation + # The issue is that pickle doesn't guarantee proper Pydantic model reconstruction + if isinstance(unpickled_confirmation, ToolConfirmation): + # If it works, verify the data + assert unpickled_confirmation.confirmed is False + assert unpickled_confirmation.hint == 'Test hint' + assert unpickled_confirmation.payload == {'test': 'data'} + else: + # This demonstrates the problem: the object might not be a ToolConfirmation + # It could be a dict or a generic object without Pydantic validation + pytest.fail( + 'ToolConfirmation was not properly reconstructed after pickle/unpickle. ' + 'This demonstrates the limitation that required the fix.' + ) + + # NOW demonstrate the CORRECT approach (current implementation): + # 1. Convert to dict using model_dump before pickling + actions_dict = event_actions.model_dump(mode='python') + pickled_dict = pickle.dumps(actions_dict) + unpickled_dict = pickle.loads(pickled_dict) + + # 2. Reconstruct using model_validate (ensures proper Pydantic model creation) + reconstructed_actions = EventActions.model_validate(unpickled_dict) + + # Verify proper reconstruction + assert isinstance(reconstructed_actions, EventActions) + assert 'func_call_1' in reconstructed_actions.requested_tool_confirmations + + reconstructed_confirmation = ( + reconstructed_actions.requested_tool_confirmations['func_call_1'] + ) + + # This ALWAYS works because we use model_validate + assert isinstance(reconstructed_confirmation, ToolConfirmation) + assert reconstructed_confirmation.confirmed is False + assert reconstructed_confirmation.hint == 'Test hint' + assert reconstructed_confirmation.payload == {'test': 'data'} + + # Verify Pydantic validation still works (this might fail with direct pickle) + # Try to create a new ToolConfirmation with invalid data to test validation + try: + # This should work because it's a proper Pydantic model + invalid_confirmation = ToolConfirmation( + confirmed='not a boolean', hint='test' # type: ignore + ) + pytest.fail('Pydantic validation should have failed') + except Exception: + # Expected: Pydantic validation should catch the type error + pass + + +@pytest.mark.parametrize('db_url,db_name', get_database_urls()) +@pytest.mark.asyncio +async def test_database_session_service_save_retrieve_tool_confirmation( + db_url, db_name +): + """Test that DatabaseSessionService correctly saves and retrieves events with requested_tool_confirmations.""" + async with DatabaseSessionService(db_url) as session_service: + # Create a session + session = await session_service.create_session( + app_name='test_app', user_id='test_user', state={} + ) + + # Create an event with requested_tool_confirmations + tool_confirmation = ToolConfirmation( + confirmed=True, hint='Approve this', payload={'amount': 100} + ) + event_actions = EventActions( + requested_tool_confirmations={'func_call_456': tool_confirmation} + ) + + event = Event( + id='event_2', + invocation_id='inv_2', + author='agent', + content=types.Content( + role='model', parts=[types.Part.from_text(text='Requesting confirmation')] + ), + actions=event_actions, + ) + + # Append the event to the session + appended_event = await session_service.append_event(session, event) + assert appended_event.id == event.id + + # Retrieve the session + retrieved_session = await session_service.get_session( + app_name='test_app', user_id='test_user', session_id=session.id + ) + + assert retrieved_session is not None + assert len(retrieved_session.events) == 1 + + # Verify the retrieved event + retrieved_event = retrieved_session.events[0] + assert retrieved_event.id == event.id + assert isinstance(retrieved_event.actions, EventActions) + + # Verify requested_tool_confirmations was preserved + assert 'func_call_456' in retrieved_event.actions.requested_tool_confirmations + retrieved_confirmation = ( + retrieved_event.actions.requested_tool_confirmations['func_call_456'] + ) + assert isinstance(retrieved_confirmation, ToolConfirmation) + assert retrieved_confirmation.confirmed is True + assert retrieved_confirmation.hint == 'Approve this' + assert retrieved_confirmation.payload == {'amount': 100} + + +@pytest.mark.parametrize('db_url,db_name', get_database_urls()) +@pytest.mark.asyncio +async def test_database_session_service_multiple_tool_confirmations( + db_url, db_name +): + """Test that DatabaseSessionService handles multiple tool confirmations in one event.""" + async with DatabaseSessionService(db_url) as session_service: + session = await session_service.create_session( + app_name='test_app', user_id='test_user', state={} + ) + + # Create event with multiple tool confirmations + tool_confirmation_1 = ToolConfirmation(confirmed=False, hint='First action') + tool_confirmation_2 = ToolConfirmation( + confirmed=True, hint='Second action', payload={'value': 42} + ) + + event_actions = EventActions( + requested_tool_confirmations={ + 'func_call_1': tool_confirmation_1, + 'func_call_2': tool_confirmation_2, + } + ) + + event = Event( + id='event_3', + invocation_id='inv_3', + author='agent', + content=types.Content( + role='model', parts=[types.Part.from_text(text='Multiple confirmations')] + ), + actions=event_actions, + ) + + await session_service.append_event(session, event) + + # Retrieve and verify + retrieved_session = await session_service.get_session( + app_name='test_app', user_id='test_user', session_id=session.id + ) + + retrieved_event = retrieved_session.events[0] + assert len(retrieved_event.actions.requested_tool_confirmations) == 2 + + conf1 = retrieved_event.actions.requested_tool_confirmations['func_call_1'] + assert isinstance(conf1, ToolConfirmation) + assert conf1.confirmed is False + assert conf1.hint == 'First action' + + conf2 = retrieved_event.actions.requested_tool_confirmations['func_call_2'] + assert isinstance(conf2, ToolConfirmation) + assert conf2.confirmed is True + assert conf2.hint == 'Second action' + assert conf2.payload == {'value': 42} + + +@pytest.mark.parametrize('db_url,db_name', get_database_urls()) +@pytest.mark.asyncio +async def test_database_session_service_empty_tool_confirmations( + db_url, db_name +): + """Test that DatabaseSessionService handles events without tool confirmations correctly.""" + async with DatabaseSessionService(db_url) as session_service: + session = await session_service.create_session( + app_name='test_app', user_id='test_user', state={} + ) + + # Create event without tool confirmations + event = Event( + id='event_4', + invocation_id='inv_4', + author='user', + content=types.Content( + role='user', parts=[types.Part.from_text(text='Regular message')] + ), + actions=EventActions(), + ) + + await session_service.append_event(session, event) + + # Retrieve and verify + retrieved_session = await session_service.get_session( + app_name='test_app', user_id='test_user', session_id=session.id + ) + + retrieved_event = retrieved_session.events[0] + assert isinstance(retrieved_event.actions, EventActions) + assert len(retrieved_event.actions.requested_tool_confirmations) == 0 + + +def _test_function(tool_context: ToolContext) -> dict[str, str]: + """Test function that requires confirmation.""" + return {"result": f"confirmed={tool_context.tool_confirmation.confirmed}"} + + +def _create_llm_response_from_tools( + tools: list[FunctionTool], +) -> GenerateContentResponse: + """Creates a mock LLM response containing a function call.""" + parts = [ + Part(function_call=FunctionCall(name=tool.name, args={})) + for tool in tools + ] + return testing_utils.LlmResponse( + content=testing_utils.ModelContent(parts=parts) + ) + + +def _create_llm_response_from_text(text: str) -> GenerateContentResponse: + """Creates a mock LLM response containing text.""" + return testing_utils.LlmResponse( + content=testing_utils.ModelContent(parts=[Part(text=text)]) + ) + + +HINT_TEXT = ( + "Please approve or reject the tool call _test_function() by" + " responding with a FunctionResponse with an" + " expected ToolConfirmation payload." +) + + +@pytest.mark.parametrize('db_url,db_name', get_database_urls()) +@pytest.mark.asyncio +async def test_tool_confirmation_flow_with_database_session_service( + db_url, db_name +): + """Test the complete tool confirmation flow using DatabaseSessionService.""" + async with DatabaseSessionService(db_url) as session_service: + # Create tools with confirmation requirement + tools = [FunctionTool(func=_test_function, require_confirmation=True)] + + # Create mock LLM responses + llm_responses = [ + _create_llm_response_from_tools(tools), + _create_llm_response_from_text("test llm response after tool call"), + ] + mock_model = testing_utils.MockModel(responses=llm_responses) + + # Create agent + agent = LlmAgent(name="test_agent", model=mock_model, tools=tools) + + # Create runner with DatabaseSessionService + runner = Runner( + app_name='test_app', + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=session_service, + memory_service=InMemoryMemoryService(), + ) + + # Create a session first + test_session = await session_service.create_session( + app_name='test_app', user_id='test_user', state={} + ) + session_id = test_session.id + + # First invocation: user query triggers tool call that requires confirmation + user_query = testing_utils.UserContent("test user query") + events = [] + async for event in runner.run_async( + user_id='test_user', session_id=session_id, new_message=user_query + ): + events.append(event) + + # Verify that confirmation was requested + assert len(events) >= 3 + # Find the request confirmation event + request_confirmation_event = None + for event in events: + if ( + event.content + and event.content.parts + and event.content.parts[0].function_call + and event.content.parts[0].function_call.name + == REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + ): + request_confirmation_event = event + break + + assert request_confirmation_event is not None + ask_for_confirmation_function_call_id = ( + request_confirmation_event.content.parts[0].function_call.id + ) + invocation_id = request_confirmation_event.invocation_id + + # Get the session to verify events were persisted (using the same session_id) + session = await session_service.get_session( + app_name='test_app', user_id='test_user', session_id=session_id + ) + assert session is not None + assert len(session.events) > 0 + + # Verify that the request confirmation event was persisted correctly + persisted_confirmation_event = None + for event in session.events: + if ( + event.content + and event.content.parts + and event.content.parts[0].function_call + and event.content.parts[0].function_call.name + == REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + ): + persisted_confirmation_event = event + break + + assert persisted_confirmation_event is not None + assert ( + persisted_confirmation_event.content.parts[0].function_call.id + == ask_for_confirmation_function_call_id + ) + + # Second invocation: user provides confirmation + user_confirmation = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=ask_for_confirmation_function_call_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"confirmed": True}, + ) + ) + ) + + # Run with the confirmation + final_events = [] + async for event in runner.run_async( + user_id='test_user', session_id=session_id, new_message=user_confirmation + ): + final_events.append(event) + + # Verify the tool was executed after confirmation + assert len(final_events) > 0 + tool_response_found = False + for event in final_events: + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name == tools[0].name + ): + tool_response_found = True + assert event.content.parts[0].function_response.response == { + "result": "confirmed=True" + } + break + + assert tool_response_found, "Tool response not found in final events" + + # Verify all events were persisted + final_session = await session_service.get_session( + app_name='test_app', user_id='test_user', session_id=session_id + ) + assert final_session is not None + assert len(final_session.events) > len(session.events) + + +@pytest.mark.parametrize('db_url,db_name', get_database_urls()) +@pytest.mark.asyncio +async def test_tool_confirmation_rejected_with_database_session_service( + db_url, db_name +): + """Test tool confirmation rejection flow using DatabaseSessionService.""" + async with DatabaseSessionService(db_url) as session_service: + tools = [FunctionTool(func=_test_function, require_confirmation=True)] + llm_responses = [ + _create_llm_response_from_tools(tools), + _create_llm_response_from_text("test response"), + ] + mock_model = testing_utils.MockModel(responses=llm_responses) + agent = LlmAgent(name="test_agent", model=mock_model, tools=tools) + + runner = Runner( + app_name='test_app', + agent=agent, + artifact_service=InMemoryArtifactService(), + session_service=session_service, + memory_service=InMemoryMemoryService(), + ) + + # Create a session first + test_session = await session_service.create_session( + app_name='test_app', user_id='test_user', state={} + ) + session_id = test_session.id + + # First invocation + user_query = testing_utils.UserContent("test query") + events = [] + async for event in runner.run_async( + user_id='test_user', session_id=session_id, new_message=user_query + ): + events.append(event) + + # Find confirmation request + request_confirmation_event = None + for event in events: + if ( + event.content + and event.content.parts + and event.content.parts[0].function_call + and event.content.parts[0].function_call.name + == REQUEST_CONFIRMATION_FUNCTION_CALL_NAME + ): + request_confirmation_event = event + break + + assert request_confirmation_event is not None + confirmation_call_id = ( + request_confirmation_event.content.parts[0].function_call.id + ) + + # User rejects the confirmation + user_rejection = testing_utils.UserContent( + Part( + function_response=FunctionResponse( + id=confirmation_call_id, + name=REQUEST_CONFIRMATION_FUNCTION_CALL_NAME, + response={"confirmed": False}, + ) + ) + ) + + final_events = [] + async for event in runner.run_async( + user_id='test_user', session_id=session_id, new_message=user_rejection + ): + final_events.append(event) + + # Verify rejection was handled + rejection_found = False + for event in final_events: + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name == tools[0].name + ): + rejection_found = True + assert event.content.parts[0].function_response.response == { + "error": "This tool call is rejected." + } + break + + assert rejection_found, "Rejection response not found" + diff --git a/tests/unittests/sessions/test_get_user_state.py b/tests/unittests/sessions/test_get_user_state.py index 33ce72bbc2..f5944668ca 100644 --- a/tests/unittests/sessions/test_get_user_state.py +++ b/tests/unittests/sessions/test_get_user_state.py @@ -18,12 +18,15 @@ from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import BaseSessionService from google.adk.sessions.database_session_service import DatabaseSessionService from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.sqlite_session_service import SqliteSessionService +from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService import pytest _APP = 'test-app' +_OTHER_APP = 'other-app' _USER = 'user-42' _OTHER_USER = 'user-99' @@ -34,7 +37,9 @@ class SessionServiceType(enum.Enum): SQLITE = 'SQLITE' -def _make_service(service_type: SessionServiceType, tmp_path=None): +def _make_service( + service_type: SessionServiceType, tmp_path=None +) -> BaseSessionService: if service_type == SessionServiceType.DATABASE: return DatabaseSessionService('sqlite+aiosqlite:///:memory:') if service_type == SessionServiceType.SQLITE: @@ -57,7 +62,6 @@ async def session_service(request, tmp_path): await service.close() -@pytest.mark.asyncio async def test_get_user_state_returns_empty_dict_when_no_state_exists( session_service, ): @@ -66,7 +70,6 @@ async def test_get_user_state_returns_empty_dict_when_no_state_exists( assert state == {} -@pytest.mark.asyncio async def test_get_user_state_returns_state_written_via_append_event( session_service, ): @@ -88,7 +91,6 @@ async def test_get_user_state_returns_state_written_via_append_event( assert 'session_key' not in state -@pytest.mark.asyncio async def test_get_user_state_is_not_visible_across_users(session_service): """User state is scoped to (app_name, user_id) — other users see {}.""" session = await session_service.create_session(app_name=_APP, user_id=_USER) @@ -106,12 +108,27 @@ async def test_get_user_state_is_not_visible_across_users(session_service): assert other_state == {} -@pytest.mark.asyncio +async def test_get_user_state_is_not_visible_across_apps(session_service): + """User state is scoped to (app_name, user_id) — other apps see {}.""" + session = await session_service.create_session(app_name=_APP, user_id=_USER) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:data': 'only-app-a'}), + ), + ) + + other_state = await session_service.get_user_state( + app_name=_OTHER_APP, user_id=_USER + ) + assert other_state == {} + + async def test_get_user_state_available_before_session_is_created( session_service, ): """Core use case: user state is readable without an active session_id.""" - # Write state via a first session. first_session = await session_service.create_session( app_name=_APP, user_id=_USER ) @@ -129,7 +146,6 @@ async def test_get_user_state_available_before_session_is_created( assert state == {'ctx': {'v': 1}} -@pytest.mark.asyncio async def test_get_user_state_reflects_latest_write(session_service): """Subsequent writes overwrite earlier values under the same key.""" session = await session_service.create_session(app_name=_APP, user_id=_USER) @@ -150,3 +166,10 @@ async def test_get_user_state_reflects_latest_write(session_service): state = await session_service.get_user_state(app_name=_APP, user_id=_USER) assert state['counter'] == 2 + + +async def test_vertex_ai_session_service_raises_not_implemented(): + """VertexAiSessionService raises NotImplementedError for get_user_state.""" + service = VertexAiSessionService(project='proj', location='us-central1') + with pytest.raises(NotImplementedError): + await service.get_user_state(app_name=_APP, user_id=_USER)