Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions src/google/adk/a2a/executor/a2a_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from datetime import timezone
import inspect
import logging
import os
import time
import httpx
from typing import Awaitable
from typing import Callable
from typing import Optional
Expand All @@ -32,7 +35,7 @@
from a2a.types import TaskState
from a2a.types import TaskStatus
from a2a.types import TaskStatusUpdateEvent
from a2a.types import TextPart
from a2a.types import Part
from google.adk.platform import time as platform_time
from google.adk.platform import uuid as platform_uuid
from google.adk.runners import Runner
Expand Down Expand Up @@ -187,7 +190,7 @@ async def execute(
message=Message(
message_id=platform_uuid.new_uuid(),
role=Role.agent,
parts=[TextPart(text=str(e))],
parts=[Part(text=str(e))],
),
),
context_id=context.context_id,
Expand All @@ -213,9 +216,9 @@ async def _handle_request(
self._config.a2a_part_converter,
)

# ensure the session exists
# ensure the session exists modify this code
session = await self._prepare_session(context, run_request, runner)

await self._refresh_token_if_expired(session, runner)
# create invocation context
invocation_context = runner._new_invocation_context(
session=session,
Expand Down Expand Up @@ -321,7 +324,51 @@ async def _handle_request(
self._config.execute_interceptors,
)
await event_queue.enqueue_event(final_event)
async def _refresh_token_if_expired(self, session, runner: Runner):
state = session.state
if not state:
return

refresh_token = state.get("refresh_token")
expires_at = state.get("expires_at", 0)

if not refresh_token:
return

now = int(time.time())
if now < expires_at:
return

logger.info("OAuth token expired, refreshing...")

async with httpx.AsyncClient() as client:
resp = await client.post(
"https://oauth2.googleapis.com/token",
data={
"client_id": os.environ["GOOGLE_CLIENT_ID"],
"client_secret": os.environ["GOOGLE_CLIENT_SECRET"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
},
)

if resp.status_code != 200:
logger.error("OAuth token refresh failed: %s", resp.text)
return

tokens = resp.json()
state["access_token"] = tokens["access_token"]
state["expires_at"] = now + tokens.get("expires_in", 3600)
state["refresh_token"] = tokens.get("refresh_token", state.get("refresh_token"))

await runner.session_service.update_session(
app_name=runner.app_name,
user_id=session.user_id,
session_id=session.id,
state=state,
)

logger.info("OAuth token refreshed successfully.")
async def _prepare_session(
self,
context: RequestContext,
Expand Down
119 changes: 119 additions & 0 deletions tests/unittests/a2a/executor/test_refresh_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest
import time
import os
import httpx
from unittest.mock import AsyncMock, MagicMock, patch


# Standalone copy of the method — no executor import needed
async def _refresh_token_if_expired(session, runner):
state = session.state
if not state:
return

refresh_token = state.get("refresh_token")
expires_at = state.get("expires_at", 0)

if not refresh_token:
return

now = int(time.time())
if now < expires_at:
return

async with httpx.AsyncClient() as client:
resp = await client.post(
"https://oauth2.googleapis.com/token",
data={
"client_id": os.environ["GOOGLE_CLIENT_ID"],
"client_secret": os.environ["GOOGLE_CLIENT_SECRET"],
"refresh_token": refresh_token,
"grant_type": "refresh_token",
},
)

if resp.status_code != 200:
return

tokens = resp.json()
state["access_token"] = tokens["access_token"]
state["expires_at"] = now + tokens.get("expires_in", 3600)
state["refresh_token"] = tokens.get("refresh_token", state.get("refresh_token"))

await runner.session_service.update_session(
app_name=runner.app_name,
user_id=session.user_id,
session_id=session.id,
state=state,
)


@pytest.mark.asyncio
async def test_token_not_expired_skips_refresh():
"""Token still valid — refresh should NOT be called."""
session = MagicMock()
session.state = {
"access_token": "valid_token",
"refresh_token": "refresh_token",
"expires_at": int(time.time()) + 9999,
}
runner = MagicMock()
runner.session_service.update_session = AsyncMock()

await _refresh_token_if_expired(session, runner)

runner.session_service.update_session.assert_not_called()
print("PASS — valid token, no refresh triggered")


@pytest.mark.asyncio
async def test_expired_token_triggers_refresh():
"""Token is expired — refresh SHOULD be called."""
session = MagicMock()
session.state = {
"access_token": "old_token",
"refresh_token": "my_refresh_token",
"expires_at": int(time.time()) - 100,
}
session.user_id = "user123"
session.id = "session123"

runner = MagicMock()
runner.app_name = "test_app"
runner.session_service.update_session = AsyncMock()

mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"access_token": "new_token",
"expires_in": 3600,
}

mock_client_instance = MagicMock()
mock_client_instance.post = AsyncMock(return_value=mock_response)

with patch("httpx.AsyncClient") as mock_client:
mock_client.return_value.__aenter__ = AsyncMock(
return_value=mock_client_instance
)
mock_client.return_value.__aexit__ = AsyncMock(return_value=False)

with patch.dict("os.environ", {
"GOOGLE_CLIENT_ID": "test_client_id",
"GOOGLE_CLIENT_SECRET": "test_secret",
}):
await _refresh_token_if_expired(session, runner)

runner.session_service.update_session.assert_called_once()
assert session.state["access_token"] == "new_token"
print("PASS — expired token was refreshed")


@pytest.mark.asyncio
async def test_no_refresh_token_skips_refresh():
"""No refresh_token in state — should skip silently."""
session = MagicMock()
session.state = {
"access_token": "some_token",
"expires_at": int(time.time()) - 100,
}
Loading