Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,46 @@ async def delete_session(
) -> None:
"""Deletes a session."""

async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
"""Returns the user-scoped state for the given app and user.

User state is keyed by ``(app_name, user_id)`` and shared across all
sessions of the same user within the same app. The returned dictionary
uses raw keys **without** the ``user:`` prefix (e.g. ``"my_key"`` rather
than ``"user:my_key"``).

This method exists so that callers can read user state without holding an
active ``session_id``. A common use case is bootstrapping context at the
start of a new session before calling ``create_session``, which would
otherwise require an expensive ``list_sessions`` call just to access
user-scoped data.

Returns an empty dict when no user state has been stored for this
``(app_name, user_id)`` combination.

Args:
app_name: The name of the app.
user_id: The ID of the user.

Returns:
A dictionary of raw (un-prefixed) user-scoped key/value pairs, or an
empty dict when no user state exists.

Raises:
NotImplementedError: When the concrete ``BaseSessionService``
implementation does not support reading user state independently of a
session. Callers should catch this, then enumerate sessions via
``list_sessions`` and call ``get_session`` on each result to access
the merged state, or accept that user state is unavailable.
"""
raise NotImplementedError(
f'{type(self).__name__} does not support get_user_state. '
'To read user state, enumerate sessions via list_sessions and '
'call get_session on each result to access the merged state.'
)

async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
Expand Down
16 changes: 16 additions & 0 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,22 @@ async def delete_session(
await sql_session.execute(stmt)
await sql_session.commit()

@override
async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
await self._prepare_tables()
schema = self._get_schema_classes()
async with self._rollback_on_exception_session(
read_only=True
) as sql_session:
storage_user_state = await sql_session.get(
schema.StorageUserState, (app_name, user_id)
)
if storage_user_state is None:
return {}
return dict(storage_user_state.state or {})

@override
async def append_event(self, session: Session, event: Event) -> Event:
await self._prepare_tables()
Expand Down
6 changes: 6 additions & 0 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ def _delete_session_impl(

self.sessions[app_name][user_id].pop(session_id)

@override
async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
return dict(self.user_state.get(app_name, {}).get(user_id, {}))

@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
Expand Down
7 changes: 7 additions & 0 deletions src/google/adk/sessions/sqlite_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ async def delete_session(
)
await db.commit()

@override
async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
async with self._get_db_connection() as db:
return await self._get_user_state(db, app_name, user_id)

@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
Expand Down
21 changes: 21 additions & 0 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,27 @@ async def delete_session(
logger.error('Error deleting session %s: %s', session_id, e)
raise

@override
async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
"""Not supported by the Vertex AI Agent Engine backend.

The Vertex AI Agent Engine API does not expose user state independently of
a session. To read user state, enumerate sessions via ``list_sessions``
and call ``get_session`` on each result to access the merged state.

Raises:
NotImplementedError: Always, because the Vertex AI Agent Engine API does
not provide a way to query user state without a session.
"""
raise NotImplementedError(
'VertexAiSessionService does not support get_user_state. '
'The Vertex AI Agent Engine API does not expose user state '
'independently of a session. To read user state, enumerate sessions '
'via list_sessions and call get_session on each result.'
)

@override
async def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
Expand Down
175 changes: 175 additions & 0 deletions tests/unittests/sessions/test_get_user_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for BaseSessionService.get_user_state across concrete implementations."""

import enum

from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.sessions.base_session_service import BaseSessionService
from google.adk.sessions.database_session_service import DatabaseSessionService
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.sessions.sqlite_session_service import SqliteSessionService
from google.adk.sessions.vertex_ai_session_service import VertexAiSessionService
import pytest

_APP = 'test-app'
_OTHER_APP = 'other-app'
_USER = 'user-42'
_OTHER_USER = 'user-99'


class SessionServiceType(enum.Enum):
IN_MEMORY = 'IN_MEMORY'
DATABASE = 'DATABASE'
SQLITE = 'SQLITE'


def _make_service(
service_type: SessionServiceType, tmp_path=None
) -> BaseSessionService:
if service_type == SessionServiceType.DATABASE:
return DatabaseSessionService('sqlite+aiosqlite:///:memory:')
if service_type == SessionServiceType.SQLITE:
return SqliteSessionService(str(tmp_path / 'sqlite.db'))
return InMemorySessionService()


@pytest.fixture(
params=[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
]
)
async def session_service(request, tmp_path):
"""Provides a session service and closes database backends on teardown."""
service = _make_service(request.param, tmp_path)
yield service
if isinstance(service, DatabaseSessionService):
await service.close()


async def test_get_user_state_returns_empty_dict_when_no_state_exists(
session_service,
):
"""Returns {} when (app_name, user_id) has never had state written."""
state = await session_service.get_user_state(app_name=_APP, user_id=_USER)
assert state == {}


async def test_get_user_state_returns_state_written_via_append_event(
session_service,
):
"""State written with the user: prefix is returned without the prefix."""
session = await session_service.create_session(app_name=_APP, user_id=_USER)
await session_service.append_event(
session,
Event(
author='system',
actions=EventActions(
state_delta={'user:profile': {'name': 'Alice'}, 'session_key': 1}
),
),
)

state = await session_service.get_user_state(app_name=_APP, user_id=_USER)

assert state == {'profile': {'name': 'Alice'}}
assert 'session_key' not in state


async def test_get_user_state_is_not_visible_across_users(session_service):
"""User state is scoped to (app_name, user_id) — other users see {}."""
session = await session_service.create_session(app_name=_APP, user_id=_USER)
await session_service.append_event(
session,
Event(
author='system',
actions=EventActions(state_delta={'user:secret': 'only-for-user-42'}),
),
)

other_state = await session_service.get_user_state(
app_name=_APP, user_id=_OTHER_USER
)
assert other_state == {}


async def test_get_user_state_is_not_visible_across_apps(session_service):
"""User state is scoped to (app_name, user_id) — other apps see {}."""
session = await session_service.create_session(app_name=_APP, user_id=_USER)
await session_service.append_event(
session,
Event(
author='system',
actions=EventActions(state_delta={'user:data': 'only-app-a'}),
),
)

other_state = await session_service.get_user_state(
app_name=_OTHER_APP, user_id=_USER
)
assert other_state == {}


async def test_get_user_state_available_before_session_is_created(
session_service,
):
"""Core use case: user state is readable without an active session_id."""
first_session = await session_service.create_session(
app_name=_APP, user_id=_USER
)
await session_service.append_event(
first_session,
Event(
author='system',
actions=EventActions(state_delta={'user:ctx': {'v': 1}}),
),
)

# Simulate a brand-new session_id (not yet created) — get_user_state must
# still return the persisted user state.
state = await session_service.get_user_state(app_name=_APP, user_id=_USER)
assert state == {'ctx': {'v': 1}}


async def test_get_user_state_reflects_latest_write(session_service):
"""Subsequent writes overwrite earlier values under the same key."""
session = await session_service.create_session(app_name=_APP, user_id=_USER)
await session_service.append_event(
session,
Event(
author='system',
actions=EventActions(state_delta={'user:counter': 1}),
),
)
await session_service.append_event(
session,
Event(
author='system',
actions=EventActions(state_delta={'user:counter': 2}),
),
)

state = await session_service.get_user_state(app_name=_APP, user_id=_USER)
assert state['counter'] == 2


async def test_vertex_ai_session_service_raises_not_implemented():
"""VertexAiSessionService raises NotImplementedError for get_user_state."""
service = VertexAiSessionService(project='proj', location='us-central1')
with pytest.raises(NotImplementedError):
await service.get_user_state(app_name=_APP, user_id=_USER)