Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,46 @@ async def delete_session(
) -> None:
"""Deletes a session."""

async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
"""Returns the user-scoped state for the given app and user.

User state is keyed by ``(app_name, user_id)`` and shared across all
sessions of the same user within the same app. The returned dictionary
uses raw keys **without** the ``user:`` prefix (e.g. ``"my_key"`` rather
than ``"user:my_key"``).

This method exists so that callers can read user state without holding an
active ``session_id``. A common use case is bootstrapping context at the
start of a new session before calling ``create_session``, which would
otherwise require an expensive ``list_sessions`` call just to access
user-scoped data.

Returns an empty dict when no user state has been stored for this
``(app_name, user_id)`` combination.

Args:
app_name: The name of the app.
user_id: The ID of the user.

Returns:
A dictionary of raw (un-prefixed) user-scoped key/value pairs, or an
empty dict when no user state exists.

Raises:
NotImplementedError: When the concrete ``BaseSessionService``
implementation does not support reading user state independently of a
session. Callers should catch this, then enumerate sessions via
``list_sessions`` and call ``get_session`` on each result to access
the merged state, or accept that user state is unavailable.
"""
raise NotImplementedError(
f'{type(self).__name__} does not support get_user_state. '
'To read user state, enumerate sessions via list_sessions and '
'call get_session on each result to access the merged state.'
)

async def append_event(self, session: Session, event: Event) -> Event:
"""Appends an event to a session object."""
if event.partial:
Expand Down
16 changes: 16 additions & 0 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,22 @@ async def delete_session(
await sql_session.execute(stmt)
await sql_session.commit()

@override
async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
await self._prepare_tables()
schema = self._get_schema_classes()
async with self._rollback_on_exception_session(
read_only=True
) as sql_session:
storage_user_state = await sql_session.get(
schema.StorageUserState, (app_name, user_id)
)
if storage_user_state is None:
return {}
return dict(storage_user_state.state or {})

@override
async def append_event(self, session: Session, event: Event) -> Event:
await self._prepare_tables()
Expand Down
6 changes: 6 additions & 0 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ def _delete_session_impl(

self.sessions[app_name][user_id].pop(session_id)

@override
async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
return dict(self.user_state.get(app_name, {}).get(user_id, {}))

@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
Expand Down
7 changes: 7 additions & 0 deletions src/google/adk/sessions/sqlite_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ async def delete_session(
)
await db.commit()

@override
async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
async with self._get_db_connection() as db:
return await self._get_user_state(db, app_name, user_id)

@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
Expand Down
21 changes: 21 additions & 0 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,27 @@ async def delete_session(
logger.error('Error deleting session %s: %s', session_id, e)
raise

@override
async def get_user_state(
self, *, app_name: str, user_id: str
) -> dict[str, Any]:
"""Not supported by the Vertex AI Agent Engine backend.

The Vertex AI Agent Engine API does not expose user state independently of
a session. To read user state, enumerate sessions via ``list_sessions``
and call ``get_session`` on each result to access the merged state.

Raises:
NotImplementedError: Always, because the Vertex AI Agent Engine API does
not provide a way to query user state without a session.
"""
raise NotImplementedError(
'VertexAiSessionService does not support get_user_state. '
'The Vertex AI Agent Engine API does not expose user state '
'independently of a session. To read user state, enumerate sessions '
'via list_sessions and call get_session on each result.'
)

@override
async def append_event(self, session: Session, event: Event) -> Event:
# Update the in-memory session.
Expand Down
171 changes: 171 additions & 0 deletions tests/integration/test_database_session_service_tool_confirmation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Integration tests for DatabaseSessionService with tool confirmation support."""

from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.sessions.database_session_service import DatabaseSessionService
from google.adk.tools.tool_confirmation import ToolConfirmation
from google.genai import types
import pytest


@pytest.mark.asyncio
@pytest.mark.parametrize('llm_backend', ['GOOGLE_AI'], indirect=True)
async def test_database_session_service_save_retrieve_tool_confirmation():
"""Test that DatabaseSessionService correctly saves and retrieves events with requested_tool_confirmations."""
async with DatabaseSessionService('sqlite+aiosqlite:///:memory:') as session_service:
# Create a session
session = await session_service.create_session(
app_name='test_app', user_id='test_user', state={}
)

# Create an event with requested_tool_confirmations
tool_confirmation = ToolConfirmation(
confirmed=True, hint='Approve this', payload={'amount': 100}
)
event_actions = EventActions(
requested_tool_confirmations={'func_call_456': tool_confirmation}
)

event = Event(
id='event_2',
invocation_id='inv_2',
author='agent',
content=types.Content(
role='model',
parts=[types.Part.from_text(text='Test message')],
),
actions=event_actions,
)

# Append event to session
await session_service.append_event(session, event)

# Retrieve the session
retrieved_session = await session_service.get_session(
app_name='test_app', user_id='test_user', session_id=session.id
)

# Verify the event was saved and retrieved correctly
assert len(retrieved_session.events) == 1
retrieved_event = retrieved_session.events[0]

# Verify EventActions structure
assert isinstance(retrieved_event.actions, EventActions)
assert 'func_call_456' in retrieved_event.actions.requested_tool_confirmations

# Verify ToolConfirmation object was correctly reconstructed
retrieved_confirmation = (
retrieved_event.actions.requested_tool_confirmations['func_call_456']
)
assert isinstance(retrieved_confirmation, ToolConfirmation)
assert retrieved_confirmation.confirmed is True
assert retrieved_confirmation.hint == 'Approve this'
assert retrieved_confirmation.payload == {'amount': 100}


@pytest.mark.asyncio
@pytest.mark.parametrize('llm_backend', ['GOOGLE_AI'], indirect=True)
async def test_database_session_service_multiple_tool_confirmations():
"""Test that DatabaseSessionService handles multiple tool confirmations in a single event."""
async with DatabaseSessionService('sqlite+aiosqlite:///:memory:') as session_service:
session = await session_service.create_session(
app_name='test_app', user_id='test_user', state={}
)

# Create an event with multiple requested_tool_confirmations
tool_confirmation_1 = ToolConfirmation(
confirmed=False, hint='First action', payload={'step': 1}
)
tool_confirmation_2 = ToolConfirmation(
confirmed=False, hint='Second action', payload={'step': 2}
)
event_actions = EventActions(
requested_tool_confirmations={
'func_call_1': tool_confirmation_1,
'func_call_2': tool_confirmation_2,
}
)

event = Event(
id='event_3',
invocation_id='inv_3',
author='agent',
content=types.Content(
role='model',
parts=[types.Part.from_text(text='Multiple confirmations')],
),
actions=event_actions,
)

await session_service.append_event(session, event)

# Retrieve and verify
retrieved_session = await session_service.get_session(
app_name='test_app', user_id='test_user', session_id=session.id
)
assert len(retrieved_session.events) == 1
retrieved_event = retrieved_session.events[0]

assert isinstance(retrieved_event.actions, EventActions)
assert len(retrieved_event.actions.requested_tool_confirmations) == 2

# Verify both confirmations
conf_1 = retrieved_event.actions.requested_tool_confirmations['func_call_1']
assert isinstance(conf_1, ToolConfirmation)
assert conf_1.confirmed is False
assert conf_1.hint == 'First action'
assert conf_1.payload == {'step': 1}

conf_2 = retrieved_event.actions.requested_tool_confirmations['func_call_2']
assert isinstance(conf_2, ToolConfirmation)
assert conf_2.confirmed is False
assert conf_2.hint == 'Second action'
assert conf_2.payload == {'step': 2}


@pytest.mark.asyncio
@pytest.mark.parametrize('llm_backend', ['GOOGLE_AI'], indirect=True)
async def test_database_session_service_empty_tool_confirmations():
"""Test that DatabaseSessionService handles events without tool confirmations."""
async with DatabaseSessionService('sqlite+aiosqlite:///:memory:') as session_service:
session = await session_service.create_session(
app_name='test_app', user_id='test_user', state={}
)

# Create an event without requested_tool_confirmations
event = Event(
id='event_4',
invocation_id='inv_4',
author='agent',
content=types.Content(
role='model',
parts=[types.Part.from_text(text='No confirmations')],
),
actions=EventActions(),
)

await session_service.append_event(session, event)

# Retrieve and verify
retrieved_session = await session_service.get_session(
app_name='test_app', user_id='test_user', session_id=session.id
)
assert len(retrieved_session.events) == 1
retrieved_event = retrieved_session.events[0]

assert isinstance(retrieved_event.actions, EventActions)
assert not retrieved_event.actions.requested_tool_confirmations
Loading