Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/google/adk/a2a/utils/agent_to_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from a2a.server.tasks import InMemoryPushNotificationConfigStore
from a2a.server.tasks import InMemoryTaskStore
from a2a.server.tasks import PushNotificationConfigStore
from a2a.server.tasks import TaskStore
from a2a.types import AgentCard
from starlette.applications import Starlette

Expand Down Expand Up @@ -84,6 +85,7 @@ def to_a2a(
protocol: str = "http",
agent_card: Optional[Union[AgentCard, str]] = None,
push_config_store: Optional[PushNotificationConfigStore] = None,
task_store: Optional[TaskStore] = None,
runner: Optional[Runner] = None,
lifespan: Optional[Callable[[Starlette], AsyncIterator[None]]] = None,
) -> Starlette:
Expand All @@ -100,6 +102,8 @@ def to_a2a(
push_config_store: Optional A2A push notification config store. If not
provided, an in-memory store will be created so push-notification
config RPC methods are supported.
task_store: Optional A2A task store for persisting task state. If not
provided, an in-memory store will be created.
runner: Optional pre-built Runner object. If not provided, a default
runner will be created using in-memory services.
lifespan: Optional async context manager for Starlette lifespan
Expand Down Expand Up @@ -127,6 +131,11 @@ async def lifespan(app):
await app.state.db.close()

app = to_a2a(agent, lifespan=lifespan)

# Or with a persistent task store:
from a2a.server.tasks import DatabaseTaskStore
task_store = DatabaseTaskStore(db_url="postgresql+asyncpg://...")
app = to_a2a(agent, task_store=task_store)
"""
# Set up ADK logging to ensure logs are visible when using uvicorn directly
adk_logger = logging.getLogger("google_adk")
Expand All @@ -145,7 +154,8 @@ async def create_runner() -> Runner:
)

# Create A2A components
task_store = InMemoryTaskStore()
if task_store is None:
task_store = InMemoryTaskStore()

agent_executor = A2aAgentExecutor(
runner=runner or create_runner,
Expand Down
9 changes: 7 additions & 2 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .utils.service_factory import create_artifact_service_from_options
from .utils.service_factory import create_memory_service_from_options
from .utils.service_factory import create_session_service_from_options
from .utils.service_factory import create_task_store_from_options

logger = logging.getLogger("google_adk." + __name__)

Expand Down Expand Up @@ -85,6 +86,7 @@ def get_fast_api_app(
allow_origins: Optional[list[str]] = None,
web: bool,
a2a: bool = False,
task_store_uri: Optional[str] = None,
host: str = "127.0.0.1",
port: int = 8000,
url_prefix: Optional[str] = None,
Expand Down Expand Up @@ -128,6 +130,8 @@ def get_fast_api_app(
allow_origins: List of allowed origins for CORS.
web: Whether to enable the web UI and serve its assets.
a2a: Whether to enable Agent-to-Agent (A2A) protocol support.
task_store_uri: URI for the A2A task store. Uses in-memory task store if
None. Only used when ``a2a=True``.
host: Host address for the server (defaults to 127.0.0.1).
port: Port number for the server (defaults to 8000).
url_prefix: Optional prefix for all URL routes.
Expand Down Expand Up @@ -588,7 +592,6 @@ async def get_agent_builder(
from a2a.server.apps import A2AStarletteApplication
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryPushNotificationConfigStore
from a2a.server.tasks import InMemoryTaskStore
from a2a.types import AgentCard
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH

Expand All @@ -598,7 +601,9 @@ async def get_agent_builder(
base_path = Path.cwd() / agents_dir
# the root agents directory should be an existing folder
if base_path.exists() and base_path.is_dir():
a2a_task_store = InMemoryTaskStore()
a2a_task_store = create_task_store_from_options(
task_store_uri=task_store_uri,
)

def create_a2a_runner_loader(captured_app_name: str):
"""Factory function to create A2A runner with proper closure."""
Expand Down
40 changes: 40 additions & 0 deletions src/google/adk/cli/service_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(self):
self._session_factories: dict[str, ServiceFactory] = {}
self._artifact_factories: dict[str, ServiceFactory] = {}
self._memory_factories: dict[str, ServiceFactory] = {}
self._task_store_factories: dict[str, ServiceFactory] = {}

def register_session_service(
self, scheme: str, factory: ServiceFactory
Expand All @@ -123,6 +124,12 @@ def register_memory_service(
"""Register a factory for a custom memory service URI scheme."""
self._memory_factories[scheme] = factory

def register_task_store_service(
self, scheme: str, factory: ServiceFactory
) -> None:
"""Register a factory for a custom A2A task store URI scheme."""
self._task_store_factories[scheme] = factory

def create_session_service(
self, uri: str, **kwargs
) -> BaseSessionService | None:
Expand Down Expand Up @@ -150,6 +157,13 @@ def create_memory_service(
return self._memory_factories[scheme](uri, **kwargs)
return None

def create_task_store_service(self, uri: str, **kwargs: Any) -> Any:
"""Create A2A task store from URI using registered factories."""
scheme = urlparse(uri).scheme
if scheme and scheme in self._task_store_factories:
return self._task_store_factories[scheme](uri, **kwargs)
return None


def get_service_registry() -> ServiceRegistry:
"""Gets the singleton ServiceRegistry instance, initializing it if needed."""
Expand Down Expand Up @@ -333,6 +347,30 @@ def agentengine_memory_factory(uri: str, **kwargs):
registry.register_memory_service("rag", rag_memory_factory)
registry.register_memory_service("agentengine", agentengine_memory_factory)

# -- A2A Task Store Services --
def memory_task_store_factory(uri: str, **kwargs: Any) -> Any:
from a2a.server.tasks import InMemoryTaskStore

return InMemoryTaskStore()

def database_task_store_factory(uri: str, **kwargs: Any) -> Any:
from a2a.server.tasks import DatabaseTaskStore
from sqlalchemy.ext.asyncio import create_async_engine

engine = create_async_engine(uri)
return DatabaseTaskStore(engine=engine)

registry.register_task_store_service("memory", memory_task_store_factory)
for scheme in [
"postgresql",
"postgresql+asyncpg",
"mysql",
"mysql+aiomysql",
"sqlite",
"sqlite+aiosqlite",
]:
registry.register_task_store_service(scheme, database_task_store_factory)


def _load_gcp_config(
agents_dir: Optional[str], service_name: str
Expand Down Expand Up @@ -437,5 +475,7 @@ def _register_services_from_yaml_config(
registry.register_artifact_service(scheme, factory)
elif service_type == "memory":
registry.register_memory_service(scheme, factory)
elif service_type == "task_store":
registry.register_task_store_service(scheme, factory)
else:
logger.warning("Unknown service type in YAML: %s", service_type)
27 changes: 27 additions & 0 deletions src/google/adk/cli/utils/service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,30 @@ def create_artifact_service_from_options(
base_path,
exc,
)


def create_task_store_from_options(
*,
task_store_uri: Optional[str] = None,
) -> Any:
"""Creates an A2A task store based on CLI/web options."""
from a2a.server.tasks import InMemoryTaskStore

registry = get_service_registry()

if task_store_uri:
logger.info(
"Using A2A task store URI: %s",
_redact_uri_for_log(task_store_uri),
)
service = registry.create_task_store_service(task_store_uri)
if service is not None:
return service

raise ValueError(
"Unsupported A2A task store URI: %s"
% _redact_uri_for_log(task_store_uri)
)

logger.info("Using in-memory A2A task store")
return InMemoryTaskStore()
10 changes: 7 additions & 3 deletions tests/unittests/a2a/integration/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,25 @@ async def run_async(self, **kwargs):


def create_server_app(
run_async_fn=None, config: A2aAgentExecutorConfig | None = None
run_async_fn=None,
config: A2aAgentExecutorConfig | None = None,
task_store=None,
):
"""Creates an A2A FastAPI application with a mocked runner.

Args:
run_async_fn: A generator function that takes **kwargs and yields Event
objects.
include_artifacts: Whether to include artifacts in A2A events.
config: Optional executor configuration.
task_store: Optional task store instance. Defaults to InMemoryTaskStore.

Returns:
A FastAPI application instance.
"""
runner = FakeRunner(run_async_fn)
executor = A2aAgentExecutor(runner=runner, config=config)
task_store = InMemoryTaskStore()
if task_store is None:
task_store = InMemoryTaskStore()
handler = DefaultRequestHandler(
agent_executor=executor, task_store=task_store
)
Expand Down
74 changes: 74 additions & 0 deletions tests/unittests/a2a/utils/test_agent_to_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,80 @@ def test_to_a2a_passes_custom_push_config_store(
task_store=mock_task_store,
)

@patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor")
@patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler")
@patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore")
@patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder")
@patch("google.adk.a2a.utils.agent_to_a2a.Starlette")
def test_to_a2a_with_custom_task_store(
self,
mock_starlette_class,
mock_card_builder_class,
mock_task_store_class,
mock_request_handler_class,
mock_agent_executor_class,
):
"""Test to_a2a with a custom task store."""
# Arrange
mock_app = Mock(spec=Starlette)
mock_starlette_class.return_value = mock_app
mock_agent_executor = Mock(spec=A2aAgentExecutor)
mock_agent_executor_class.return_value = mock_agent_executor
mock_request_handler = Mock(spec=DefaultRequestHandler)
mock_request_handler_class.return_value = mock_request_handler
mock_card_builder = Mock(spec=AgentCardBuilder)
mock_card_builder_class.return_value = mock_card_builder
custom_task_store = Mock()

# Act
result = to_a2a(self.mock_agent, task_store=custom_task_store)

# Assert
assert result == mock_app
mock_task_store_class.assert_not_called()
mock_request_handler_class.assert_called_once_with(
agent_executor=mock_agent_executor,
push_config_store=ANY,
task_store=custom_task_store,
)

@patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor")
@patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler")
@patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore")
@patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder")
@patch("google.adk.a2a.utils.agent_to_a2a.Starlette")
def test_to_a2a_default_task_store_when_none(
self,
mock_starlette_class,
mock_card_builder_class,
mock_task_store_class,
mock_request_handler_class,
mock_agent_executor_class,
):
"""Test to_a2a defaults to InMemoryTaskStore when task_store is None."""
# Arrange
mock_app = Mock(spec=Starlette)
mock_starlette_class.return_value = mock_app
mock_task_store = Mock(spec=InMemoryTaskStore)
mock_task_store_class.return_value = mock_task_store
mock_agent_executor = Mock(spec=A2aAgentExecutor)
mock_agent_executor_class.return_value = mock_agent_executor
mock_request_handler = Mock(spec=DefaultRequestHandler)
mock_request_handler_class.return_value = mock_request_handler
mock_card_builder = Mock(spec=AgentCardBuilder)
mock_card_builder_class.return_value = mock_card_builder

# Act
result = to_a2a(self.mock_agent, task_store=None)

# Assert
mock_task_store_class.assert_called_once()
mock_request_handler_class.assert_called_once_with(
agent_executor=mock_agent_executor,
push_config_store=ANY,
task_store=mock_task_store,
)

@patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor")
@patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler")
@patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore")
Expand Down
Loading
Loading