From 95d0947e905d4d115baf2223af83d8222aa3da6b Mon Sep 17 00:00:00 2001 From: Joshua Provoste <8358462+JoshuaProvoste@users.noreply.github.com> Date: Tue, 14 Apr 2026 18:52:38 -0400 Subject: [PATCH 1/3] security: remediate pickle RCE sinks using HMAC-SHA256 signatures --- .../migrate_from_sqlalchemy_pickle.py | 4 +- src/google/adk/sessions/schemas/v0.py | 6 +- src/google/adk/utils/serialization_utils.py | 87 +++++++++++++++++++ tests/unittests/security/test_pickle_rce.py | 75 ++++++++++++++++ 4 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 src/google/adk/utils/serialization_utils.py create mode 100644 tests/unittests/security/test_pickle_rce.py diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py index a6d1ad2a78..a12132ed17 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py @@ -20,7 +20,6 @@ from datetime import timezone import json import logging -import pickle import sys from typing import Any @@ -29,6 +28,7 @@ from google.adk.sessions import _session_util from google.adk.sessions.migration import _schema_check_utils from google.adk.sessions.schemas import v1 +from google.adk.utils import serialization_utils from google.genai import types import sqlalchemy from sqlalchemy import create_engine @@ -59,7 +59,7 @@ def _row_to_event(row: dict) -> Event: if actions_val is not None: try: if isinstance(actions_val, bytes): - actions = pickle.loads(actions_val) + actions = serialization_utils.secure_loads(actions_val) else: # for spanner - it might return object directly actions = actions_val except Exception as e: diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index e4a4368c6d..41579e321c 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -30,11 +30,11 @@ from datetime import timezone import json import logging -import pickle from typing import Any from typing import Optional from google.adk.platform import uuid as platform_uuid +from google.adk.utils import serialization_utils from google.genai import types from sqlalchemy import Boolean from sqlalchemy import desc @@ -107,14 +107,14 @@ def process_bind_param(self, value, dialect): """Ensures the pickled value is a bytes object before passing it to the database dialect.""" if value is not None: if dialect.name in ("spanner+spanner", "mysql"): - return pickle.dumps(value) + return serialization_utils.secure_dumps(value) return value def process_result_value(self, value, dialect): """Ensures the raw bytes from the database are unpickled back into a Python object.""" if value is not None: if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) + return serialization_utils.secure_loads(value) return value diff --git a/src/google/adk/utils/serialization_utils.py b/src/google/adk/utils/serialization_utils.py new file mode 100644 index 0000000000..2ea03d1e86 --- /dev/null +++ b/src/google/adk/utils/serialization_utils.py @@ -0,0 +1,87 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for secure serialization and deserialization.""" + +import hashlib +import hmac +import os +import pickle +from typing import Any + + +class SecurityError(Exception): + """Raised when a security validation fails during deserialization.""" + + pass + + +def _get_secret_key() -> bytes: + """Retrieves the secret key used for HMAC signing. + + In a production environment, this should be fetched from a secure secret + manager or KMS. For this remediation, we default to an environment variable. + """ + secret = os.environ.get("ADK_SECURITY_SECRET") + if not secret: + # Fallback for demonstration/local development only. + # WARNING: This should be replaced with mandatory secret fetching in prod. + return b"default_insecure_development_secret" + return secret.encode("utf-8") + + +def secure_dumps(obj: Any) -> bytes: + """Serializes an object using pickle and appends an HMAC signature. + + Args: + obj: The Python object to serialize. + + Returns: + The signed binary blob. + """ + serialized = pickle.dumps(obj) + key = _get_secret_key() + signature = hmac.new(key, serialized, hashlib.sha256).digest() + return signature + serialized + + +def secure_loads(data: bytes) -> Any: + """Verifies the HMAC signature and deserializes a binary blob. + + Args: + data: The signed binary blob. + + Returns: + The deserialized Python object. + + Raises: + SecurityError: If the signature is invalid or missing. + """ + if len(data) < 32: + raise SecurityError("Data too short to contain a valid signature.") + + signature = data[:32] + serialized = data[32:] + + key = _get_secret_key() + expected_signature = hmac.new(key, serialized, hashlib.sha256).digest() + + if not hmac.compare_digest(signature, expected_signature): + raise SecurityError( + "Invalid signature detected during deserialization. " + "The data may have been tampered with or originated from an " + "untrusted source." + ) + + return pickle.loads(serialized) diff --git a/tests/unittests/security/test_pickle_rce.py b/tests/unittests/security/test_pickle_rce.py new file mode 100644 index 0000000000..745bb7c97f --- /dev/null +++ b/tests/unittests/security/test_pickle_rce.py @@ -0,0 +1,75 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import unittest +from unittest import mock + +from google.adk.utils import serialization_utils + + +class TestPickleRCE(unittest.TestCase): + + def test_secure_serialization_roundtrip(self): + """Verifies that an object can be securely serialized and deserialized.""" + test_obj = {"key": "value", "list": [1, 2, 3]} + signed_data = serialization_utils.secure_dumps(test_obj) + + # Verification + decoded_obj = serialization_utils.secure_loads(signed_data) + self.assertEqual(test_obj, decoded_obj) + + def test_secure_loads_rejects_unsigned_data(self): + """Verifies that legacy unsigned pickle data is rejected.""" + raw_pickle = pickle.dumps({"malicious": "payload"}) + + with self.assertRaises(serialization_utils.SecurityError) as cm: + serialization_utils.secure_loads(raw_pickle) + + self.assertIn("Invalid signature detected", str(cm.exception)) + + def test_secure_loads_rejects_tampered_data(self): + """Verifies that tampered data (wrong signature) is rejected.""" + test_obj = "safe_data" + signed_data = serialization_utils.secure_dumps(test_obj) + + # Tamper with the data (change the payload but keep the signature length) + tampered_data = signed_data[:32] + b"tampered_payload" + + with self.assertRaises(serialization_utils.SecurityError): + serialization_utils.secure_loads(tampered_data) + + def test_secure_loads_rejects_wrong_key(self): + """Verifies that data signed with a different key is rejected.""" + test_obj = "secret_data" + + with mock.patch.dict(os.environ, {"ADK_SECURITY_SECRET": "key_one"}): + signed_data = serialization_utils.secure_dumps(test_obj) + + with mock.patch.dict(os.environ, {"ADK_SECURITY_SECRET": "key_two"}): + with self.assertRaises(serialization_utils.SecurityError): + serialization_utils.secure_loads(signed_data) + + def test_secure_loads_too_short(self): + """Verifies that data shorter than the HMAC signature is rejected.""" + with self.assertRaises(serialization_utils.SecurityError) as cm: + serialization_utils.secure_loads(b"short") + self.assertEqual( + str(cm.exception), "Data too short to contain a valid signature." + ) + + +if __name__ == "__main__": + unittest.main() From 0f0957b728f1edb49377979fb862aa4830c3b3e0 Mon Sep 17 00:00:00 2001 From: Joshua Provoste <8358462+JoshuaProvoste@users.noreply.github.com> Date: Tue, 14 Apr 2026 19:36:56 -0400 Subject: [PATCH 2/3] security: finalize RCE remediation with CLI safety flags and JSON refactoring --- local.db | Bin 0 -> 49152 bytes src/google/adk/cli/cli_tools_click.py | 24 ++++- .../sessions/migration/migration_runner.py | 67 +++++++++++++- src/google/adk/sessions/schemas/shared.py | 37 ++++++++ src/google/adk/sessions/schemas/v0.py | 31 +------ .../unittests/security/test_migration_cli.py | 86 ++++++++++++++++++ tests_output.txt | Bin 0 -> 17320 bytes 7 files changed, 212 insertions(+), 33 deletions(-) create mode 100644 local.db create mode 100644 tests/unittests/security/test_migration_cli.py create mode 100644 tests_output.txt diff --git a/local.db b/local.db new file mode 100644 index 0000000000000000000000000000000000000000..d309c2f8d95b6a237c0db1561f4a4abc27fcf10d GIT binary patch literal 49152 zcmeI*yKmb@90zbRwk1n`C~1Kb9t@x5p(bj(qTwgxP@$>OZWZ@nL2f8|9}hy+BIwQ2Na!)PQ4>dJz&uY&=hf_=>#={x$toEg&l5u6t{dp$!mok;u zR%ZWB{V?<6>@U;5&3rxe)AaX=->1G4i!ng}0uX=z1Rwx`ODNEJZz`@8=hV(U$FbPeyS(p+T5J_?JB2V)mTtT@R&SnKqdd_j@hj~nQ&}! zH(xG%mM>=#n(U3-U~Bqf z6_S^z7{~#IDKAuw*e+YR`YvS*6i?(>f zpAW~A5p>>E=V3gqEiI`h$9~5KLyS|-yuRe*Ko|FAkma}2m+}RD+yVI3p+IMr>>jc; ztY-$g;anH!N&UuTOe-!$QS=338GSQ}&e}v=TU=C6-uF>t&0<>+CcRW_fTXwl%|!g# z-d4aO&g-$5c55*z&R***qnD1KygL@x=I7PZ=YEuD8lbxcor)yJ$(tu|}Q z!_h`k#HFVGRKyJv1Rwwb2tWV=5P$##AOHafKmYPYfOa2+se%P*N{OB0vNI z0SG_<0uX=z1Rwwb2tWV=5V*_&_tf~(+`_ed-e&u>Rpat|W&W2n!tXO!#?f&$lg+JW z)>n;nk~KCnxy|*D)-wJSofDk@f32imUuGV(9Rd)500bZa0SG_<0uX=z1R(JL7Fd`N z#|;wEM-Alpe{s=2Ob~zo1Rwwb2tWV=5P$##AOHafjIe+_|BuiABiz9V4FV8=00bZa z0SG_<0uX=z1Rx*;aQ+{|00Izz00bZa0SG_<0uX=z1RyZ_0yzIa`hAQDApijgKmY;| zfB*y_009U<00RE?|7U+y#0?V!AOHafKmY;|fB*y_009U<;8F;j7;0icT<&fre_Ca2 z=5i~S%SB)MF8ZSQM`_1xH>#974cBt)^d?ECS4f(+_~S-X&d~!=yT7O^%XnhhQQG5_ zGs|W*PUZZuT%*aIeOl#%OPlOzi+P;cmaNgLHr##Zz^_BD)63n zQgksI2>}Q|00Izz00bZa0SG_<0uX>eS3r-+;~ bool: + r"""Checks if a database URL points to a trusted local source. + + Trusted sources include: + - Localhost (127.0.0.1, ::1, 'localhost') + - Private network addresses (if explicitly allowed in future, but blocked now for safety) + - Local file paths (sqlite:///path/to/db) + + Untrusted sources include: + - External IPs. + - Remote hostnames. + - Windows UNC paths (\\host\share). + """ + try: + parsed = urlparse(db_url) + host = parsed.hostname + + # SQLite local paths (sqlite:///path) have no hostname in urlparse usually + if not host: + # Check for Windows UNC paths in the path component + # sqlite:///\\host\share -> path starts with /\\ + if parsed.path.startswith("/\\\\") or parsed.path.startswith("//"): + return False + return True + + # Check for localhost/loopback + if host.lower() in ("localhost", "127.0.0.1", "::1"): + return True + + # Check if host is an IP and if it's a loopback IP + try: + ip = ipaddress.ip_address(host) + return ip.is_loopback + except ValueError: + # Not an IP address, probably a hostname. + # If it's not 'localhost', we treat it as untrusted for safety. + return False + + except Exception: + # On parsing error, fail closed + return False + + +def upgrade( + source_db_url: str, + dest_db_url: str, + force_untrusted_source: bool = False, +): """Migrates a database from its current version to the latest version. If the source database schema is older than the latest version, this @@ -72,6 +128,15 @@ def upgrade(source_db_url: str, dest_db_url: str): "Please provide a different URL for dest_db_url." ) + if not _is_trusted_url(source_db_url) and not force_untrusted_source: + raise SecurityError( + f"Untrusted source database URL detected: {source_db_url}\n" + "Migrating from remote or untrusted sources (e.g., SMB shares or " + "external IPs) poses a SIGNIFICANT Remote Code Execution (RCE) " + "risk if the source data is malicious.\n" + "To proceed anyway, use the --force-untrusted-source flag." + ) + current_version = _schema_check_utils.get_db_schema_version(source_db_url) if current_version == LATEST_VERSION: logger.info( diff --git a/src/google/adk/sessions/schemas/shared.py b/src/google/adk/sessions/schemas/shared.py index 25d4ea9e95..63bcf42648 100644 --- a/src/google/adk/sessions/schemas/shared.py +++ b/src/google/adk/sessions/schemas/shared.py @@ -22,6 +22,8 @@ from sqlalchemy.types import DateTime from sqlalchemy.types import TypeDecorator +from google.adk.utils import serialization_utils + DEFAULT_MAX_KEY_LENGTH = 128 DEFAULT_MAX_VARCHAR_LENGTH = 256 @@ -55,6 +57,41 @@ def process_result_value(self, value, dialect: Dialect): return value +class JsonEncodedType(DynamicJSON): + """A JSON-encoded type with hybrid support for secure legacy pickles. + + New data is always stored as JSON. When reading, it first attempts to + decode JSON. If that fails and the value is binary, it attempts to + deserialize using serialization_utils.secure_loads (HMAC-verified). + """ + + def process_result_value(self, value, dialect: Dialect): + if value is None: + return None + + # Try JSON first (for new data or PostgreSQL JSONB) + if dialect.name == "postgresql": + return value + + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError: + # If it's a string that's not JSON, it might be a corrupted entry + # or an unexpected format. Logic continues to check for binary. + pass + + # If JSON failed, check if it's binary legacy data (HMAC signed) + if isinstance(value, bytes): + try: + return serialization_utils.secure_loads(value) + except serialization_utils.SecurityError: + # If both JSON and secure_loads fail, re-raise or handle as Error + raise + + return super().process_result_value(value, dialect) + + class PreciseTimestamp(TypeDecorator): """Represents a timestamp precise to the microsecond.""" diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index 41579e321c..521a2efb57 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -57,9 +57,9 @@ from ...events.event import Event from ...events.event_actions import EventActions from ..session import Session -from .shared import DEFAULT_MAX_KEY_LENGTH from .shared import DEFAULT_MAX_VARCHAR_LENGTH from .shared import DynamicJSON +from .shared import JsonEncodedType from .shared import PreciseTimestamp logger = logging.getLogger("google_adk." + __name__) @@ -89,33 +89,6 @@ def _truncate_str(value: Optional[str], max_length: int) -> Optional[str]: return value -class DynamicPickleType(TypeDecorator): - """Represents a type that can be pickled.""" - - impl = PickleType - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.LONGBLOB) - if dialect.name == "spanner+spanner": - from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType - - return dialect.type_descriptor(SpannerPickleType) - return self.impl - - def process_bind_param(self, value, dialect): - """Ensures the pickled value is a bytes object before passing it to the database dialect.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return serialization_utils.secure_dumps(value) - return value - - def process_result_value(self, value, dialect): - """Ensures the raw bytes from the database are unpickled back into a Python object.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return serialization_utils.secure_loads(value) - return value class Base(DeclarativeBase): @@ -234,7 +207,7 @@ class StorageEvent(Base): invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) + actions: Mapped[MutableDict[str, Any]] = mapped_column(JsonEncodedType) long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( Text, nullable=True ) diff --git a/tests/unittests/security/test_migration_cli.py b/tests/unittests/security/test_migration_cli.py new file mode 100644 index 0000000000..cc9c6b345f --- /dev/null +++ b/tests/unittests/security/test_migration_cli.py @@ -0,0 +1,86 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock +from click.testing import CliRunner +from google.adk.cli.cli_tools_click import main + +class TestMigrationCLISecurity(unittest.TestCase): + + def setUp(self): + self.runner = CliRunner() + + def test_migrate_session_blocks_remote_ip(self): + """Verifies that the CLI blocks migration from a remote IP source.""" + # Using an external IP address + untrusted_url = "sqlite://1.2.3.4/malicious.db" + + result = self.runner.invoke(main, [ + "migrate", "session", + "--source_db_url", untrusted_url, + "--dest_db_url", "sqlite:///local.db" + ]) + + self.assertNotEqual(result.exit_code, 0) + self.assertIn("Untrusted source database URL detected", result.output) + self.assertIn("--force-untrusted-source", result.output) + + def test_migrate_session_blocks_unc_path(self): + """Verifies that the CLI blocks migration from a Windows UNC path.""" + # Using a Windows UNC path (Samba style) + untrusted_url = "sqlite:///\\\\192.168.1.90\\lab_share\\malicious.db" + + result = self.runner.invoke(main, [ + "migrate", "session", + "--source_db_url", untrusted_url, + "--dest_db_url", "sqlite:///local.db" + ]) + + self.assertNotEqual(result.exit_code, 0) + self.assertIn("Untrusted source database URL detected", result.output) + + @mock.patch("google.adk.sessions.migration.migration_runner.upgrade") + def test_migrate_session_allows_localhost(self, mock_upgrade): + """Verifies that localhost URLs are trusted by default.""" + trusted_url = "sqlite:///local.db" + + result = self.runner.invoke(main, [ + "migrate", "session", + "--source_db_url", trusted_url, + "--dest_db_url", "sqlite:///dest.db" + ]) + + # It should call upgrade (we mock it call because we don't want to run real migration) + mock_upgrade.assert_called_once() + self.assertEqual(result.exit_code, 0) + + @mock.patch("google.adk.sessions.migration.migration_runner.upgrade") + def test_migrate_session_force_flag_works(self, mock_upgrade): + """Verifies that the --force-untrusted-source flag bypasses the block.""" + untrusted_url = "sqlite://8.8.8.8/remote.db" + + result = self.runner.invoke(main, [ + "migrate", "session", + "--source_db_url", untrusted_url, + "--dest_db_url", "sqlite:///local.db", + "--force-untrusted-source" + ]) + + # Should call upgrade with force_untrusted_source=True + mock_upgrade.assert_called_with(untrusted_url, "sqlite:///local.db", True) + self.assertEqual(result.exit_code, 0) + +if __name__ == "__main__": + unittest.main() diff --git a/tests_output.txt b/tests_output.txt new file mode 100644 index 0000000000000000000000000000000000000000..b7b8cf47ad4ecbf1d0f630bf2b7d0c148627acfe GIT binary patch literal 17320 zcmeI4{Zku77{~W#XZ#;-_=R*Vk+#qlw9X)cGDES_Vs(mUXp$mB2rmNAA9#$DUe3<4e`( zi(;rgyqMuD&E*GCzM*+-xu?4CxEF3sGyh3E$WoAK&M`$h>hVM+b+zAid+uvpYwk;Z z-*63m+fz^9sFx#Akf>f)^0F-x3er=zFLW-?4xlJsZeUdVbDFs(YkwEw4MZ3ASPK8m|GX z>U$tcOe61T#8=X~ySndcoD)BKTe|eSINlZ&1J%Sn^Yhx2JUY70-FNzS;Polcx}pC2 zYSq_lTk6^71s{&QH(1xR6I}=Dx23T{Z5tzugia#&W9fC)CK@V_^}6Qnc>01~HPrW! zIBxdYR(mw)Oqb22;py#&igzk&i1t&B!yNEMrkPMt6UC;fXR?I6_EMH=CVG9YlF%Yc zLm0Q|`BaxhS-(#sI<3ySWV|IA1&try?^Whzy!u`vpamu~deDfRCK|2QGN3hu2;M=oWKm>~Q#Q0xS*6=g!O6TBN@b$}X z*=shM@?3N8YQ6`uo69~QTYIZ#_*^)J?9n%DJn=g8tN4XZV{?g{7iDMq43;n_e>*2n z66^sfAUR|K$B`Od{=oN@Cf+)SQYId^|A*3W!p>?pyuw zt60B_^pL279v!PsIGB2uhN$X^4*L6V+SsJIO|~(t4OyGLV?8Oik^;RXFTke~N0DL7 z&qN;MnaZ-VOD%6xF@H-MaUJ>L{hE?vPoqQoha3_wbc*R9nc{6gFtQv}lLC z$SUyX7C*7*)RJR4VtfDY%BG5b*N=Wk@yRW^peK#b% zdPA;aD=;#aP=lQNnR4ZT`aBT&a9?HH>fyES`85{VxaE+RbgYQQcMwSM3pyCq*r zrW<0YtsPqz6hh@YjbWJ!p0FcdUiYh&6v+vzG%~Y}qA}TiZ5Wk9k9RNn40F2vtf2rt zx3zj0Xk>UXdcKdM*)|l_O-9hlFJv1q1!QPQ7Mx-ALjN@%v26Z_zO_Zn&?g7FF%fWV|s9^`ryS4*@|Jy?5U~AMyv)bODjrL%p zc{BP+psPInimA65?TZiKAclzZz1ZWgxW$N*aX(ou1e)Px*6%kIP1xxG1!8pr?o`)! zhqAwxry=Eb?Ls!MJa3U#M-kl+V%9r&MOKq-PZj*PI-@0!jENwlWpI@zfd{eNK3AM4 z7x_b!5(Q%p2Hh&+&=<;O@ORW9a)w+Z8-0s-ID0A{W^{rTJ~SdrL(~n?W>0MPw2;B2 zabX-S4WHcwlncyww$!c2Loa7&V{EuGTS|y=w-TST(Vhs>%>iA_h*qGT9T6l0tXarsm{BLgX$y2~A?L&HY^N1_CI2u_ zYAxxBzYhiRNwLlf6=qjGptfj8ir>y26-K9S3l;<&&{N?-6&_Ugd3|^6ntWVdjvnFT z<9y}j@xTfXx+?mbMRLOGh)bmZ+=eUeQRCv(1$a1x3Me(S@ zgYpr&T#i(DP#$ZY2Ir{opq|$7lX0_qLd_~X2wa4H^zu(Xl=+TOK4-)mS9nm_SQ1YB5p?4pzONo1#Ipdp{S+{a;X^S)<%&Eod-#cYZIQGW?pIAhPom`>Ja# z#-GzDTKhiGDX(=&b(MGG>5cr1D}65@yk*LMILLNQ>+wC^v+t^A-~S$WE&!w* znw$Q!HC)}tZ)?=Gqrx74&wtw9?wTlqCa5UF)8|X2XiBlH$^vK3}r$Pna+-`^LUFJ2N+T z$^<=+ck=3@q^pyV>8UN#Fl2wAQg)2&4BX^h8+(s2-%}Bf^B4|$Rd+lkcJ>A;%Z}IA z=?{9!i1qNWI?IvW^|@-CpL~`Bsa0{loU_{@(%wXz2jea$&gbn7ji{nL9=(e4vpYpv zMR|Lpa+)ZA+x+vcvL3RsCH Date: Mon, 20 Apr 2026 18:31:10 -0400 Subject: [PATCH 3/3] fix(cli): resolve CI failures, style issues, and NameError in schemas --- contributing/samples/gepa/experiment.py | 1 - contributing/samples/gepa/run_experiment.py | 1 - src/google/adk/cli/cli_tools_click.py | 13 +-- .../sessions/migration/migration_runner.py | 2 +- src/google/adk/sessions/schemas/shared.py | 3 +- src/google/adk/sessions/schemas/v0.py | 3 +- src/google/adk/utils/serialization_utils.py | 2 + tests/unittests/cli/test_fast_api.py | 8 +- .../unittests/security/test_migration_cli.py | 85 ++++++++++++------- 9 files changed, 72 insertions(+), 46 deletions(-) diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index c41239dc3e..7103172060 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -25,6 +25,8 @@ from pathlib import Path import tempfile import textwrap +from typing import Any +from typing import Literal from typing import Optional import click @@ -2094,9 +2096,8 @@ def cli_migrate_session( logs.setup_adk_logger(getattr(logging, log_level.upper())) try: from ..sessions.migration import migration_runner - migration_runner.upgrade( - source_db_url, dest_db_url, force_untrusted_source - ) + + migration_runner.upgrade(source_db_url, dest_db_url, force_untrusted_source) click.secho("Migration check and upgrade process finished.", fg="green") except Exception as e: click.secho(f"Migration failed: {e}", fg="red", err=True) @@ -2482,14 +2483,14 @@ def cli_deploy_gke( otel_to_cloud: bool, with_ui: bool, adk_version: str, - service_type: str, - log_level: Optional[str] = None, + service_type: Literal["ClusterIP", "NodePort", "LoadBalancer"], + log_level: str = "INFO", session_service_uri: Optional[str] = None, artifact_service_uri: Optional[str] = None, memory_service_uri: Optional[str] = None, use_local_storage: bool = False, trigger_sources: Optional[str] = None, -): +) -> None: """Deploys an agent to GKE. AGENT: The path to the agent source code folder. diff --git a/src/google/adk/sessions/migration/migration_runner.py b/src/google/adk/sessions/migration/migration_runner.py index 14afa51a78..16b70318be 100644 --- a/src/google/adk/sessions/migration/migration_runner.py +++ b/src/google/adk/sessions/migration/migration_runner.py @@ -41,7 +41,7 @@ } # The most recent schema version. The migration process stops once this version # is reached. -#Reached. +# Reached. LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION diff --git a/src/google/adk/sessions/schemas/shared.py b/src/google/adk/sessions/schemas/shared.py index 63bcf42648..0c00a2f9a7 100644 --- a/src/google/adk/sessions/schemas/shared.py +++ b/src/google/adk/sessions/schemas/shared.py @@ -15,6 +15,7 @@ import json +from google.adk.utils import serialization_utils from sqlalchemy import Dialect from sqlalchemy import Text from sqlalchemy.dialects import mysql @@ -22,8 +23,6 @@ from sqlalchemy.types import DateTime from sqlalchemy.types import TypeDecorator -from google.adk.utils import serialization_utils - DEFAULT_MAX_KEY_LENGTH = 128 DEFAULT_MAX_VARCHAR_LENGTH = 256 diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index 521a2efb57..58eb0ccedd 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -57,6 +57,7 @@ from ...events.event import Event from ...events.event_actions import EventActions from ..session import Session +from .shared import DEFAULT_MAX_KEY_LENGTH from .shared import DEFAULT_MAX_VARCHAR_LENGTH from .shared import DynamicJSON from .shared import JsonEncodedType @@ -89,8 +90,6 @@ def _truncate_str(value: Optional[str], max_length: int) -> Optional[str]: return value - - class Base(DeclarativeBase): """Base class for v0 database tables.""" diff --git a/src/google/adk/utils/serialization_utils.py b/src/google/adk/utils/serialization_utils.py index 2ea03d1e86..f3d19609a0 100644 --- a/src/google/adk/utils/serialization_utils.py +++ b/src/google/adk/utils/serialization_utils.py @@ -14,6 +14,8 @@ """Utilities for secure serialization and deserialization.""" +from __future__ import annotations + import hashlib import hmac import os diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 3e63f31222..2a169fc72e 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -1983,7 +1983,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): assert not (app_root / "tmp").exists() response = builder_test_client.get("/builder/app/app?tmp=true") assert response.status_code == 200 - assert response.text == "name: app\n" + assert response.text.replace("\r\n", "\n") == "name: app\n" tmp_agent_root = app_root / "tmp" / "app" assert (tmp_agent_root / "root_agent.yaml").is_file() @@ -1993,7 +1993,7 @@ def test_builder_get_tmp_true_recreates_tmp(builder_test_client, tmp_path): "/builder/app/app?tmp=true&file_path=nested/nested.yaml" ) assert response.status_code == 200 - assert response.text == "nested: true\n" + assert response.text.replace("\r\n", "\n") == "nested: true\n" def test_builder_get_tmp_true_missing_app_returns_empty( @@ -2140,11 +2140,11 @@ def test_builder_get_allows_yaml_file_paths(builder_test_client, tmp_path): "/builder/app/app?file_path=sub_agent.yaml" ) assert response.status_code == 200 - assert response.text == "name: sub\n" + assert response.text.replace("\r\n", "\n") == "name: sub\n" response = builder_test_client.get("/builder/app/app?file_path=tool.yml") assert response.status_code == 200 - assert response.text == "name: tool\n" + assert response.text.replace("\r\n", "\n") == "name: tool\n" def test_builder_endpoints_not_registered_without_web( diff --git a/tests/unittests/security/test_migration_cli.py b/tests/unittests/security/test_migration_cli.py index cc9c6b345f..cc93a57197 100644 --- a/tests/unittests/security/test_migration_cli.py +++ b/tests/unittests/security/test_migration_cli.py @@ -14,9 +14,11 @@ import unittest from unittest import mock + from click.testing import CliRunner from google.adk.cli.cli_tools_click import main + class TestMigrationCLISecurity(unittest.TestCase): def setUp(self): @@ -26,13 +28,19 @@ def test_migrate_session_blocks_remote_ip(self): """Verifies that the CLI blocks migration from a remote IP source.""" # Using an external IP address untrusted_url = "sqlite://1.2.3.4/malicious.db" - - result = self.runner.invoke(main, [ - "migrate", "session", - "--source_db_url", untrusted_url, - "--dest_db_url", "sqlite:///local.db" - ]) - + + result = self.runner.invoke( + main, + [ + "migrate", + "session", + "--source_db_url", + untrusted_url, + "--dest_db_url", + "sqlite:///local.db", + ], + ) + self.assertNotEqual(result.exit_code, 0) self.assertIn("Untrusted source database URL detected", result.output) self.assertIn("--force-untrusted-source", result.output) @@ -41,13 +49,19 @@ def test_migrate_session_blocks_unc_path(self): """Verifies that the CLI blocks migration from a Windows UNC path.""" # Using a Windows UNC path (Samba style) untrusted_url = "sqlite:///\\\\192.168.1.90\\lab_share\\malicious.db" - - result = self.runner.invoke(main, [ - "migrate", "session", - "--source_db_url", untrusted_url, - "--dest_db_url", "sqlite:///local.db" - ]) - + + result = self.runner.invoke( + main, + [ + "migrate", + "session", + "--source_db_url", + untrusted_url, + "--dest_db_url", + "sqlite:///local.db", + ], + ) + self.assertNotEqual(result.exit_code, 0) self.assertIn("Untrusted source database URL detected", result.output) @@ -55,13 +69,19 @@ def test_migrate_session_blocks_unc_path(self): def test_migrate_session_allows_localhost(self, mock_upgrade): """Verifies that localhost URLs are trusted by default.""" trusted_url = "sqlite:///local.db" - - result = self.runner.invoke(main, [ - "migrate", "session", - "--source_db_url", trusted_url, - "--dest_db_url", "sqlite:///dest.db" - ]) - + + result = self.runner.invoke( + main, + [ + "migrate", + "session", + "--source_db_url", + trusted_url, + "--dest_db_url", + "sqlite:///dest.db", + ], + ) + # It should call upgrade (we mock it call because we don't want to run real migration) mock_upgrade.assert_called_once() self.assertEqual(result.exit_code, 0) @@ -70,17 +90,24 @@ def test_migrate_session_allows_localhost(self, mock_upgrade): def test_migrate_session_force_flag_works(self, mock_upgrade): """Verifies that the --force-untrusted-source flag bypasses the block.""" untrusted_url = "sqlite://8.8.8.8/remote.db" - - result = self.runner.invoke(main, [ - "migrate", "session", - "--source_db_url", untrusted_url, - "--dest_db_url", "sqlite:///local.db", - "--force-untrusted-source" - ]) - + + result = self.runner.invoke( + main, + [ + "migrate", + "session", + "--source_db_url", + untrusted_url, + "--dest_db_url", + "sqlite:///local.db", + "--force-untrusted-source", + ], + ) + # Should call upgrade with force_untrusted_source=True mock_upgrade.assert_called_with(untrusted_url, "sqlite:///local.db", True) self.assertEqual(result.exit_code, 0) + if __name__ == "__main__": unittest.main()