diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index af94bb9eeb..f06374a523 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -111,6 +111,46 @@ async def delete_session( ) -> None: """Deletes a session.""" + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + """Returns the user-scoped state for the given app and user. + + User state is keyed by ``(app_name, user_id)`` and shared across all + sessions of the same user within the same app. The returned dictionary + uses raw keys **without** the ``user:`` prefix (e.g. ``"my_key"`` rather + than ``"user:my_key"``). + + This method exists so that callers can read user state without holding an + active ``session_id``. A common use case is bootstrapping context at the + start of a new session before calling ``create_session``, which would + otherwise require an expensive ``list_sessions`` call just to access + user-scoped data. + + Returns an empty dict when no user state has been stored for this + ``(app_name, user_id)`` combination. + + Args: + app_name: The name of the app. + user_id: The ID of the user. + + Returns: + A dictionary of raw (un-prefixed) user-scoped key/value pairs, or an + empty dict when no user state exists. + + Raises: + NotImplementedError: When the concrete ``BaseSessionService`` + implementation does not support reading user state independently of a + session. Callers should catch this, then enumerate sessions via + ``list_sessions`` and call ``get_session`` on each result to access + the merged state, or accept that user state is unavailable. + """ + raise NotImplementedError( + f'{type(self).__name__} does not support get_user_state. ' + 'To read user state, enumerate sessions via list_sessions and ' + 'call get_session on each result to access the merged state.' + ) + async def append_event(self, session: Session, event: Event) -> Event: """Appends an event to a session object.""" if event.partial: diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index d033f1f234..521a9b5d8b 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -616,6 +616,22 @@ async def delete_session( await sql_session.execute(stmt) await sql_session.commit() + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + await self._prepare_tables() + schema = self._get_schema_classes() + async with self._rollback_on_exception_session( + read_only=True + ) as sql_session: + storage_user_state = await sql_session.get( + schema.StorageUserState, (app_name, user_id) + ) + if storage_user_state is None: + return {} + return dict(storage_user_state.state or {}) + @override async def append_event(self, session: Session, event: Event) -> Event: await self._prepare_tables() diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index b8f6cfab46..73a54f398b 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -312,6 +312,12 @@ def _delete_session_impl( self.sessions[app_name][user_id].pop(session_id) + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + return dict(self.user_state.get(app_name, {}).get(user_id, {})) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 427bc3e73e..33f02a17bb 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -359,6 +359,13 @@ async def delete_session( ) await db.commit() + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + async with self._get_db_connection() as db: + return await self._get_user_state(db, app_name, user_id) + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 8c1fdc134e..a4eca9b5c9 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -269,6 +269,27 @@ async def delete_session( logger.error('Error deleting session %s: %s', session_id, e) raise + @override + async def get_user_state( + self, *, app_name: str, user_id: str + ) -> dict[str, Any]: + """Not supported by the Vertex AI Agent Engine backend. + + The Vertex AI Agent Engine API does not expose user state independently of + a session. To read user state, enumerate sessions via ``list_sessions`` + and call ``get_session`` on each result to access the merged state. + + Raises: + NotImplementedError: Always, because the Vertex AI Agent Engine API does + not provide a way to query user state without a session. + """ + raise NotImplementedError( + 'VertexAiSessionService does not support get_user_state. ' + 'The Vertex AI Agent Engine API does not expose user state ' + 'independently of a session. To read user state, enumerate sessions ' + 'via list_sessions and call get_session on each result.' + ) + @override async def append_event(self, session: Session, event: Event) -> Event: # Update the in-memory session. diff --git a/tests/unittests/sessions/test_get_user_state.py b/tests/unittests/sessions/test_get_user_state.py new file mode 100644 index 0000000000..f5944668ca --- /dev/null +++ b/tests/unittests/sessions/test_get_user_state.py @@ -0,0 +1,175 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BaseSessionService.get_user_state across concrete implementations.""" + +import enum + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.database_session_service import DatabaseSessionService +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.sqlite_session_service import SqliteSessionService +from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService +import pytest + +_APP = 'test-app' +_OTHER_APP = 'other-app' +_USER = 'user-42' +_OTHER_USER = 'user-99' + + +class SessionServiceType(enum.Enum): + IN_MEMORY = 'IN_MEMORY' + DATABASE = 'DATABASE' + SQLITE = 'SQLITE' + + +def _make_service( + service_type: SessionServiceType, tmp_path=None +) -> BaseSessionService: + if service_type == SessionServiceType.DATABASE: + return DatabaseSessionService('sqlite+aiosqlite:///:memory:') + if service_type == SessionServiceType.SQLITE: + return SqliteSessionService(str(tmp_path / 'sqlite.db')) + return InMemorySessionService() + + +@pytest.fixture( + params=[ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ] +) +async def session_service(request, tmp_path): + """Provides a session service and closes database backends on teardown.""" + service = _make_service(request.param, tmp_path) + yield service + if isinstance(service, DatabaseSessionService): + await service.close() + + +async def test_get_user_state_returns_empty_dict_when_no_state_exists( + session_service, +): + """Returns {} when (app_name, user_id) has never had state written.""" + state = await session_service.get_user_state(app_name=_APP, user_id=_USER) + assert state == {} + + +async def test_get_user_state_returns_state_written_via_append_event( + session_service, +): + """State written with the user: prefix is returned without the prefix.""" + session = await session_service.create_session(app_name=_APP, user_id=_USER) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions( + state_delta={'user:profile': {'name': 'Alice'}, 'session_key': 1} + ), + ), + ) + + state = await session_service.get_user_state(app_name=_APP, user_id=_USER) + + assert state == {'profile': {'name': 'Alice'}} + assert 'session_key' not in state + + +async def test_get_user_state_is_not_visible_across_users(session_service): + """User state is scoped to (app_name, user_id) — other users see {}.""" + session = await session_service.create_session(app_name=_APP, user_id=_USER) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:secret': 'only-for-user-42'}), + ), + ) + + other_state = await session_service.get_user_state( + app_name=_APP, user_id=_OTHER_USER + ) + assert other_state == {} + + +async def test_get_user_state_is_not_visible_across_apps(session_service): + """User state is scoped to (app_name, user_id) — other apps see {}.""" + session = await session_service.create_session(app_name=_APP, user_id=_USER) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:data': 'only-app-a'}), + ), + ) + + other_state = await session_service.get_user_state( + app_name=_OTHER_APP, user_id=_USER + ) + assert other_state == {} + + +async def test_get_user_state_available_before_session_is_created( + session_service, +): + """Core use case: user state is readable without an active session_id.""" + first_session = await session_service.create_session( + app_name=_APP, user_id=_USER + ) + await session_service.append_event( + first_session, + Event( + author='system', + actions=EventActions(state_delta={'user:ctx': {'v': 1}}), + ), + ) + + # Simulate a brand-new session_id (not yet created) — get_user_state must + # still return the persisted user state. + state = await session_service.get_user_state(app_name=_APP, user_id=_USER) + assert state == {'ctx': {'v': 1}} + + +async def test_get_user_state_reflects_latest_write(session_service): + """Subsequent writes overwrite earlier values under the same key.""" + session = await session_service.create_session(app_name=_APP, user_id=_USER) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:counter': 1}), + ), + ) + await session_service.append_event( + session, + Event( + author='system', + actions=EventActions(state_delta={'user:counter': 2}), + ), + ) + + state = await session_service.get_user_state(app_name=_APP, user_id=_USER) + assert state['counter'] == 2 + + +async def test_vertex_ai_session_service_raises_not_implemented(): + """VertexAiSessionService raises NotImplementedError for get_user_state.""" + service = VertexAiSessionService(project='proj', location='us-central1') + with pytest.raises(NotImplementedError): + await service.get_user_state(app_name=_APP, user_id=_USER)