Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
Binary file added local.db
Binary file not shown.
29 changes: 24 additions & 5 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from pathlib import Path
import tempfile
import textwrap
from typing import Any
from typing import Literal
from typing import Optional

import click
Expand Down Expand Up @@ -2071,18 +2073,35 @@ def migrate():
default="INFO",
help="Optional. Set the logging level",
)
@click.option(
"--force-untrusted-source",
is_flag=True,
default=False,
help=(
"Optional. Force migration from untrusted or remote database sources "
"(e.g., SMB shares or external IPs). Use with CAUTION as it poses RCE "
"risks if the source is malicious."
),
)
@click.pass_context
def cli_migrate_session(
*, source_db_url: str, dest_db_url: str, log_level: str
ctx,
*,
source_db_url: str,
dest_db_url: str,
log_level: str,
force_untrusted_source: bool,
):
"""Migrates a session database to the latest schema version."""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
try:
from ..sessions.migration import migration_runner

migration_runner.upgrade(source_db_url, dest_db_url)
migration_runner.upgrade(source_db_url, dest_db_url, force_untrusted_source)
click.secho("Migration check and upgrade process finished.", fg="green")
except Exception as e:
click.secho(f"Migration failed: {e}", fg="red", err=True)
ctx.exit(1)


@deploy.command("agent_engine")
Expand Down Expand Up @@ -2464,14 +2483,14 @@ def cli_deploy_gke(
otel_to_cloud: bool,
with_ui: bool,
adk_version: str,
service_type: str,
log_level: Optional[str] = None,
service_type: Literal["ClusterIP", "NodePort", "LoadBalancer"],
log_level: str = "INFO",
session_service_uri: Optional[str] = None,
artifact_service_uri: Optional[str] = None,
memory_service_uri: Optional[str] = None,
use_local_storage: bool = False,
trigger_sources: Optional[str] = None,
):
) -> None:
"""Deploys an agent to GKE.

AGENT: The path to the agent source code folder.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from datetime import timezone
import json
import logging
import pickle
import sys
from typing import Any

Expand All @@ -29,6 +28,7 @@
from google.adk.sessions import _session_util
from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.schemas import v1
from google.adk.utils import serialization_utils
from google.genai import types
import sqlalchemy
from sqlalchemy import create_engine
Expand Down Expand Up @@ -59,7 +59,7 @@ def _row_to_event(row: dict) -> Event:
if actions_val is not None:
try:
if isinstance(actions_val, bytes):
actions = pickle.loads(actions_val)
actions = serialization_utils.secure_loads(actions_val)
else: # for spanner - it might return object directly
actions = actions_val
except Exception as e:
Expand Down
67 changes: 66 additions & 1 deletion src/google/adk/sessions/migration/migration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

from __future__ import annotations

import ipaddress
import logging
import os
import tempfile
from urllib.parse import urlparse

from google.adk.sessions.migration import _schema_check_utils
from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle
Expand All @@ -39,10 +41,64 @@
}
# The most recent schema version. The migration process stops once this version
# is reached.
# Reached.
LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION


def upgrade(source_db_url: str, dest_db_url: str):
class SecurityError(Exception):
"""Raised when a security policy is violated during migration."""

pass


def _is_trusted_url(db_url: str) -> bool:
r"""Checks if a database URL points to a trusted local source.

Trusted sources include:
- Localhost (127.0.0.1, ::1, 'localhost')
- Private network addresses (if explicitly allowed in future, but blocked now for safety)
- Local file paths (sqlite:///path/to/db)

Untrusted sources include:
- External IPs.
- Remote hostnames.
- Windows UNC paths (\\host\share).
"""
try:
parsed = urlparse(db_url)
host = parsed.hostname

# SQLite local paths (sqlite:///path) have no hostname in urlparse usually
if not host:
# Check for Windows UNC paths in the path component
# sqlite:///\\host\share -> path starts with /\\
if parsed.path.startswith("/\\\\") or parsed.path.startswith("//"):
return False
return True

# Check for localhost/loopback
if host.lower() in ("localhost", "127.0.0.1", "::1"):
return True

# Check if host is an IP and if it's a loopback IP
try:
ip = ipaddress.ip_address(host)
return ip.is_loopback
except ValueError:
# Not an IP address, probably a hostname.
# If it's not 'localhost', we treat it as untrusted for safety.
return False

except Exception:
# On parsing error, fail closed
return False


def upgrade(
source_db_url: str,
dest_db_url: str,
force_untrusted_source: bool = False,
):
"""Migrates a database from its current version to the latest version.

If the source database schema is older than the latest version, this
Expand Down Expand Up @@ -72,6 +128,15 @@ def upgrade(source_db_url: str, dest_db_url: str):
"Please provide a different URL for dest_db_url."
)

if not _is_trusted_url(source_db_url) and not force_untrusted_source:
raise SecurityError(
f"Untrusted source database URL detected: {source_db_url}\n"
"Migrating from remote or untrusted sources (e.g., SMB shares or "
"external IPs) poses a SIGNIFICANT Remote Code Execution (RCE) "
"risk if the source data is malicious.\n"
"To proceed anyway, use the --force-untrusted-source flag."
)

current_version = _schema_check_utils.get_db_schema_version(source_db_url)
if current_version == LATEST_VERSION:
logger.info(
Expand Down
36 changes: 36 additions & 0 deletions src/google/adk/sessions/schemas/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import json

from google.adk.utils import serialization_utils
from sqlalchemy import Dialect
from sqlalchemy import Text
from sqlalchemy.dialects import mysql
Expand Down Expand Up @@ -55,6 +56,41 @@ def process_result_value(self, value, dialect: Dialect):
return value


class JsonEncodedType(DynamicJSON):
"""A JSON-encoded type with hybrid support for secure legacy pickles.

New data is always stored as JSON. When reading, it first attempts to
decode JSON. If that fails and the value is binary, it attempts to
deserialize using serialization_utils.secure_loads (HMAC-verified).
"""

def process_result_value(self, value, dialect: Dialect):
if value is None:
return None

# Try JSON first (for new data or PostgreSQL JSONB)
if dialect.name == "postgresql":
return value

if isinstance(value, str):
try:
return json.loads(value)
except json.JSONDecodeError:
# If it's a string that's not JSON, it might be a corrupted entry
# or an unexpected format. Logic continues to check for binary.
pass

# If JSON failed, check if it's binary legacy data (HMAC signed)
if isinstance(value, bytes):
try:
return serialization_utils.secure_loads(value)
except serialization_utils.SecurityError:
# If both JSON and secure_loads fail, re-raise or handle as Error
raise

return super().process_result_value(value, dialect)


class PreciseTimestamp(TypeDecorator):
"""Represents a timestamp precise to the microsecond."""

Expand Down
34 changes: 3 additions & 31 deletions src/google/adk/sessions/schemas/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from datetime import timezone
import json
import logging
import pickle
from typing import Any
from typing import Optional

from google.adk.platform import uuid as platform_uuid
from google.adk.utils import serialization_utils
from google.genai import types
from sqlalchemy import Boolean
from sqlalchemy import desc
Expand All @@ -60,6 +60,7 @@
from .shared import DEFAULT_MAX_KEY_LENGTH
from .shared import DEFAULT_MAX_VARCHAR_LENGTH
from .shared import DynamicJSON
from .shared import JsonEncodedType
from .shared import PreciseTimestamp

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -89,35 +90,6 @@ def _truncate_str(value: Optional[str], max_length: int) -> Optional[str]:
return value


class DynamicPickleType(TypeDecorator):
"""Represents a type that can be pickled."""

impl = PickleType

def load_dialect_impl(self, dialect):
if dialect.name == "mysql":
return dialect.type_descriptor(mysql.LONGBLOB)
if dialect.name == "spanner+spanner":
from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType

return dialect.type_descriptor(SpannerPickleType)
return self.impl

def process_bind_param(self, value, dialect):
"""Ensures the pickled value is a bytes object before passing it to the database dialect."""
if value is not None:
if dialect.name in ("spanner+spanner", "mysql"):
return pickle.dumps(value)
return value

def process_result_value(self, value, dialect):
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
if value is not None:
if dialect.name in ("spanner+spanner", "mysql"):
return pickle.loads(value)
return value


class Base(DeclarativeBase):
"""Base class for v0 database tables."""

Expand Down Expand Up @@ -234,7 +206,7 @@ class StorageEvent(Base):

invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType)
actions: Mapped[MutableDict[str, Any]] = mapped_column(JsonEncodedType)
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
Text, nullable=True
)
Expand Down
Loading