diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 3e8ed461e2..534549b18d 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -26,6 +26,7 @@ from a2a.server.tasks import InMemoryPushNotificationConfigStore from a2a.server.tasks import InMemoryTaskStore from a2a.server.tasks import PushNotificationConfigStore +from a2a.server.tasks import TaskStore from a2a.types import AgentCard from starlette.applications import Starlette @@ -84,6 +85,7 @@ def to_a2a( protocol: str = "http", agent_card: Optional[Union[AgentCard, str]] = None, push_config_store: Optional[PushNotificationConfigStore] = None, + task_store: Optional[TaskStore] = None, runner: Optional[Runner] = None, lifespan: Optional[Callable[[Starlette], AsyncIterator[None]]] = None, ) -> Starlette: @@ -100,6 +102,8 @@ def to_a2a( push_config_store: Optional A2A push notification config store. If not provided, an in-memory store will be created so push-notification config RPC methods are supported. + task_store: Optional A2A task store for persisting task state. If not + provided, an in-memory store will be created. runner: Optional pre-built Runner object. If not provided, a default runner will be created using in-memory services. lifespan: Optional async context manager for Starlette lifespan @@ -127,6 +131,11 @@ async def lifespan(app): await app.state.db.close() app = to_a2a(agent, lifespan=lifespan) + + # Or with a persistent task store: + from a2a.server.tasks import DatabaseTaskStore + task_store = DatabaseTaskStore(db_url="postgresql+asyncpg://...") + app = to_a2a(agent, task_store=task_store) """ # Set up ADK logging to ensure logs are visible when using uvicorn directly adk_logger = logging.getLogger("google_adk") @@ -145,7 +154,8 @@ async def create_runner() -> Runner: ) # Create A2A components - task_store = InMemoryTaskStore() + if task_store is None: + task_store = InMemoryTaskStore() agent_executor = A2aAgentExecutor( runner=runner or create_runner, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index fa1948d4e2..be303a9190 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -51,6 +51,7 @@ from .utils.service_factory import create_artifact_service_from_options from .utils.service_factory import create_memory_service_from_options from .utils.service_factory import create_session_service_from_options +from .utils.service_factory import create_task_store_from_options logger = logging.getLogger("google_adk." + __name__) @@ -85,6 +86,7 @@ def get_fast_api_app( allow_origins: Optional[list[str]] = None, web: bool, a2a: bool = False, + task_store_uri: Optional[str] = None, host: str = "127.0.0.1", port: int = 8000, url_prefix: Optional[str] = None, @@ -128,6 +130,8 @@ def get_fast_api_app( allow_origins: List of allowed origins for CORS. web: Whether to enable the web UI and serve its assets. a2a: Whether to enable Agent-to-Agent (A2A) protocol support. + task_store_uri: URI for the A2A task store. Uses in-memory task store if + None. Only used when ``a2a=True``. host: Host address for the server (defaults to 127.0.0.1). port: Port number for the server (defaults to 8000). url_prefix: Optional prefix for all URL routes. @@ -588,7 +592,6 @@ async def get_agent_builder( from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryPushNotificationConfigStore - from a2a.server.tasks import InMemoryTaskStore from a2a.types import AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -598,7 +601,9 @@ async def get_agent_builder( base_path = Path.cwd() / agents_dir # the root agents directory should be an existing folder if base_path.exists() and base_path.is_dir(): - a2a_task_store = InMemoryTaskStore() + a2a_task_store = create_task_store_from_options( + task_store_uri=task_store_uri, + ) def create_a2a_runner_loader(captured_app_name: str): """Factory function to create A2A runner with proper closure.""" diff --git a/src/google/adk/cli/service_registry.py b/src/google/adk/cli/service_registry.py index b1328958ef..7b8340c837 100644 --- a/src/google/adk/cli/service_registry.py +++ b/src/google/adk/cli/service_registry.py @@ -98,6 +98,7 @@ def __init__(self): self._session_factories: dict[str, ServiceFactory] = {} self._artifact_factories: dict[str, ServiceFactory] = {} self._memory_factories: dict[str, ServiceFactory] = {} + self._task_store_factories: dict[str, ServiceFactory] = {} def register_session_service( self, scheme: str, factory: ServiceFactory @@ -123,6 +124,12 @@ def register_memory_service( """Register a factory for a custom memory service URI scheme.""" self._memory_factories[scheme] = factory + def register_task_store_service( + self, scheme: str, factory: ServiceFactory + ) -> None: + """Register a factory for a custom A2A task store URI scheme.""" + self._task_store_factories[scheme] = factory + def create_session_service( self, uri: str, **kwargs ) -> BaseSessionService | None: @@ -150,6 +157,13 @@ def create_memory_service( return self._memory_factories[scheme](uri, **kwargs) return None + def create_task_store_service(self, uri: str, **kwargs: Any) -> Any: + """Create A2A task store from URI using registered factories.""" + scheme = urlparse(uri).scheme + if scheme and scheme in self._task_store_factories: + return self._task_store_factories[scheme](uri, **kwargs) + return None + def get_service_registry() -> ServiceRegistry: """Gets the singleton ServiceRegistry instance, initializing it if needed.""" @@ -333,6 +347,30 @@ def agentengine_memory_factory(uri: str, **kwargs): registry.register_memory_service("rag", rag_memory_factory) registry.register_memory_service("agentengine", agentengine_memory_factory) + # -- A2A Task Store Services -- + def memory_task_store_factory(uri: str, **kwargs: Any) -> Any: + from a2a.server.tasks import InMemoryTaskStore + + return InMemoryTaskStore() + + def database_task_store_factory(uri: str, **kwargs: Any) -> Any: + from a2a.server.tasks import DatabaseTaskStore + from sqlalchemy.ext.asyncio import create_async_engine + + engine = create_async_engine(uri) + return DatabaseTaskStore(engine=engine) + + registry.register_task_store_service("memory", memory_task_store_factory) + for scheme in [ + "postgresql", + "postgresql+asyncpg", + "mysql", + "mysql+aiomysql", + "sqlite", + "sqlite+aiosqlite", + ]: + registry.register_task_store_service(scheme, database_task_store_factory) + def _load_gcp_config( agents_dir: Optional[str], service_name: str @@ -437,5 +475,7 @@ def _register_services_from_yaml_config( registry.register_artifact_service(scheme, factory) elif service_type == "memory": registry.register_memory_service(scheme, factory) + elif service_type == "task_store": + registry.register_task_store_service(scheme, factory) else: logger.warning("Unknown service type in YAML: %s", service_type) diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py index f0ffcbe9eb..5c75f27020 100644 --- a/src/google/adk/cli/utils/service_factory.py +++ b/src/google/adk/cli/utils/service_factory.py @@ -326,3 +326,30 @@ def create_artifact_service_from_options( base_path, exc, ) + + +def create_task_store_from_options( + *, + task_store_uri: Optional[str] = None, +) -> Any: + """Creates an A2A task store based on CLI/web options.""" + from a2a.server.tasks import InMemoryTaskStore + + registry = get_service_registry() + + if task_store_uri: + logger.info( + "Using A2A task store URI: %s", + _redact_uri_for_log(task_store_uri), + ) + service = registry.create_task_store_service(task_store_uri) + if service is not None: + return service + + raise ValueError( + "Unsupported A2A task store URI: %s" + % _redact_uri_for_log(task_store_uri) + ) + + logger.info("Using in-memory A2A task store") + return InMemoryTaskStore() diff --git a/tests/unittests/a2a/integration/server.py b/tests/unittests/a2a/integration/server.py index c965a71091..bd01d824f2 100644 --- a/tests/unittests/a2a/integration/server.py +++ b/tests/unittests/a2a/integration/server.py @@ -74,21 +74,25 @@ async def run_async(self, **kwargs): def create_server_app( - run_async_fn=None, config: A2aAgentExecutorConfig | None = None + run_async_fn=None, + config: A2aAgentExecutorConfig | None = None, + task_store=None, ): """Creates an A2A FastAPI application with a mocked runner. Args: run_async_fn: A generator function that takes **kwargs and yields Event objects. - include_artifacts: Whether to include artifacts in A2A events. + config: Optional executor configuration. + task_store: Optional task store instance. Defaults to InMemoryTaskStore. Returns: A FastAPI application instance. """ runner = FakeRunner(run_async_fn) executor = A2aAgentExecutor(runner=runner, config=config) - task_store = InMemoryTaskStore() + if task_store is None: + task_store = InMemoryTaskStore() handler = DefaultRequestHandler( agent_executor=executor, task_store=task_store ) diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index a9e2458ebd..b5012a6535 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -168,6 +168,80 @@ def test_to_a2a_passes_custom_push_config_store( task_store=mock_task_store, ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") + @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") + @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore") + @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") + @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") + def test_to_a2a_with_custom_task_store( + self, + mock_starlette_class, + mock_card_builder_class, + mock_task_store_class, + mock_request_handler_class, + mock_agent_executor_class, + ): + """Test to_a2a with a custom task store.""" + # Arrange + mock_app = Mock(spec=Starlette) + mock_starlette_class.return_value = mock_app + mock_agent_executor = Mock(spec=A2aAgentExecutor) + mock_agent_executor_class.return_value = mock_agent_executor + mock_request_handler = Mock(spec=DefaultRequestHandler) + mock_request_handler_class.return_value = mock_request_handler + mock_card_builder = Mock(spec=AgentCardBuilder) + mock_card_builder_class.return_value = mock_card_builder + custom_task_store = Mock() + + # Act + result = to_a2a(self.mock_agent, task_store=custom_task_store) + + # Assert + assert result == mock_app + mock_task_store_class.assert_not_called() + mock_request_handler_class.assert_called_once_with( + agent_executor=mock_agent_executor, + push_config_store=ANY, + task_store=custom_task_store, + ) + + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") + @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") + @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore") + @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") + @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") + def test_to_a2a_default_task_store_when_none( + self, + mock_starlette_class, + mock_card_builder_class, + mock_task_store_class, + mock_request_handler_class, + mock_agent_executor_class, + ): + """Test to_a2a defaults to InMemoryTaskStore when task_store is None.""" + # Arrange + mock_app = Mock(spec=Starlette) + mock_starlette_class.return_value = mock_app + mock_task_store = Mock(spec=InMemoryTaskStore) + mock_task_store_class.return_value = mock_task_store + mock_agent_executor = Mock(spec=A2aAgentExecutor) + mock_agent_executor_class.return_value = mock_agent_executor + mock_request_handler = Mock(spec=DefaultRequestHandler) + mock_request_handler_class.return_value = mock_request_handler + mock_card_builder = Mock(spec=AgentCardBuilder) + mock_card_builder_class.return_value = mock_card_builder + + # Act + result = to_a2a(self.mock_agent, task_store=None) + + # Assert + mock_task_store_class.assert_called_once() + mock_request_handler_class.assert_called_once_with( + agent_executor=mock_agent_executor, + push_config_store=ANY, + task_store=mock_task_store, + ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore") diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 3e63f31222..2a568a420c 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -849,7 +849,10 @@ def test_app_with_a2a( "google.adk.cli.fast_api.LocalEvalSetResultsManager", return_value=mock_eval_set_results_manager, ), - patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store, + patch( + "google.adk.cli.fast_api.create_task_store_from_options", + return_value=MagicMock(), + ), patch( "google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor" ) as mock_executor, @@ -859,7 +862,6 @@ def test_app_with_a2a( patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, ): # Configure mocks - mock_task_store.return_value = MagicMock() mock_executor.return_value = MagicMock() mock_handler.return_value = MagicMock() @@ -1814,7 +1816,9 @@ def test_a2a_request_handler_uses_push_config_store( "google.adk.cli.fast_api.LocalEvalSetResultsManager", return_value=mock_eval_set_results_manager, ), - patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store, + patch( + "google.adk.cli.fast_api.create_task_store_from_options", + ) as mock_create_task_store, patch( "a2a.server.tasks.InMemoryPushNotificationConfigStore" ) as mock_push_config_store_class, @@ -1827,7 +1831,7 @@ def test_a2a_request_handler_uses_push_config_store( patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, ): mock_task_store_instance = MagicMock() - mock_task_store.return_value = mock_task_store_instance + mock_create_task_store.return_value = mock_task_store_instance mock_push_config_store = MagicMock() mock_push_config_store_class.return_value = mock_push_config_store mock_executor_instance = MagicMock() @@ -1857,6 +1861,86 @@ def test_a2a_request_handler_uses_push_config_store( ) +def test_a2a_request_handler_uses_task_store_uri( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + temp_agents_dir_with_a2a, + monkeypatch, +): + """Test A2A request handler uses task store created from URI.""" + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.create_session_service_from_options", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.create_artifact_service_from_options", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.create_memory_service_from_options", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + patch( + "google.adk.cli.fast_api.create_task_store_from_options", + ) as mock_create_task_store, + patch( + "google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor" + ) as mock_executor, + patch( + "a2a.server.request_handlers.DefaultRequestHandler" + ) as mock_handler, + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + ): + custom_task_store = MagicMock() + mock_create_task_store.return_value = custom_task_store + mock_executor_instance = MagicMock() + mock_executor.return_value = mock_executor_instance + mock_handler.return_value = MagicMock() + mock_a2a_app_instance = MagicMock() + mock_a2a_app_instance.routes.return_value = [] + mock_a2a_app.return_value = mock_a2a_app_instance + + test_uri = "postgresql+asyncpg://user:pass@host/db" + monkeypatch.chdir(temp_agents_dir_with_a2a) + _ = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + task_store_uri=test_uri, + host="127.0.0.1", + port=8000, + ) + + mock_create_task_store.assert_called_once_with( + task_store_uri=test_uri, + ) + mock_handler.assert_called_once() + call_kwargs = mock_handler.call_args[1] + assert call_kwargs["task_store"] is custom_task_store + + def test_a2a_disabled_by_default(test_app): """Test that A2A functionality is disabled by default.""" # The regular test_app fixture has a2a=False diff --git a/tests/unittests/cli/test_service_registry.py b/tests/unittests/cli/test_service_registry.py index dd33e00641..58dd1a88dc 100644 --- a/tests/unittests/cli/test_service_registry.py +++ b/tests/unittests/cli/test_service_registry.py @@ -172,14 +172,37 @@ def test_create_memory_service_memory(registry): assert isinstance(memory_service, InMemoryMemoryService) +# Task Store Tests +def test_create_task_store_memory(registry): + from a2a.server.tasks import InMemoryTaskStore + + task_store = registry.create_task_store_service("memory://") + assert isinstance(task_store, InMemoryTaskStore) + + +@patch("sqlalchemy.ext.asyncio.create_async_engine") +@patch("a2a.server.tasks.DatabaseTaskStore") +def test_create_task_store_postgresql( + mock_db_task_store, mock_create_engine, registry +): + mock_engine = mock_create_engine.return_value + registry.create_task_store_service("postgresql+asyncpg://user:pass@host/db") + mock_create_engine.assert_called_once_with( + "postgresql+asyncpg://user:pass@host/db" + ) + mock_db_task_store.assert_called_once_with(engine=mock_engine) + + # General Tests def test_unsupported_scheme(registry, mock_services): session_service = registry.create_session_service("unsupported://foo") artifact_service = registry.create_artifact_service("unsupported://foo") memory_service = registry.create_memory_service("unsupported://foo") + task_store = registry.create_task_store_service("unsupported://foo") assert session_service is None assert artifact_service is None assert memory_service is None + assert task_store is None for service in [ "vertex_session", "db_session", diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py index d6f1426a81..e7953c602d 100644 --- a/tests/unittests/cli/utils/test_service_factory.py +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -445,3 +445,38 @@ def _raise_permission_error(*_args, **_kwargs): ) assert isinstance(service, InMemoryArtifactService) + + +def test_create_task_store_uses_registry(monkeypatch): + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) + expected = object() + registry.create_task_store_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_task_store_from_options( + task_store_uri="postgresql+asyncpg://user:pass@host/db", + ) + + assert result is expected + registry.create_task_store_service.assert_called_once_with( + "postgresql+asyncpg://user:pass@host/db", + ) + + +def test_create_task_store_defaults_to_in_memory(): + from a2a.server.tasks import InMemoryTaskStore + + service = service_factory.create_task_store_from_options() + + assert isinstance(service, InMemoryTaskStore) + + +def test_create_task_store_raises_on_unknown_scheme(monkeypatch): + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) + registry.create_task_store_service.return_value = None + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + with pytest.raises(ValueError): + service_factory.create_task_store_from_options( + task_store_uri="unknown://foo", + )