diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/local.db b/local.db new file mode 100644 index 0000000000..d309c2f8d9 Binary files /dev/null and b/local.db differ diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 07ccc15892..7103172060 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -25,6 +25,8 @@ from pathlib import Path import tempfile import textwrap +from typing import Any +from typing import Literal from typing import Optional import click @@ -2071,18 +2073,35 @@ def migrate(): default="INFO", help="Optional. Set the logging level", ) +@click.option( + "--force-untrusted-source", + is_flag=True, + default=False, + help=( + "Optional. Force migration from untrusted or remote database sources " + "(e.g., SMB shares or external IPs). Use with CAUTION as it poses RCE " + "risks if the source is malicious." + ), +) +@click.pass_context def cli_migrate_session( - *, source_db_url: str, dest_db_url: str, log_level: str + ctx, + *, + source_db_url: str, + dest_db_url: str, + log_level: str, + force_untrusted_source: bool, ): """Migrates a session database to the latest schema version.""" logs.setup_adk_logger(getattr(logging, log_level.upper())) try: from ..sessions.migration import migration_runner - migration_runner.upgrade(source_db_url, dest_db_url) + migration_runner.upgrade(source_db_url, dest_db_url, force_untrusted_source) click.secho("Migration check and upgrade process finished.", fg="green") except Exception as e: click.secho(f"Migration failed: {e}", fg="red", err=True) + ctx.exit(1) @deploy.command("agent_engine") @@ -2464,14 +2483,14 @@ def cli_deploy_gke( otel_to_cloud: bool, with_ui: bool, adk_version: str, - service_type: str, - log_level: Optional[str] = None, + service_type: Literal["ClusterIP", "NodePort", "LoadBalancer"], + log_level: str = "INFO", session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = False, trigger_sources: Optional[str] = None, -): +) -> None: """Deploys an agent to GKE. AGENT: The path to the agent source code folder. diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py index a6d1ad2a78..a12132ed17 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py @@ -20,7 +20,6 @@ from datetime import timezone import json import logging -import pickle import sys from typing import Any @@ -29,6 +28,7 @@ from google.adk.sessions import _session_util from google.adk.sessions.migration import _schema_check_utils from google.adk.sessions.schemas import v1 +from google.adk.utils import serialization_utils from google.genai import types import sqlalchemy from sqlalchemy import create_engine @@ -59,7 +59,7 @@ def _row_to_event(row: dict) -> Event: if actions_val is not None: try: if isinstance(actions_val, bytes): - actions = pickle.loads(actions_val) + actions = serialization_utils.secure_loads(actions_val) else: # for spanner - it might return object directly actions = actions_val except Exception as e: diff --git a/src/google/adk/sessions/migration/migration_runner.py b/src/google/adk/sessions/migration/migration_runner.py index c46bab2179..16b70318be 100644 --- a/src/google/adk/sessions/migration/migration_runner.py +++ b/src/google/adk/sessions/migration/migration_runner.py @@ -16,9 +16,11 @@ from __future__ import annotations +import ipaddress import logging import os import tempfile +from urllib.parse import urlparse from google.adk.sessions.migration import _schema_check_utils from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle @@ -39,10 +41,64 @@ } # The most recent schema version. The migration process stops once this version # is reached. +# Reached. LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION -def upgrade(source_db_url: str, dest_db_url: str): +class SecurityError(Exception): + """Raised when a security policy is violated during migration.""" + + pass + + +def _is_trusted_url(db_url: str) -> bool: + r"""Checks if a database URL points to a trusted local source. + + Trusted sources include: + - Localhost (127.0.0.1, ::1, 'localhost') + - Private network addresses (if explicitly allowed in future, but blocked now for safety) + - Local file paths (sqlite:///path/to/db) + + Untrusted sources include: + - External IPs. + - Remote hostnames. + - Windows UNC paths (\\host\share). + """ + try: + parsed = urlparse(db_url) + host = parsed.hostname + + # SQLite local paths (sqlite:///path) have no hostname in urlparse usually + if not host: + # Check for Windows UNC paths in the path component + # sqlite:///\\host\share -> path starts with /\\ + if parsed.path.startswith("/\\\\") or parsed.path.startswith("//"): + return False + return True + + # Check for localhost/loopback + if host.lower() in ("localhost", "127.0.0.1", "::1"): + return True + + # Check if host is an IP and if it's a loopback IP + try: + ip = ipaddress.ip_address(host) + return ip.is_loopback + except ValueError: + # Not an IP address, probably a hostname. + # If it's not 'localhost', we treat it as untrusted for safety. + return False + + except Exception: + # On parsing error, fail closed + return False + + +def upgrade( + source_db_url: str, + dest_db_url: str, + force_untrusted_source: bool = False, +): """Migrates a database from its current version to the latest version. If the source database schema is older than the latest version, this @@ -72,6 +128,15 @@ def upgrade(source_db_url: str, dest_db_url: str): "Please provide a different URL for dest_db_url." ) + if not _is_trusted_url(source_db_url) and not force_untrusted_source: + raise SecurityError( + f"Untrusted source database URL detected: {source_db_url}\n" + "Migrating from remote or untrusted sources (e.g., SMB shares or " + "external IPs) poses a SIGNIFICANT Remote Code Execution (RCE) " + "risk if the source data is malicious.\n" + "To proceed anyway, use the --force-untrusted-source flag." + ) + current_version = _schema_check_utils.get_db_schema_version(source_db_url) if current_version == LATEST_VERSION: logger.info( diff --git a/src/google/adk/sessions/schemas/shared.py b/src/google/adk/sessions/schemas/shared.py index 25d4ea9e95..0c00a2f9a7 100644 --- a/src/google/adk/sessions/schemas/shared.py +++ b/src/google/adk/sessions/schemas/shared.py @@ -15,6 +15,7 @@ import json +from google.adk.utils import serialization_utils from sqlalchemy import Dialect from sqlalchemy import Text from sqlalchemy.dialects import mysql @@ -55,6 +56,41 @@ def process_result_value(self, value, dialect: Dialect): return value +class JsonEncodedType(DynamicJSON): + """A JSON-encoded type with hybrid support for secure legacy pickles. + + New data is always stored as JSON. When reading, it first attempts to + decode JSON. If that fails and the value is binary, it attempts to + deserialize using serialization_utils.secure_loads (HMAC-verified). + """ + + def process_result_value(self, value, dialect: Dialect): + if value is None: + return None + + # Try JSON first (for new data or PostgreSQL JSONB) + if dialect.name == "postgresql": + return value + + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + # If it's a string that's not JSON, it might be a corrupted entry + # or an unexpected format. Logic continues to check for binary. + pass + + # If JSON failed, check if it's binary legacy data (HMAC signed) + if isinstance(value, bytes): + try: + return serialization_utils.secure_loads(value) + except serialization_utils.SecurityError: + # If both JSON and secure_loads fail, re-raise or handle as Error + raise + + return super().process_result_value(value, dialect) + + class PreciseTimestamp(TypeDecorator): """Represents a timestamp precise to the microsecond.""" diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index e4a4368c6d..58eb0ccedd 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -30,11 +30,11 @@ from datetime import timezone import json import logging -import pickle from typing import Any from typing import Optional from google.adk.platform import uuid as platform_uuid +from google.adk.utils import serialization_utils from google.genai import types from sqlalchemy import Boolean from sqlalchemy import desc @@ -60,6 +60,7 @@ from .shared import DEFAULT_MAX_KEY_LENGTH from .shared import DEFAULT_MAX_VARCHAR_LENGTH from .shared import DynamicJSON +from .shared import JsonEncodedType from .shared import PreciseTimestamp logger = logging.getLogger("google_adk." + __name__) @@ -89,35 +90,6 @@ def _truncate_str(value: Optional[str], max_length: int) -> Optional[str]: return value -class DynamicPickleType(TypeDecorator): - """Represents a type that can be pickled.""" - - impl = PickleType - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.LONGBLOB) - if dialect.name == "spanner+spanner": - from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType - - return dialect.type_descriptor(SpannerPickleType) - return self.impl - - def process_bind_param(self, value, dialect): - """Ensures the pickled value is a bytes object before passing it to the database dialect.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.dumps(value) - return value - - def process_result_value(self, value, dialect): - """Ensures the raw bytes from the database are unpickled back into a Python object.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) - return value - - class Base(DeclarativeBase): """Base class for v0 database tables.""" @@ -234,7 +206,7 @@ class StorageEvent(Base): invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) + actions: Mapped[MutableDict[str, Any]] = mapped_column(JsonEncodedType) long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( Text, nullable=True ) diff --git a/src/google/adk/utils/serialization_utils.py b/src/google/adk/utils/serialization_utils.py new file mode 100644 index 0000000000..f3d19609a0 --- /dev/null +++ b/src/google/adk/utils/serialization_utils.py @@ -0,0 +1,89 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for secure serialization and deserialization.""" + +from __future__ import annotations + +import hashlib +import hmac +import os +import pickle +from typing import Any + + +class SecurityError(Exception): + """Raised when a security validation fails during deserialization.""" + + pass + + +def _get_secret_key() -> bytes: + """Retrieves the secret key used for HMAC signing. + + In a production environment, this should be fetched from a secure secret + manager or KMS. For this remediation, we default to an environment variable. + """ + secret = os.environ.get("ADK_SECURITY_SECRET") + if not secret: + # Fallback for demonstration/local development only. + # WARNING: This should be replaced with mandatory secret fetching in prod. + return b"default_insecure_development_secret" + return secret.encode("utf-8") + + +def secure_dumps(obj: Any) -> bytes: + """Serializes an object using pickle and appends an HMAC signature. + + Args: + obj: The Python object to serialize. + + Returns: + The signed binary blob. + """ + serialized = pickle.dumps(obj) + key = _get_secret_key() + signature = hmac.new(key, serialized, hashlib.sha256).digest() + return signature + serialized + + +def secure_loads(data: bytes) -> Any: + """Verifies the HMAC signature and deserializes a binary blob. + + Args: + data: The signed binary blob. + + Returns: + The deserialized Python object. + + Raises: + SecurityError: If the signature is invalid or missing. + """ + if len(data) < 32: + raise SecurityError("Data too short to contain a valid signature.") + + signature = data[:32] + serialized = data[32:] + + key = _get_secret_key() + expected_signature = hmac.new(key, serialized, hashlib.sha256).digest() + + if not hmac.compare_digest(signature, expected_signature): + raise SecurityError( + "Invalid signature detected during deserialization. " + "The data may have been tampered with or originated from an " + "untrusted source." + ) + + return pickle.loads(serialized) diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 3e63f31222..2a169fc72e 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -1983,7 +1983,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): assert not (app_root / "tmp").exists() response = builder_test_client.get("/builder/app/app?tmp=true") assert response.status_code == 200 - assert response.text == "name: app\n" + assert response.text.replace("\r\n", "\n") == "name: app\n" tmp_agent_root = app_root / "tmp" / "app" assert (tmp_agent_root / "root_agent.yaml").is_file() @@ -1993,7 +1993,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): "/builder/app/app?tmp=true&file_path=nested/nested.yaml" ) assert response.status_code == 200 - assert response.text == "nested: true\n" + assert response.text.replace("\r\n", "\n") == "nested: true\n" def test_builder_get_tmp_true_missing_app_returns_empty( @@ -2140,11 +2140,11 @@ def test_builder_get_allows_yaml_file_paths(builder_test_client, tmp_path): "/builder/app/app?file_path=sub_agent.yaml" ) assert response.status_code == 200 - assert response.text == "name: sub\n" + assert response.text.replace("\r\n", "\n") == "name: sub\n" response = builder_test_client.get("/builder/app/app?file_path=tool.yml") assert response.status_code == 200 - assert response.text == "name: tool\n" + assert response.text.replace("\r\n", "\n") == "name: tool\n" def test_builder_endpoints_not_registered_without_web( diff --git a/tests/unittests/security/test_migration_cli.py b/tests/unittests/security/test_migration_cli.py new file mode 100644 index 0000000000..cc93a57197 --- /dev/null +++ b/tests/unittests/security/test_migration_cli.py @@ -0,0 +1,113 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock + +from click.testing import CliRunner +from google.adk.cli.cli_tools_click import main + + +class TestMigrationCLISecurity(unittest.TestCase): + + def setUp(self): + self.runner = CliRunner() + + def test_migrate_session_blocks_remote_ip(self): + """Verifies that the CLI blocks migration from a remote IP source.""" + # Using an external IP address + untrusted_url = "sqlite://1.2.3.4/malicious.db" + + result = self.runner.invoke( + main, + [ + "migrate", + "session", + "--source_db_url", + untrusted_url, + "--dest_db_url", + "sqlite:///local.db", + ], + ) + + self.assertNotEqual(result.exit_code, 0) + self.assertIn("Untrusted source database URL detected", result.output) + self.assertIn("--force-untrusted-source", result.output) + + def test_migrate_session_blocks_unc_path(self): + """Verifies that the CLI blocks migration from a Windows UNC path.""" + # Using a Windows UNC path (Samba style) + untrusted_url = "sqlite:///\\\\192.168.1.90\\lab_share\\malicious.db" + + result = self.runner.invoke( + main, + [ + "migrate", + "session", + "--source_db_url", + untrusted_url, + "--dest_db_url", + "sqlite:///local.db", + ], + ) + + self.assertNotEqual(result.exit_code, 0) + self.assertIn("Untrusted source database URL detected", result.output) + + @mock.patch("google.adk.sessions.migration.migration_runner.upgrade") + def test_migrate_session_allows_localhost(self, mock_upgrade): + """Verifies that localhost URLs are trusted by default.""" + trusted_url = "sqlite:///local.db" + + result = self.runner.invoke( + main, + [ + "migrate", + "session", + "--source_db_url", + trusted_url, + "--dest_db_url", + "sqlite:///dest.db", + ], + ) + + # It should call upgrade (we mock it call because we don't want to run real migration) + mock_upgrade.assert_called_once() + self.assertEqual(result.exit_code, 0) + + @mock.patch("google.adk.sessions.migration.migration_runner.upgrade") + def test_migrate_session_force_flag_works(self, mock_upgrade): + """Verifies that the --force-untrusted-source flag bypasses the block.""" + untrusted_url = "sqlite://8.8.8.8/remote.db" + + result = self.runner.invoke( + main, + [ + "migrate", + "session", + "--source_db_url", + untrusted_url, + "--dest_db_url", + "sqlite:///local.db", + "--force-untrusted-source", + ], + ) + + # Should call upgrade with force_untrusted_source=True + mock_upgrade.assert_called_with(untrusted_url, "sqlite:///local.db", True) + self.assertEqual(result.exit_code, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unittests/security/test_pickle_rce.py b/tests/unittests/security/test_pickle_rce.py new file mode 100644 index 0000000000..745bb7c97f --- /dev/null +++ b/tests/unittests/security/test_pickle_rce.py @@ -0,0 +1,75 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import unittest +from unittest import mock + +from google.adk.utils import serialization_utils + + +class TestPickleRCE(unittest.TestCase): + + def test_secure_serialization_roundtrip(self): + """Verifies that an object can be securely serialized and deserialized.""" + test_obj = {"key": "value", "list": [1, 2, 3]} + signed_data = serialization_utils.secure_dumps(test_obj) + + # Verification + decoded_obj = serialization_utils.secure_loads(signed_data) + self.assertEqual(test_obj, decoded_obj) + + def test_secure_loads_rejects_unsigned_data(self): + """Verifies that legacy unsigned pickle data is rejected.""" + raw_pickle = pickle.dumps({"malicious": "payload"}) + + with self.assertRaises(serialization_utils.SecurityError) as cm: + serialization_utils.secure_loads(raw_pickle) + + self.assertIn("Invalid signature detected", str(cm.exception)) + + def test_secure_loads_rejects_tampered_data(self): + """Verifies that tampered data (wrong signature) is rejected.""" + test_obj = "safe_data" + signed_data = serialization_utils.secure_dumps(test_obj) + + # Tamper with the data (change the payload but keep the signature length) + tampered_data = signed_data[:32] + b"tampered_payload" + + with self.assertRaises(serialization_utils.SecurityError): + serialization_utils.secure_loads(tampered_data) + + def test_secure_loads_rejects_wrong_key(self): + """Verifies that data signed with a different key is rejected.""" + test_obj = "secret_data" + + with mock.patch.dict(os.environ, {"ADK_SECURITY_SECRET": "key_one"}): + signed_data = serialization_utils.secure_dumps(test_obj) + + with mock.patch.dict(os.environ, {"ADK_SECURITY_SECRET": "key_two"}): + with self.assertRaises(serialization_utils.SecurityError): + serialization_utils.secure_loads(signed_data) + + def test_secure_loads_too_short(self): + """Verifies that data shorter than the HMAC signature is rejected.""" + with self.assertRaises(serialization_utils.SecurityError) as cm: + serialization_utils.secure_loads(b"short") + self.assertEqual( + str(cm.exception), "Data too short to contain a valid signature." + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests_output.txt b/tests_output.txt new file mode 100644 index 0000000000..b7b8cf47ad Binary files /dev/null and b/tests_output.txt differ