From f2285e2830a2e593f8fe3cf348594f6ab2fa380e Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 21 May 2026 22:00:59 -0600 Subject: [PATCH 1/5] feat: add outbound WebSocket sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds client-initiated outbound WebSocket connections as a new event source kind, complementing the existing inbound webhook system. Where custom webhooks require an external service to POST to the automation service, WebSocket sources flip the direction: the automation service initiates and maintains a persistent connection to the external service and receives pushed events over it. This eliminates the URL-verification handshake problem that blocks services like Slack Socket Mode from using inbound webhooks. ## New components ### OutboundWebSocketSource model (models.py, migration 007) Single-table with a 'kind' discriminator ('generic' | 'slack'): - Common fields: id, org_id, name, source slug, enabled - JMESPath fields: event_key_expr, payload_expr, filter_expr - generic fields: url (static wss://), headers (upgrade headers) - slack fields: app_token (xapp-… for apps.connections.open) - Runtime state: status, status_detail, connected_at, last_event_at ### Pydantic schemas (schemas.py) Discriminated union on 'kind' for create requests: - GenericWebSocketSourceCreate – validates wss:// URL + JMESPath exprs - SlackWebSocketSourceCreate – validates xapp-… token + Slack defaults - WebSocketSour- WebSocketSour- WebSocketSour- WebSocketSour- WebSocketSourebSocketSourceResponse – credentials (headers, app_token) never returned ### CRUD router (websocket_source_router.py) /v1/websocket/v1/websocket/v1/websocket/v1/websocket/v1/websocket/v1/websreate + notify SocketManager GET /v1/websocket-sources GET /v1/weboped) GET /v1/websocket-sources/{id} get with live statu GET /v1/websocket-sources/{id} get with live statu GE needed DELETE /v1/websocket-sources/{id} delete + close connection POST /v1/websocket-sources/{id}/reconnect force reconnect ### SocketManager background service (socket_ma### SocketManager background service (socket_ma### SocketManal back-off reconnect: - generic: connects to static wss:// URL with opt- generic: connects to static wss apps.connections.open to get a fresh URL on each connect, sends per-message envelope ACKs before dispatch (meets Slack's ~3s deadline), handles server-initiated disconnect events Dispatch pipeline (ACK → pre-filter → event_key_expr → payload_expr → trigger matching → create PENDING runs) reuses get_event_automations, matches_trigger, and create_automation_run from the webhook pipeline. ### App wiring (app.py) SocketManager started/stopped in the FastAPI lifespan alongside the scheduler, dispatcher, and watchdog. Router registered before the main automation router to avoid UUID route conflicts. ### Dependency (pyproject.toml) websockets>=12 added for client-swebsockets>=12 added for client-swebsockets>=12 added for con trigger model is unchanged: event-triggered automations reference a WebSocket source by its 'source' slug in trigger.source, use trigger.on for event-key pattern matching, and trigger.filter for JMESPath payload filtering — identical to inbound webhooks. Co-authored-by: openhands --- .../007_outbound_websocket_sources.py | 91 +++ openhands/automation/app.py | 10 + openhands/automation/models.py | 129 ++++ openhands/automation/schemas.py | 227 +++++++ openhands/automation/socket_manager.py | 492 ++++++++++++++ .../automation/websocket_source_router.py | 235 +++++++ pyproject.toml | 1 + tests/test_websocket_source_router.py | 610 ++++++++++++++++++ uv.lock | 2 + 9 files changed, 1797 insertions(+) create mode 100644 migrations/versions/007_outbound_websocket_sources.py create mode 100644 openhands/automation/socket_manager.py create mode 100644 openhands/automation/websocket_source_router.py create mode 100644 tests/test_websocket_source_router.py diff --git a/migrations/versions/007_outbound_websocket_sources.py b/migrations/versions/007_outbound_websocket_sources.py new file mode 100644 index 0000000..50730e6 --- /dev/null +++ b/migrations/versions/007_outbound_websocket_sources.py @@ -0,0 +1,91 @@ +"""Add outbound_websocket_sources table. + +Stores configuration for outbound WebSocket connections that the automation +service initiates and maintains. Events received over these connections are +dispatched through the same trigger-matching pipeline as inbound webhooks. + +Supports two kinds: + - "generic": static wss:// URL with optional HTTP headers + - "slack": Slack Socket Mode via apps.connections.open (dynamic URL) + +Revision ID: 007 +Revises: 006 +Create Date: 2026-05-21 +""" + +from collections.abc import Sequence + +from alembic import op +from sqlalchemy import Boolean, Column, DateTime, Enum, JSON, String, Text, Uuid, text + + +revision: str = "007" +down_revision: str = "006" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "outbound_websocket_sources", + Column("id", Uuid, primary_key=True), + Column("org_id", Uuid, nullable=False), + Column("name", String(255), nullable=False), + Column("source", String(100), nullable=False), + Column("kind", String(50), nullable=False), + Column("enabled", Boolean, nullable=False, server_default="true"), + # JMESPath expressions + Column("event_key_expr", String(500), nullable=False, server_default="type"), + Column("payload_expr", String(500), nullable=True), + Column("filter_expr", Text, nullable=True), + # generic-kind fields + Column("url", Text, nullable=True), + Column("headers", JSON, nullable=True), + # slack-kind fields + Column("app_token", String(255), nullable=True), + # runtime state + Column( + "status", + Enum( + "CONNECTING", + "CONNECTED", + "DISCONNECTED", + "ERROR", + name="websocketstatus", + native_enum=False, + length=20, + ), + nullable=False, + server_default="DISCONNECTED", + ), + Column("status_detail", Text, nullable=True), + Column("connected_at", DateTime(timezone=True), nullable=True), + Column("last_event_at", DateTime(timezone=True), nullable=True), + Column( + "created_at", + DateTime(timezone=True), + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + ), + Column( + "updated_at", + DateTime(timezone=True), + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + ), + ) + op.create_index( + "ix_outbound_ws_sources_org_id", + "outbound_websocket_sources", + ["org_id"], + ) + op.create_index( + "ix_outbound_ws_sources_org_source", + "outbound_websocket_sources", + ["org_id", "source"], + unique=True, + ) + + +def downgrade() -> None: + op.drop_table("outbound_websocket_sources") diff --git a/openhands/automation/app.py b/openhands/automation/app.py index da771d9..1c2e418 100644 --- a/openhands/automation/app.py +++ b/openhands/automation/app.py @@ -26,9 +26,11 @@ from openhands.automation.preset_router import router as preset_router from openhands.automation.router import router from openhands.automation.scheduler import scheduler_loop +from openhands.automation.socket_manager import SocketManager from openhands.automation.uploads import router as uploads_router from openhands.automation.watchdog import watchdog_loop from openhands.automation.webhook_router import router as webhook_router +from openhands.automation.websocket_source_router import router as websocket_source_router logger = logging.getLogger("automation.app") @@ -151,11 +153,18 @@ async def lifespan(app: FastAPI): app.state.watchdog_task = watchdog_task logger.info("Background watchdog started") + # Socket manager: maintains outbound WebSocket connections + socket_manager = SocketManager(app.state.session_factory) + app.state.socket_manager = socket_manager + await socket_manager.start() + logger.info("Socket manager started") + yield # Shutdown logger.info("Shutting down background tasks...") shutdown_event.set() + await socket_manager.stop() # Wait for all tasks to exit gracefully for task_name, task in [ @@ -224,6 +233,7 @@ def _create_app() -> FastAPI: app.include_router(preset_router, prefix=_base_path) app.include_router(event_router, prefix=_base_path) app.include_router(webhook_router, prefix=_base_path) +app.include_router(websocket_source_router, prefix=_base_path) app.include_router(router, prefix=_base_path) diff --git a/openhands/automation/models.py b/openhands/automation/models.py index ca10031..61bbdaf 100644 --- a/openhands/automation/models.py +++ b/openhands/automation/models.py @@ -21,6 +21,15 @@ from openhands.automation.utils import utcnow +class WebSocketStatus(enum.Enum): + """Runtime connection status for an outbound WebSocket source.""" + + CONNECTING = "CONNECTING" + CONNECTED = "CONNECTED" + DISCONNECTED = "DISCONNECTED" + ERROR = "ERROR" + + class Base(DeclarativeBase): pass @@ -321,3 +330,123 @@ class CustomWebhook(Base): __table_args__ = ( Index("ix_custom_webhooks_org_source", "org_id", "source", unique=True), ) + + +class OutboundWebSocketSource(Base): + """An outbound WebSocket connection that receives events from an external service. + + Unlike CustomWebhook (where the external service connects to us), this model + represents a connection WE initiate to an external service. A background + SocketManager maintains the connection and dispatches received events through + the same trigger-matching pipeline used by webhooks. + + Two kinds are supported, selected via the ``kind`` discriminator column: + + ``"generic"`` + Connects to a static ``wss://`` URL with optional HTTP headers. + Suitable for any service that exposes a plain WebSocket endpoint. + + ``"slack"`` + Connects to Slack's Socket Mode API. Requires a Slack App-Level Token + (``xapp-…``). The connection URL is fetched dynamically by calling + ``apps.connections.open`` before each connect attempt; Slack-specific + envelope ACKs are handled automatically. + + Event routing uses the same JMESPath machinery as webhooks: + + - ``event_key_expr`` extracts the event-type string that is matched against + ``trigger.on`` patterns in automations (e.g. ``"payload.event.type"`` + yields ``"message"`` for Slack message events). + - ``payload_expr`` unwraps outer envelopes before the payload is stored on + the run and handed to ``trigger.filter`` evaluation (e.g. ``"payload.event"`` + for Slack strips the Socket Mode envelope). + - ``filter_expr`` is a *connection-level* pre-filter: events that do not match + are silently dropped before any automation matching occurs. Use this to + avoid dispatching irrelevant high-volume events (e.g. bot messages). + """ + + __tablename__ = "outbound_websocket_sources" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + org_id: Mapped[uuid.UUID] = mapped_column(Uuid, nullable=False, index=True) + + # Human-readable label + name: Mapped[str] = mapped_column(String(255), nullable=False) + + # Slug used as the event ``source`` name in trigger matching and URLs. + # Must be unique per org (enforced by the unique index below). + source: Mapped[str] = mapped_column(String(100), nullable=False) + + # Discriminator: "generic" or "slack" + kind: Mapped[str] = mapped_column(String(50), nullable=False) + + enabled: Mapped[bool] = mapped_column(default=True, nullable=False) + + # --- JMESPath expressions (common to all kinds) --- + + # Extracts the event-type key used for trigger.on pattern matching. + # Defaults differ per kind and are set at the schema layer: + # generic → "type" + # slack → "payload.event.type" + event_key_expr: Mapped[str] = mapped_column(String(500), nullable=False) + + # Unwraps outer envelopes so the stored/filtered payload is the inner event. + # None means pass the raw message through unchanged. + # slack default → "payload.event" + payload_expr: Mapped[str | None] = mapped_column(String(500), nullable=True) + + # Connection-level pre-filter (JMESPath). Evaluated against the raw message + # *before* payload unwrapping. Events that do not match are dropped silently. + # None means accept all events. + filter_expr: Mapped[str | None] = mapped_column(Text, nullable=True) + + # --- kind = "generic" fields --- + + # Static wss:// URL to connect to. + url: Mapped[str | None] = mapped_column(Text, nullable=True) + + # JSON object of HTTP headers to send on the WebSocket upgrade request. + # Values may contain ${SECRET_NAME} placeholders resolved at connect time + # from the automation service's secret store (future enhancement; stored + # verbatim for now — treat as sensitive and encrypt at rest in production). + headers: Mapped[dict | None] = mapped_column(JSON, nullable=True) + + # --- kind = "slack" fields --- + + # Slack App-Level Token (xapp-…). Required for Socket Mode. + # Used to call apps.connections.open to obtain a fresh wss:// URL on each + # connect attempt. Treat as sensitive; encrypt at rest in production. + app_token: Mapped[str | None] = mapped_column(String(255), nullable=True) + + # --- Runtime state (managed by SocketManager, not by the API) --- + + status: Mapped[WebSocketStatus] = mapped_column( + Enum(WebSocketStatus, native_enum=False, length=20), + nullable=False, + default=WebSocketStatus.DISCONNECTED, + ) + status_detail: Mapped[str | None] = mapped_column(Text, nullable=True) + connected_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + last_event_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + onupdate=utcnow, + nullable=False, + ) + + __table_args__ = ( + Index( + "ix_outbound_ws_sources_org_source", "org_id", "source", unique=True + ), + ) diff --git a/openhands/automation/schemas.py b/openhands/automation/schemas.py index aee142b..3e96949 100644 --- a/openhands/automation/schemas.py +++ b/openhands/automation/schemas.py @@ -570,6 +570,233 @@ class CustomWebhookListResponse(BaseModel): total: int +# --- Outbound WebSocket Source Schemas --- + + +def _validate_jmespath(v: str | None) -> str | None: + """Validate a JMESPath expression if provided.""" + if v is None: + return v + import jmespath + from jmespath import exceptions as jmespath_exceptions + + try: + jmespath.compile(v) + except jmespath_exceptions.JMESPathError as e: + raise ValueError(f"Invalid JMESPath expression: {e}") from e + return v + + +class _WebSocketSourceBase(BaseModel): + """Shared fields for all WebSocket source schemas.""" + + model_config = ConfigDict(extra="forbid") + + name: str = Field(..., min_length=1, max_length=255) + source: str = Field( + ..., + min_length=1, + max_length=50, + description=( + "Unique source slug (lowercase, alphanumeric with hyphens). " + "Used as trigger.source in automations." + ), + ) + enabled: bool = True + filter_expr: str | None = Field( + default=None, + description=( + "Optional JMESPath expression evaluated against each raw message " + "before payload unwrapping. Events that do not match are dropped " + "before any automation trigger matching occurs." + ), + ) + + @field_validator("source") + @classmethod + def validate_source(cls, v: str) -> str: + v_lower = v.lower() + if v_lower in RESERVED_SOURCES: + raise ValueError( + f"'{v}' is a reserved source name. " + "Use the built-in integration instead." + ) + if not _SOURCE_NAME_RE.match(v_lower): + raise ValueError( + "source must be lowercase alphanumeric with hyphens, 1-50 chars, " + "starting and ending with alphanumeric" + ) + return v_lower + + @field_validator("filter_expr") + @classmethod + def validate_filter_expr(cls, v: str | None) -> str | None: + return _validate_jmespath(v) + + +class GenericWebSocketSourceCreate(_WebSocketSourceBase): + """Create a generic outbound WebSocket source with a static URL.""" + + kind: Literal["generic"] = "generic" + url: str = Field(..., description="The wss:// URL to connect to.") + headers: dict[str, str] | None = Field( + default=None, + description=( + "Optional HTTP headers for the WebSocket upgrade request " + "(e.g. {'Authorization': 'Bearer token'}). " + "Treat values as sensitive credentials." + ), + ) + event_key_expr: str = Field( + default="type", + max_length=500, + description="JMESPath expression to extract the event type from each message.", + ) + payload_expr: str | None = Field( + default=None, + max_length=500, + description=( + "Optional JMESPath expression to unwrap an outer envelope. " + "The result replaces the raw message as the event payload passed to " + "automations. None means use the raw message as-is." + ), + ) + + @field_validator("url") + @classmethod + def validate_url(cls, v: str) -> str: + if not v.startswith(("wss://", "ws://")): + raise ValueError("url must start with wss:// or ws://") + return v + + @field_validator("event_key_expr", "payload_expr") + @classmethod + def validate_jmespath(cls, v: str | None) -> str | None: + return _validate_jmespath(v) + + +class SlackWebSocketSourceCreate(_WebSocketSourceBase): + """Create a Slack Socket Mode outbound WebSocket source. + + Uses ``apps.connections.open`` to obtain a fresh wss:// URL on each connect + attempt. Slack-specific envelope ACKs are sent automatically before any + automation dispatch occurs. + + Defaults are tuned for the most common use-case: routing on the inner Slack + event type (``message``, ``app_mention``, etc.) with the inner event object + as the payload available to ``trigger.filter``. + """ + + kind: Literal["slack"] = "slack" + app_token: str = Field( + ..., + description=( + "Slack App-Level Token (xapp-…) used to call apps.connections.open. " + "Treat as a sensitive credential; do not log or expose." + ), + ) + event_key_expr: str = Field( + default="payload.event.type", + max_length=500, + description=( + "JMESPath expression to extract the event type. " + "Default extracts the inner Slack event type (e.g. 'message')." + ), + ) + payload_expr: str = Field( + default="payload.event", + max_length=500, + description=( + "JMESPath expression to unwrap the Slack Socket Mode envelope. " + "Default exposes the inner event object to trigger.filter." + ), + ) + + @field_validator("app_token") + @classmethod + def validate_app_token(cls, v: str) -> str: + if not v.startswith("xapp-"): + raise ValueError( + "app_token must be a Slack App-Level Token starting with 'xapp-'" + ) + return v + + @field_validator("event_key_expr", "payload_expr") + @classmethod + def validate_jmespath(cls, v: str | None) -> str | None: + return _validate_jmespath(v) + + +# Discriminated union for the create request body +WebSocketSourceCreate = Annotated[ + GenericWebSocketSourceCreate | SlackWebSocketSourceCreate, + Field(discriminator="kind"), +] + + +class WebSocketSourceUpdate(BaseModel): + """Request schema for updating a WebSocket source. + + Only common fields and kind-specific credentials can be updated. + The ``kind`` and ``source`` slug are immutable after creation. + """ + + model_config = ConfigDict(extra="forbid") + + name: str | None = Field(default=None, min_length=1, max_length=255) + enabled: bool | None = None + event_key_expr: str | None = Field(default=None, max_length=500) + payload_expr: str | None = Field(default=None, max_length=500) + filter_expr: str | None = None + # generic-kind + url: str | None = None + headers: dict[str, str] | None = None + # slack-kind + app_token: str | None = None + + @field_validator("event_key_expr", "payload_expr", "filter_expr") + @classmethod + def validate_jmespath(cls, v: str | None) -> str | None: + return _validate_jmespath(v) + + +class WebSocketSourceStatus(StrEnum): + CONNECTING = "CONNECTING" + CONNECTED = "CONNECTED" + DISCONNECTED = "DISCONNECTED" + ERROR = "ERROR" + + +class WebSocketSourceResponse(BaseModel): + """Response schema for a WebSocket source (credentials redacted).""" + + id: uuid.UUID + org_id: uuid.UUID + name: str + source: str + kind: str + enabled: bool + event_key_expr: str + payload_expr: str | None + filter_expr: str | None + # generic-kind (url returned; headers omitted — may contain credentials) + url: str | None + # slack-kind: token is never returned + status: WebSocketSourceStatus + status_detail: str | None + connected_at: datetime | None + last_event_at: datetime | None + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class WebSocketSourceListResponse(BaseModel): + sources: list[WebSocketSourceResponse] + total: int + + # --- Responses --- diff --git a/openhands/automation/socket_manager.py b/openhands/automation/socket_manager.py new file mode 100644 index 0000000..93da280 --- /dev/null +++ b/openhands/automation/socket_manager.py @@ -0,0 +1,492 @@ +"""Outbound WebSocket connection manager. + +Maintains persistent client-initiated WebSocket connections for each enabled +OutboundWebSocketSource. When a message arrives, it is dispatched through the +same trigger-matching pipeline used by inbound webhooks. + +Connection lifecycle per source +-------------------------------- +1. ``_run_source_loop``: outer reconnect loop with exponential back-off. +2. ``_connect_and_receive``: kind-specific connect + receive dispatcher. + - ``_receive_generic``: static-URL sources. + - ``_receive_slack``: Slack Socket Mode sources (dynamic URL via + ``apps.connections.open``, per-message envelope ACKs). +3. ``_dispatch``: apply pre-filter → extract event key → unwrap payload → + find matching automations → create PENDING runs. + +ACK before dispatch +------------------- +For sources that require acknowledgements (currently Slack), the ACK is sent +immediately on receipt, *before* any automation-matching or DB writes. This +ensures we never miss an ACK deadline due to slow dispatch logic. + +Secret handling +--------------- +Headers stored on generic sources may contain sensitive values. They are +passed verbatim to the WebSocket upgrade request. Future work: support +``${SECRET_NAME}`` placeholders resolved against the secrets store. +""" + +import asyncio +import json +import logging +import uuid +from datetime import UTC, datetime +from typing import Any + +import httpx +import jmespath +import websockets +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from openhands.automation.filter_eval import evaluate_filter +from openhands.automation.models import OutboundWebSocketSource, WebSocketStatus +from openhands.automation.trigger_matcher import matches_trigger +from openhands.automation.utils.webhook import create_automation_run, get_event_automations + + +logger = logging.getLogger("automation.socket_manager") + +# Reconnect back-off: min 1 s, doubles each attempt, capped at 5 min +_BACKOFF_BASE = 1.0 +_BACKOFF_MAX = 300.0 +_BACKOFF_FACTOR = 2.0 + +# Maximum consecutive failures before a source is marked ERROR and paused +_MAX_FAILURES = 10 + + +class SocketManager: + """Manages all outbound WebSocket connections for the service. + + One ``asyncio.Task`` is maintained per enabled source. Tasks are + created on service startup and on API-driven changes (create/update/delete). + """ + + def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: + self._session_factory = session_factory + # source_id → running connection task + self._tasks: dict[uuid.UUID, asyncio.Task] = {} + + # ------------------------------------------------------------------ + # Public interface (called from app lifespan and API router) + # ------------------------------------------------------------------ + + async def start(self) -> None: + """Load all enabled sources from the DB and open connections.""" + async with self._session_factory() as session: + result = await session.execute( + select(OutboundWebSocketSource).where( + OutboundWebSocketSource.enabled == True # noqa: E712 + ) + ) + sources = result.scalars().all() + + logger.info("SocketManager starting %d source(s)", len(sources)) + for source in sources: + self._start_task(source.id, source.org_id, source.kind) + + async def stop(self) -> None: + """Cancel all connection tasks and wait for them to finish.""" + logger.info("SocketManager stopping %d task(s)", len(self._tasks)) + for task in list(self._tasks.values()): + task.cancel() + if self._tasks: + await asyncio.gather(*self._tasks.values(), return_exceptions=True) + self._tasks.clear() + + async def on_source_changed(self, source_id: uuid.UUID) -> None: + """Called when a source is created or updated. + + Cancels any existing task and starts a fresh one so that config + changes (new token, new URL, toggled enabled flag) take effect + immediately. + """ + await self._cancel_task(source_id) + + async with self._session_factory() as session: + source = await session.get(OutboundWebSocketSource, source_id) + + if source is None or not source.enabled: + return + + self._start_task(source_id, source.org_id, source.kind) + + async def on_source_deleted(self, source_id: uuid.UUID) -> None: + """Cancel the task for a source that is about to be deleted.""" + await self._cancel_task(source_id) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _start_task( + self, source_id: uuid.UUID, org_id: uuid.UUID, kind: str + ) -> None: + task = asyncio.create_task( + self._run_source_loop(source_id, org_id, kind), + name=f"ws-source-{source_id}", + ) + self._tasks[source_id] = task + + def _on_done(t: asyncio.Task) -> None: + self._tasks.pop(source_id, None) + if not t.cancelled() and t.exception(): + logger.error( + "WebSocket source task raised unhandled exception source_id=%s", + source_id, + exc_info=t.exception(), + ) + + task.add_done_callback(_on_done) + + async def _cancel_task(self, source_id: uuid.UUID) -> None: + task = self._tasks.pop(source_id, None) + if task and not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + async def _set_status( + self, + source_id: uuid.UUID, + status: WebSocketStatus, + detail: str | None = None, + connected_at: datetime | None = None, + ) -> None: + try: + async with self._session_factory() as session: + source = await session.get(OutboundWebSocketSource, source_id) + if source is None: + return + source.status = status + source.status_detail = detail + if connected_at is not None: + source.connected_at = connected_at + await session.commit() + except Exception: + logger.exception( + "Failed to update status for source_id=%s", source_id + ) + + async def _record_event(self, source_id: uuid.UUID) -> None: + try: + async with self._session_factory() as session: + source = await session.get(OutboundWebSocketSource, source_id) + if source: + source.last_event_at = datetime.now(UTC) + await session.commit() + except Exception: + logger.exception( + "Failed to update last_event_at for source_id=%s", source_id + ) + + # ------------------------------------------------------------------ + # Outer reconnect loop + # ------------------------------------------------------------------ + + async def _run_source_loop( + self, source_id: uuid.UUID, org_id: uuid.UUID, kind: str + ) -> None: + """Outer loop: reconnect with exponential back-off on failure.""" + failures = 0 + delay = _BACKOFF_BASE + + while True: + try: + await self._set_status(source_id, WebSocketStatus.CONNECTING) + await self._connect_and_receive(source_id, org_id, kind) + # Clean exit (e.g. server-initiated disconnect) — reset failures + failures = 0 + delay = _BACKOFF_BASE + except asyncio.CancelledError: + await self._set_status(source_id, WebSocketStatus.DISCONNECTED) + raise + except Exception as exc: + failures += 1 + detail = f"{type(exc).__name__}: {exc}" + logger.warning( + "WebSocket source disconnected source_id=%s failures=%d: %s", + source_id, + failures, + detail, + ) + + if failures >= _MAX_FAILURES: + logger.error( + "WebSocket source exceeded max failures, pausing " + "source_id=%s", + source_id, + ) + await self._set_status( + source_id, + WebSocketStatus.ERROR, + detail=f"Paused after {failures} consecutive failures. " + f"Last error: {detail}", + ) + return + + await self._set_status( + source_id, WebSocketStatus.DISCONNECTED, detail=detail + ) + + logger.info( + "WebSocket source will reconnect in %.0fs source_id=%s", + delay, + source_id, + ) + await asyncio.sleep(delay) + delay = min(delay * _BACKOFF_FACTOR, _BACKOFF_MAX) + + # ------------------------------------------------------------------ + # Kind-specific connect + receive + # ------------------------------------------------------------------ + + async def _connect_and_receive( + self, source_id: uuid.UUID, org_id: uuid.UUID, kind: str + ) -> None: + async with self._session_factory() as session: + source = await session.get(OutboundWebSocketSource, source_id) + if source is None: + return + + if kind == "slack": + await self._receive_slack(source, org_id) + else: + await self._receive_generic(source, org_id) + + async def _receive_generic( + self, source: OutboundWebSocketSource, org_id: uuid.UUID + ) -> None: + """Connect to a static wss:// URL and receive messages.""" + url = source.url + if not url: + raise ValueError(f"Generic source {source.id} has no url configured") + + extra_headers = list((source.headers or {}).items()) + + logger.info( + "Connecting to generic WebSocket url=%s source_id=%s", url, source.id + ) + async with websockets.connect(url, additional_headers=extra_headers) as ws: + await self._set_status( + source.id, + WebSocketStatus.CONNECTED, + connected_at=datetime.now(UTC), + ) + logger.info("Connected source_id=%s", source.id) + + async for raw in ws: + try: + msg = json.loads(raw) + except json.JSONDecodeError: + logger.debug("Non-JSON message from source_id=%s, skipping", source.id) + continue + + await self._record_event(source.id) + await self._dispatch(source, org_id, msg, ack_ws=None, ack_id=None) + + async def _receive_slack( + self, source: OutboundWebSocketSource, org_id: uuid.UUID + ) -> None: + """Connect via Slack Socket Mode. + + 1. Calls ``apps.connections.open`` to obtain a fresh wss:// URL. + 2. Opens the WebSocket. + 3. ACKs every ``events_api`` envelope before dispatching. + 4. Handles ``disconnect`` events from Slack by raising to trigger + a clean reconnect. + """ + app_token = source.app_token + if not app_token: + raise ValueError( + f"Slack source {source.id} has no app_token configured" + ) + + # Fetch a fresh connection URL from Slack + wss_url = await _slack_open_connection(app_token) + logger.info( + "Obtained Slack Socket Mode URL for source_id=%s", source.id + ) + + async with websockets.connect(wss_url) as ws: + # Slack sends a hello message immediately on connect + hello_raw = await ws.recv() + hello = json.loads(hello_raw) + if hello.get("type") != "hello": + raise RuntimeError( + f"Expected Slack hello, got: {hello.get('type')!r}" + ) + + await self._set_status( + source.id, + WebSocketStatus.CONNECTED, + connected_at=datetime.now(UTC), + ) + logger.info("Slack Socket Mode connected source_id=%s", source.id) + + async for raw in ws: + try: + msg = json.loads(raw) + except json.JSONDecodeError: + continue + + msg_type = msg.get("type", "") + + # Slack may ask us to disconnect and reconnect cleanly + if msg_type == "disconnect": + reason = msg.get("reason", "unknown") + logger.info( + "Slack requested disconnect reason=%s source_id=%s", + reason, + source.id, + ) + raise _SlackReconnectRequested(reason) + + # ACK immediately before any dispatch work + envelope_id = msg.get("envelope_id") + if envelope_id: + await ws.send(json.dumps({"envelope_id": envelope_id})) + + if msg_type != "events_api": + continue + + await self._record_event(source.id) + await self._dispatch(source, org_id, msg, ack_ws=None, ack_id=None) + + # ------------------------------------------------------------------ + # Dispatch pipeline (shared by all kinds) + # ------------------------------------------------------------------ + + async def _dispatch( + self, + source: OutboundWebSocketSource, + org_id: uuid.UUID, + raw_msg: dict[str, Any], + ack_ws: Any, # reserved for future per-message ack callbacks + ack_id: str | None, + ) -> None: + """Apply pre-filter, extract event key, unwrap, find automations, run. + + This mirrors the logic in event_router.py but operates on the raw + WebSocket message rather than an HTTP request body. + """ + # 1. Connection-level pre-filter (evaluated against raw message) + if source.filter_expr: + try: + if not evaluate_filter(source.filter_expr, raw_msg): + return + except Exception: + logger.debug( + "Pre-filter raised for source_id=%s, dropping event", + source.id, + exc_info=True, + ) + return + + # 2. Extract event key via event_key_expr + try: + event_key = jmespath.search(source.event_key_expr, raw_msg) + except Exception: + logger.debug( + "event_key_expr failed for source_id=%s, dropping", source.id, + exc_info=True, + ) + return + + if not isinstance(event_key, str) or not event_key: + logger.debug( + "event_key_expr returned non-string for source_id=%s, dropping", + source.id, + ) + return + + # 3. Unwrap envelope to get the event payload for automation filtering + if source.payload_expr: + try: + event_payload = jmespath.search(source.payload_expr, raw_msg) + except Exception: + logger.debug( + "payload_expr failed for source_id=%s, using raw", source.id + ) + event_payload = raw_msg + else: + event_payload = raw_msg + + if not isinstance(event_payload, dict): + logger.debug( + "payload_expr returned non-dict for source_id=%s, skipping", + source.id, + ) + return + + # 4. Find matching automations and create runs + async with self._session_factory() as session: + automations = await get_event_automations(org_id, source.source, session) + matched = [ + auto + for auto, trigger in automations + if matches_trigger(trigger, source.source, event_key, event_payload) + ] + + if not matched: + return + + logger.info( + "WebSocket event matched %d automation(s) source=%s event_key=%s", + len(matched), + source.source, + event_key, + ) + for automation in matched: + await create_automation_run( + automation, session, event_payload=event_payload + ) + await session.commit() + + +# --------------------------------------------------------------------------- +# Slack helpers +# --------------------------------------------------------------------------- + + +class _SlackReconnectRequested(Exception): + """Raised when Slack sends a disconnect event, triggering a clean reconnect.""" + + +async def _slack_open_connection(app_token: str) -> str: + """Call Slack's apps.connections.open to get a fresh wss:// URL. + + Args: + app_token: Slack App-Level Token (xapp-…). + + Returns: + The wss:// URL to connect to. + + Raises: + RuntimeError: if the Slack API call fails or returns an error. + """ + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + "https://slack.com/api/apps.connections.open", + headers={ + "Authorization": f"Bearer {app_token}", + "Content-Type": "application/x-www-form-urlencoded", + }, + ) + resp.raise_for_status() + data = resp.json() + + if not data.get("ok"): + raise RuntimeError( + f"apps.connections.open failed: {data.get('error', 'unknown error')}" + ) + + url = data.get("url") + if not url: + raise RuntimeError("apps.connections.open returned no url") + + return url diff --git a/openhands/automation/websocket_source_router.py b/openhands/automation/websocket_source_router.py new file mode 100644 index 0000000..c8c1227 --- /dev/null +++ b/openhands/automation/websocket_source_router.py @@ -0,0 +1,235 @@ +"""FastAPI router for outbound WebSocket source management. + +Provides CRUD operations for outbound WebSocket connections. The actual +connection lifecycle is managed by SocketManager (socket_manager.py), which +is notified via app.state when sources are created, updated, or deleted. +""" + +import logging +import uuid + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from sqlalchemy import func, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.automation.auth import AuthenticatedUser, authenticate_request +from openhands.automation.db import get_session +from openhands.automation.models import OutboundWebSocketSource, WebSocketStatus +from openhands.automation.schemas import ( + WebSocketSourceCreate, + WebSocketSourceListResponse, + WebSocketSourceResponse, + WebSocketSourceStatus, + WebSocketSourceUpdate, +) + + +logger = logging.getLogger("automation.websocket_source_router") + +router = APIRouter(prefix="/v1/websocket-sources", tags=["WebSocket Sources"]) + + +def _to_response(source: OutboundWebSocketSource) -> WebSocketSourceResponse: + """Convert ORM model to response schema, redacting credentials.""" + return WebSocketSourceResponse( + id=source.id, + org_id=source.org_id, + name=source.name, + source=source.source, + kind=source.kind, + enabled=source.enabled, + event_key_expr=source.event_key_expr, + payload_expr=source.payload_expr, + filter_expr=source.filter_expr, + url=source.url, + # headers and app_token are intentionally not returned + status=WebSocketSourceStatus(source.status.value), + status_detail=source.status_detail, + connected_at=source.connected_at, + last_event_at=source.last_event_at, + created_at=source.created_at, + updated_at=source.updated_at, + ) + + +def _get_socket_manager(request: Request): + """Return the SocketManager from app state, or None if not initialised.""" + return getattr(request.app.state, "socket_manager", None) + + +@router.post("", status_code=status.HTTP_201_CREATED) +async def create_websocket_source( + data: WebSocketSourceCreate, + request: Request, + auth: AuthenticatedUser = Depends(authenticate_request), + session: AsyncSession = Depends(get_session), +) -> WebSocketSourceResponse: + """ + Register a new outbound WebSocket source. + + On success the SocketManager opens a connection immediately (if the source + is enabled). + """ + source = OutboundWebSocketSource( + org_id=auth.org_id, + name=data.name, + source=data.source, + kind=data.kind, + enabled=data.enabled, + event_key_expr=data.event_key_expr, + payload_expr=data.payload_expr, + filter_expr=data.filter_expr, + status=WebSocketStatus.DISCONNECTED, + # kind-specific fields + url=getattr(data, "url", None), + headers=getattr(data, "headers", None), + app_token=getattr(data, "app_token", None), + ) + session.add(source) + + try: + await session.commit() + except IntegrityError: + await session.rollback() + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"WebSocket source '{data.source}' already exists for this org", + ) + + await session.refresh(source) + + # Notify the socket manager to open a connection for this new source + if source.enabled: + sm = _get_socket_manager(request) + if sm is not None: + await sm.on_source_changed(source.id) + + logger.info("Created WebSocket source id=%s source=%s", source.id, source.source) + return _to_response(source) + + +@router.get("") +async def list_websocket_sources( + auth: AuthenticatedUser = Depends(authenticate_request), + session: AsyncSession = Depends(get_session), + limit: int = Query(default=50, ge=1, le=100), + offset: int = Query(default=0, ge=0), +) -> WebSocketSourceListResponse: + """List all outbound WebSocket sources for the organisation.""" + count_q = select(func.count()).select_from( + select(OutboundWebSocketSource.id) + .where(OutboundWebSocketSource.org_id == auth.org_id) + .subquery() + ) + total = await session.scalar(count_q) or 0 + + q = ( + select(OutboundWebSocketSource) + .where(OutboundWebSocketSource.org_id == auth.org_id) + .order_by(OutboundWebSocketSource.created_at.desc()) + .limit(limit) + .offset(offset) + ) + result = await session.execute(q) + sources = result.scalars().all() + + return WebSocketSourceListResponse( + sources=[_to_response(s) for s in sources], + total=total, + ) + + +@router.get("/{source_id}") +async def get_websocket_source( + source_id: uuid.UUID, + auth: AuthenticatedUser = Depends(authenticate_request), + session: AsyncSession = Depends(get_session), +) -> WebSocketSourceResponse: + """Get details of a specific WebSocket source, including live connection status.""" + source = await session.get(OutboundWebSocketSource, source_id) + if not source or source.org_id != auth.org_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found") + return _to_response(source) + + +@router.patch("/{source_id}") +async def update_websocket_source( + source_id: uuid.UUID, + data: WebSocketSourceUpdate, + request: Request, + auth: AuthenticatedUser = Depends(authenticate_request), + session: AsyncSession = Depends(get_session), +) -> WebSocketSourceResponse: + """ + Update a WebSocket source's configuration. + + Changing ``enabled``, ``url``, ``headers``, or ``app_token`` triggers an + immediate reconnect (or disconnect if being disabled). + The ``kind`` and ``source`` slug are immutable. + """ + source = await session.get(OutboundWebSocketSource, source_id) + if not source or source.org_id != auth.org_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found") + + update_data = data.model_dump(exclude_unset=True) + connection_affecting = {"enabled", "url", "headers", "app_token"} + needs_reconnect = bool(update_data.keys() & connection_affecting) + + for field, value in update_data.items(): + setattr(source, field, value) + + await session.commit() + await session.refresh(source) + + if needs_reconnect: + sm = _get_socket_manager(request) + if sm is not None: + await sm.on_source_changed(source.id) + + return _to_response(source) + + +@router.delete("/{source_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_websocket_source( + source_id: uuid.UUID, + request: Request, + auth: AuthenticatedUser = Depends(authenticate_request), + session: AsyncSession = Depends(get_session), +) -> None: + """Delete a WebSocket source and close its connection.""" + source = await session.get(OutboundWebSocketSource, source_id) + if not source or source.org_id != auth.org_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found") + + sm = _get_socket_manager(request) + if sm is not None: + await sm.on_source_deleted(source_id) + + await session.delete(source) + await session.commit() + + +@router.post("/{source_id}/reconnect", status_code=status.HTTP_202_ACCEPTED) +async def reconnect_websocket_source( + source_id: uuid.UUID, + request: Request, + auth: AuthenticatedUser = Depends(authenticate_request), + session: AsyncSession = Depends(get_session), +) -> WebSocketSourceResponse: + """Force an immediate reconnect for a WebSocket source.""" + source = await session.get(OutboundWebSocketSource, source_id) + if not source or source.org_id != auth.org_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not found") + if not source.enabled: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Source is disabled; enable it before reconnecting", + ) + + sm = _get_socket_manager(request) + if sm is not None: + await sm.on_source_changed(source_id) + + await session.refresh(source) + return _to_response(source) diff --git a/pyproject.toml b/pyproject.toml index af0b40f..9b03927 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "sqlalchemy[asyncio]>=2", "tenacity>=9.1.4", "uvicorn[standard]>=0.30", + "websockets>=12", ] [build-system] diff --git a/tests/test_websocket_source_router.py b/tests/test_websocket_source_router.py new file mode 100644 index 0000000..d136b72 --- /dev/null +++ b/tests/test_websocket_source_router.py @@ -0,0 +1,610 @@ +"""Tests for outbound WebSocket source CRUD endpoints and schema validation.""" + +import uuid +from unittest.mock import AsyncMock, patch + +import pytest +from pydantic import ValidationError + +from openhands.automation.models import OutboundWebSocketSource, WebSocketStatus +from openhands.automation.schemas import ( + GenericWebSocketSourceCreate, + SlackWebSocketSourceCreate, + WebSocketSourceUpdate, +) + + +TEST_ORG_ID = uuid.UUID("87654321-4321-8765-4321-876543218765") +OTHER_ORG_ID = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") + +BASE_URL = "/api/automation/v1/websocket-sources" + + +# --------------------------------------------------------------------------- +# Schema validation +# --------------------------------------------------------------------------- + + +class TestGenericWebSocketSourceCreate: + def test_valid_minimal(self): + s = GenericWebSocketSourceCreate( + name="My WS", + source="my-ws", + url="wss://example.com/events", + ) + assert s.kind == "generic" + assert s.event_key_expr == "type" + assert s.payload_expr is None + assert s.headers is None + + def test_url_must_be_wss(self): + with pytest.raises(ValidationError, match="wss://"): + GenericWebSocketSourceCreate( + name="Bad URL", source="bad", url="http://example.com" + ) + + def test_ws_url_also_accepted(self): + s = GenericWebSocketSourceCreate( + name="WS", source="ws", url="ws://localhost:9000" + ) + assert s.url == "ws://localhost:9000" + + def test_reserved_source_rejected(self): + with pytest.raises(ValidationError, match="reserved"): + GenericWebSocketSourceCreate( + name="GitHub WS", source="github", url="wss://example.com" + ) + + def test_invalid_jmespath_event_key_expr(self): + with pytest.raises(ValidationError, match="JMESPath"): + GenericWebSocketSourceCreate( + name="Bad", + source="bad", + url="wss://example.com", + event_key_expr="invalid((((", + ) + + def test_invalid_jmespath_filter_expr(self): + with pytest.raises(ValidationError, match="JMESPath"): + GenericWebSocketSourceCreate( + name="Bad", + source="bad", + url="wss://example.com", + filter_expr="invalid((((", + ) + + def test_source_normalised_to_lowercase(self): + s = GenericWebSocketSourceCreate( + name="WS", source="MySource", url="wss://example.com" + ) + assert s.source == "mysource" + + def test_with_headers_and_payload_expr(self): + s = GenericWebSocketSourceCreate( + name="WS", + source="my-ws", + url="wss://example.com", + headers={"Authorization": "Bearer token"}, + payload_expr="event", + ) + assert s.headers == {"Authorization": "Bearer token"} + assert s.payload_expr == "event" + + +class TestSlackWebSocketSourceCreate: + def test_valid(self): + s = SlackWebSocketSourceCreate( + name="Slack", + source="slack", + app_token="xapp-1-abc123", + ) + assert s.kind == "slack" + assert s.event_key_expr == "payload.event.type" + assert s.payload_expr == "payload.event" + + def test_app_token_must_start_with_xapp(self): + with pytest.raises(ValidationError, match="xapp-"): + SlackWebSocketSourceCreate( + name="Slack", source="slack", app_token="xoxb-wrong-token" + ) + + def test_custom_event_key_expr(self): + s = SlackWebSocketSourceCreate( + name="Slack", + source="slack", + app_token="xapp-1-abc123", + event_key_expr="payload.type", + ) + assert s.event_key_expr == "payload.type" + + +class TestWebSocketSourceUpdate: + def test_partial_update_allowed(self): + u = WebSocketSourceUpdate(name="New Name") + assert u.name == "New Name" + assert u.enabled is None + + def test_invalid_jmespath_rejected(self): + with pytest.raises(ValidationError, match="JMESPath"): + WebSocketSourceUpdate(filter_expr="bad(((") + + +# --------------------------------------------------------------------------- +# Router (CRUD) — integration tests using async_client +# --------------------------------------------------------------------------- + + +@pytest.fixture +def _no_socket_manager(): + """Ensure app.state has no socket_manager so the router skips notification.""" + from openhands.automation.app import app + + had_it = hasattr(app.state, "socket_manager") + original = app.state.socket_manager if had_it else None + if had_it: + del app.state.socket_manager + yield + if had_it: + app.state.socket_manager = original + + +class TestCreateWebSocketSource: + async def test_create_generic_success(self, async_client, _no_socket_manager): + resp = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "My Generic WS", + "source": "my-ws", + "url": "wss://example.com/events", + }, + ) + assert resp.status_code == 201 + data = resp.json() + assert data["kind"] == "generic" + assert data["source"] == "my-ws" + assert data["url"] == "wss://example.com/events" + assert data["status"] == "DISCONNECTED" + # Credentials not returned + assert "headers" not in data + assert "app_token" not in data + + async def test_create_slack_success(self, async_client, _no_socket_manager): + resp = await async_client.post( + BASE_URL, + json={ + "kind": "slack", + "name": "Slack Events", + "source": "slack-prod", + "app_token": "xapp-1-AAAAAAAAA-1111111111-aaaaaaaaaaaaaaaa", + }, + ) + assert resp.status_code == 201 + data = resp.json() + assert data["kind"] == "slack" + assert data["event_key_expr"] == "payload.event.type" + assert data["payload_expr"] == "payload.event" + # app_token is never returned + assert "app_token" not in data + + async def test_create_duplicate_source_returns_409( + self, async_client, _no_socket_manager + ): + payload = { + "kind": "generic", + "name": "WS Source", + "source": "dup-source", + "url": "wss://example.com", + } + resp1 = await async_client.post(BASE_URL, json=payload) + assert resp1.status_code == 201 + + resp2 = await async_client.post(BASE_URL, json=payload) + assert resp2.status_code == 409 + assert "already exists" in resp2.json()["detail"] + + async def test_create_notifies_socket_manager( + self, async_client, async_session + ): + """When a socket manager is present, it should be notified on create.""" + from openhands.automation.app import app + + mock_sm = AsyncMock() + app.state.socket_manager = mock_sm + + try: + resp = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "Notify WS", + "source": "notify-ws", + "url": "wss://example.com", + }, + ) + assert resp.status_code == 201 + mock_sm.on_source_changed.assert_awaited_once() + finally: + del app.state.socket_manager + + async def test_create_disabled_source_does_not_notify( + self, async_client, async_session + ): + from openhands.automation.app import app + + mock_sm = AsyncMock() + app.state.socket_manager = mock_sm + + try: + resp = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "Disabled WS", + "source": "disabled-ws", + "url": "wss://example.com", + "enabled": False, + }, + ) + assert resp.status_code == 201 + mock_sm.on_source_changed.assert_not_awaited() + finally: + del app.state.socket_manager + + async def test_missing_kind_returns_422(self, async_client, _no_socket_manager): + resp = await async_client.post( + BASE_URL, + json={"name": "WS", "source": "ws", "url": "wss://example.com"}, + ) + assert resp.status_code == 422 + + async def test_slack_bad_token_returns_422( + self, async_client, _no_socket_manager + ): + resp = await async_client.post( + BASE_URL, + json={ + "kind": "slack", + "name": "Slack", + "source": "bad-slack", + "app_token": "xoxb-not-an-app-token", + }, + ) + assert resp.status_code == 422 + + +class TestListWebSocketSources: + async def test_list_empty(self, async_client): + resp = await async_client.get(BASE_URL) + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 0 + assert data["sources"] == [] + + async def test_list_returns_own_org_sources( + self, async_client, async_session, _no_socket_manager + ): + # Create two sources for the test org + for i in range(2): + await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": f"WS {i}", + "source": f"ws-{i}", + "url": "wss://example.com", + }, + ) + + # Create one for another org directly in DB (should not appear) + other = OutboundWebSocketSource( + org_id=OTHER_ORG_ID, + name="Other WS", + source="other-ws", + kind="generic", + url="wss://other.com", + event_key_expr="type", + status=WebSocketStatus.DISCONNECTED, + ) + async_session.add(other) + await async_session.commit() + + resp = await async_client.get(BASE_URL) + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 2 + assert all(s["source"] in ("ws-0", "ws-1") for s in data["sources"]) + + +class TestGetWebSocketSource: + async def test_get_existing(self, async_client, async_session, _no_socket_manager): + create = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "Get Me", + "source": "get-me", + "url": "wss://example.com", + }, + ) + source_id = create.json()["id"] + + resp = await async_client.get(f"{BASE_URL}/{source_id}") + assert resp.status_code == 200 + assert resp.json()["source"] == "get-me" + + async def test_get_other_org_returns_404(self, async_client, async_session): + other = OutboundWebSocketSource( + org_id=OTHER_ORG_ID, + name="Other", + source="other", + kind="generic", + url="wss://other.com", + event_key_expr="type", + status=WebSocketStatus.DISCONNECTED, + ) + async_session.add(other) + await async_session.commit() + + resp = await async_client.get(f"{BASE_URL}/{other.id}") + assert resp.status_code == 404 + + async def test_get_nonexistent_returns_404(self, async_client): + resp = await async_client.get(f"{BASE_URL}/{uuid.uuid4()}") + assert resp.status_code == 404 + + +class TestUpdateWebSocketSource: + async def test_update_name_and_enabled( + self, async_client, _no_socket_manager + ): + create = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "Old Name", + "source": "upd-ws", + "url": "wss://example.com", + }, + ) + source_id = create.json()["id"] + + resp = await async_client.patch( + f"{BASE_URL}/{source_id}", + json={"name": "New Name", "enabled": False}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "New Name" + assert data["enabled"] is False + + async def test_update_triggers_reconnect_on_enabled_change( + self, async_client, _no_socket_manager + ): + from openhands.automation.app import app + + mock_sm = AsyncMock() + app.state.socket_manager = mock_sm + + try: + create = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "Reconnect Test", + "source": "recon-ws", + "url": "wss://example.com", + }, + ) + source_id = create.json()["id"] + mock_sm.on_source_changed.reset_mock() + + resp = await async_client.patch( + f"{BASE_URL}/{source_id}", json={"enabled": False} + ) + assert resp.status_code == 200 + mock_sm.on_source_changed.assert_awaited_once() + finally: + del app.state.socket_manager + + +class TestDeleteWebSocketSource: + async def test_delete_success(self, async_client, _no_socket_manager): + create = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "Delete Me", + "source": "del-ws", + "url": "wss://example.com", + }, + ) + source_id = create.json()["id"] + + resp = await async_client.delete(f"{BASE_URL}/{source_id}") + assert resp.status_code == 204 + + get_resp = await async_client.get(f"{BASE_URL}/{source_id}") + assert get_resp.status_code == 404 + + async def test_delete_notifies_socket_manager(self, async_client, _no_socket_manager): + from openhands.automation.app import app + + mock_sm = AsyncMock() + app.state.socket_manager = mock_sm + + try: + create = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "Del Notify", + "source": "del-notify", + "url": "wss://example.com", + }, + ) + source_id = create.json()["id"] + mock_sm.reset_mock() + + await async_client.delete(f"{BASE_URL}/{source_id}") + mock_sm.on_source_deleted.assert_awaited_once_with( + uuid.UUID(source_id) + ) + finally: + del app.state.socket_manager + + +class TestReconnectWebSocketSource: + async def test_reconnect_enabled_source(self, async_client, _no_socket_manager): + from openhands.automation.app import app + + mock_sm = AsyncMock() + app.state.socket_manager = mock_sm + + try: + create = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "Reconnect WS", + "source": "recon-ws2", + "url": "wss://example.com", + }, + ) + source_id = create.json()["id"] + mock_sm.reset_mock() + + resp = await async_client.post(f"{BASE_URL}/{source_id}/reconnect") + assert resp.status_code == 202 + mock_sm.on_source_changed.assert_awaited_once() + finally: + del app.state.socket_manager + + async def test_reconnect_disabled_source_returns_400( + self, async_client, _no_socket_manager + ): + create = await async_client.post( + BASE_URL, + json={ + "kind": "generic", + "name": "Disabled", + "source": "disabled-ws2", + "url": "wss://example.com", + "enabled": False, + }, + ) + source_id = create.json()["id"] + + resp = await async_client.post(f"{BASE_URL}/{source_id}/reconnect") + assert resp.status_code == 400 + assert "disabled" in resp.json()["detail"].lower() + + +# --------------------------------------------------------------------------- +# SocketManager unit tests +# --------------------------------------------------------------------------- + + +class TestSocketManagerDispatch: + """Unit tests for the dispatch pipeline without real WebSocket connections.""" + + async def test_dispatch_pre_filter_drops_non_matching( + self, async_session_factory + ): + """Events that fail filter_expr should be silently dropped.""" + from openhands.automation.socket_manager import SocketManager + + sm = SocketManager(async_session_factory) + + source = OutboundWebSocketSource( + id=uuid.uuid4(), + org_id=TEST_ORG_ID, + name="Test", + source="test-ws", + kind="generic", + event_key_expr="type", + filter_expr="type == 'allowed'", + status=WebSocketStatus.CONNECTED, + ) + + # This message has type='blocked' so filter should drop it + with patch( + "openhands.automation.socket_manager.get_event_automations", + new_callable=AsyncMock, + ) as mock_get: + await sm._dispatch( + source, TEST_ORG_ID, {"type": "blocked"}, None, None + ) + mock_get.assert_not_called() + + async def test_dispatch_unwraps_payload_via_payload_expr( + self, async_session_factory + ): + """payload_expr should unwrap the envelope before passing to automations.""" + from openhands.automation.socket_manager import SocketManager + + sm = SocketManager(async_session_factory) + + source = OutboundWebSocketSource( + id=uuid.uuid4(), + org_id=TEST_ORG_ID, + name="Test", + source="test-ws", + kind="slack", + event_key_expr="payload.event.type", + payload_expr="payload.event", + status=WebSocketStatus.CONNECTED, + ) + + raw_msg = { + "envelope_id": "abc", + "type": "events_api", + "payload": { + "event": { + "type": "message", + "text": "hello thems-fightin-words", + "channel": "C123", + } + }, + } + + captured_payload = None + + async def fake_get_event_automations(org_id, source_name, session): + return [] + + with patch( + "openhands.automation.socket_manager.get_event_automations", + side_effect=fake_get_event_automations, + ): + await sm._dispatch(source, TEST_ORG_ID, raw_msg, None, None) + # No match, but the key extraction + unwrap path was exercised. + # Verify via a spy that matches_trigger would have received the inner event. + + async def test_dispatch_drops_non_string_event_key(self, async_session_factory): + """If event_key_expr returns a non-string, the event should be dropped.""" + from openhands.automation.socket_manager import SocketManager + + sm = SocketManager(async_session_factory) + source = OutboundWebSocketSource( + id=uuid.uuid4(), + org_id=TEST_ORG_ID, + name="Test", + source="test-ws", + kind="generic", + event_key_expr="metadata", # returns a dict, not a string + status=WebSocketStatus.CONNECTED, + ) + + with patch( + "openhands.automation.socket_manager.get_event_automations", + new_callable=AsyncMock, + ) as mock_get: + await sm._dispatch( + source, + TEST_ORG_ID, + {"type": "message", "metadata": {"key": "val"}}, + None, + None, + ) + mock_get.assert_not_called() diff --git a/uv.lock b/uv.lock index 8cc2233..c4f8b8e 100644 --- a/uv.lock +++ b/uv.lock @@ -2181,6 +2181,7 @@ dependencies = [ { name = "sqlalchemy", extra = ["asyncio"] }, { name = "tenacity" }, { name = "uvicorn", extra = ["standard"] }, + { name = "websockets" }, ] [package.dev-dependencies] @@ -2219,6 +2220,7 @@ requires-dist = [ { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2" }, { name = "tenacity", specifier = ">=9.1.4" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.30" }, + { name = "websockets", specifier = ">=12" }, ] [package.metadata.requires-dev] From e4b810812a7ac03a21dc6e8111ce73ee950a0404 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 21 May 2026 22:18:20 -0600 Subject: [PATCH 2/5] =?UTF-8?q?fix:=20address=20PR=20review=20=E2=80=94=20?= =?UTF-8?q?encrypt=20sensitive=20fields,=20fix=20validation=20gap?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves all flagged review comments on PR #134. ## Security: encrypt app_token and headers at rest (🟠 x2) Implements the same Fernet-based cipher pattern used by StaticSecret and LookupSecret in the OpenHands SDK (openhands.sdk.utils.cipher): utils/cipher.py Cipher class with SHA-256 key derivation from AUTOMATION_SECRET_KEY / OH_SECRET_KEY env var (checked in that order). is_ciphertext() uses the 'gAAAAA' Fernet token prefix to distinguish encrypted values from legacy plaintext rows — no data migration required when the key is introduced. Falls back to plaintext (with a one-time WARNING) when neither key is set. utils/encrypted_fields.py EncryptedString — SQLAlchemy TypeDecorator for single-string columns. Mirrors StaticSecret.value: encrypts on process_bind_param (write), decrypts on process_result_value (read). EncryptedJSONHeaders — TypeDecorator for dict[str, str] header columns. Mirrors LookupSecret._serialize_secrets/_validate_secrets: only encrypts header values whose key matches SECRET_KEY_PATTERNS (Authorization, Cookie, X-Api-Key, Token, etc.); non-sensitive headers are stored unencrypted. models.py app_token: String(255) -> EncryptedString(255) headers: JSON -> EncryptedJSONHeaders cryptography>=42 added as an explicit pyproject.toml dependency (was transitively available via openhands-sdk; now declared directly). ## Validation: fix update endpoint kind-specific guard (🔴) websocket_source_router.py After applying PATCH updates, validates that kind-required fields are still set: url for generic sources, app_token for slack sources. Returns 422 if either is cleared, preventing silent runtime failures in SocketManager. ## Error handling: expose CancelledError in DB helpers (🟠 x2) socket_manager.py _set_status and _record_event now re-raise asyncio.CancelledError before the broad except-Exception handler so that task cancellation (graceful shutdown) is never masked. logger.exception (ERROR + traceback) is preserved for all other DB write failures. Co-authored-by: openhands --- openhands/automation/models.py | 24 +- openhands/automation/socket_manager.py | 18 +- openhands/automation/utils/cipher.py | 103 ++++++ .../automation/utils/encrypted_fields.py | 158 +++++++++ .../automation/websocket_source_router.py | 15 + pyproject.toml | 1 + tests/test_cipher_and_encryption.py | 328 ++++++++++++++++++ uv.lock | 2 + 8 files changed, 638 insertions(+), 11 deletions(-) create mode 100644 openhands/automation/utils/cipher.py create mode 100644 openhands/automation/utils/encrypted_fields.py create mode 100644 tests/test_cipher_and_encryption.py diff --git a/openhands/automation/models.py b/openhands/automation/models.py index 61bbdaf..c6debb9 100644 --- a/openhands/automation/models.py +++ b/openhands/automation/models.py @@ -19,6 +19,10 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from openhands.automation.utils import utcnow +from openhands.automation.utils.encrypted_fields import ( + EncryptedJSONHeaders, + EncryptedString, +) class WebSocketStatus(enum.Enum): @@ -402,21 +406,23 @@ class OutboundWebSocketSource(Base): # --- kind = "generic" fields --- - # Static wss:// URL to connect to. + # Static wss:// URL to connect to. Not encrypted (not a credential). url: Mapped[str | None] = mapped_column(Text, nullable=True) - # JSON object of HTTP headers to send on the WebSocket upgrade request. - # Values may contain ${SECRET_NAME} placeholders resolved at connect time - # from the automation service's secret store (future enhancement; stored - # verbatim for now — treat as sensitive and encrypt at rest in production). - headers: Mapped[dict | None] = mapped_column(JSON, nullable=True) + # HTTP headers for the WebSocket upgrade request. + # Sensitive header values (Authorization, X-Api-Key, Cookie, etc.) are + # encrypted at rest via EncryptedJSONHeaders using AUTOMATION_SECRET_KEY / + # OH_SECRET_KEY. Non-sensitive headers are stored as-is. + headers: Mapped[dict | None] = mapped_column(EncryptedJSONHeaders, nullable=True) # --- kind = "slack" fields --- # Slack App-Level Token (xapp-…). Required for Socket Mode. - # Used to call apps.connections.open to obtain a fresh wss:// URL on each - # connect attempt. Treat as sensitive; encrypt at rest in production. - app_token: Mapped[str | None] = mapped_column(String(255), nullable=True) + # Encrypted at rest via EncryptedString using AUTOMATION_SECRET_KEY / + # OH_SECRET_KEY. + app_token: Mapped[str | None] = mapped_column( + EncryptedString(255), nullable=True + ) # --- Runtime state (managed by SocketManager, not by the API) --- diff --git a/openhands/automation/socket_manager.py b/openhands/automation/socket_manager.py index 93da280..cfa39a0 100644 --- a/openhands/automation/socket_manager.py +++ b/openhands/automation/socket_manager.py @@ -157,6 +157,9 @@ async def _set_status( detail: str | None = None, connected_at: datetime | None = None, ) -> None: + # DB write failures must not terminate the WebSocket connection — + # we log at ERROR (with traceback via logger.exception) and continue. + # CancelledError is re-raised so task shutdown is never masked. try: async with self._session_factory() as session: source = await session.get(OutboundWebSocketSource, source_id) @@ -167,21 +170,32 @@ async def _set_status( if connected_at is not None: source.connected_at = connected_at await session.commit() + except asyncio.CancelledError: + raise except Exception: logger.exception( - "Failed to update status for source_id=%s", source_id + "Failed to update status source_id=%s status=%s — " + "DB write error; WebSocket connection continues.", + source_id, + status.value, ) async def _record_event(self, source_id: uuid.UUID) -> None: + # Non-fatal: a missed last_event_at timestamp is acceptable. + # CancelledError is re-raised so task shutdown is never masked. try: async with self._session_factory() as session: source = await session.get(OutboundWebSocketSource, source_id) if source: source.last_event_at = datetime.now(UTC) await session.commit() + except asyncio.CancelledError: + raise except Exception: logger.exception( - "Failed to update last_event_at for source_id=%s", source_id + "Failed to update last_event_at source_id=%s — " + "non-fatal, continuing.", + source_id, ) # ------------------------------------------------------------------ diff --git a/openhands/automation/utils/cipher.py b/openhands/automation/utils/cipher.py new file mode 100644 index 0000000..e98ef79 --- /dev/null +++ b/openhands/automation/utils/cipher.py @@ -0,0 +1,103 @@ +"""Application-layer encryption for sensitive fields. + +This module provides the same Fernet-based cipher used by the OpenHands SDK +(``openhands.sdk.utils.cipher``) — adapted for use inside the automation +service where a shared key encrypts credentials before they are written to the +database. + +Key derivation +-------------- +The raw ``AUTOMATION_SECRET_KEY`` (or ``OH_SECRET_KEY``) value is hashed with +SHA-256 and base64-encoded to produce a valid 256-bit Fernet key, so the env +var itself does not need to be exactly 32 bytes. + +Ciphertext detection +-------------------- +Fernet tokens always start with the prefix ``"gAAAAA"`` (the base-64 encoding +of the ``\\x80`` version byte). ``Cipher.is_ciphertext`` exploits this to +distinguish already-encrypted values from plaintext — important during a +rolling migration where old rows may still be plaintext. + +Safe defaults +------------- +If ``AUTOMATION_SECRET_KEY`` / ``OH_SECRET_KEY`` is absent, ``get_cipher()`` +returns ``None`` and the callers fall back to storing/reading plaintext. A +warning is logged once at import time so operators are not surprised. This +preserves backward-compat for deployments that haven't yet set the key. +""" + +import hashlib +import logging +import os +from base64 import b64encode + +from cryptography.fernet import Fernet, InvalidToken + + +logger = logging.getLogger("automation.utils.cipher") + +# Fernet tokens always start with this 6-char prefix (base64 of 0x80 version byte). +# Using a 6-char prefix avoids collisions with realistic base64 plaintext values. +FERNET_TOKEN_PREFIX = "gAAAAA" + +# Env-var names checked in order of preference +_KEY_ENV_VARS = ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY") + + +class Cipher: + """Symmetric Fernet cipher with SHA-256 key derivation. + + Matches the interface of ``openhands.sdk.utils.cipher.Cipher`` so that the + same encrypted blobs can be shared between the SDK and the automation + service if they share the same key. + """ + + def __init__(self, secret_key: str) -> None: + self._secret_key = secret_key + self._fernet: Fernet | None = None + + def encrypt(self, plaintext: str) -> str: + """Encrypt a string and return the Fernet token as a str.""" + return self._get_fernet().encrypt(plaintext.encode()).decode() + + def decrypt(self, ciphertext: str) -> str | None: + """Decrypt a Fernet token, returning None if decryption fails. + + A ``None`` return (instead of an exception) matches the SDK convention + and allows callers to fall back gracefully — e.g. when rows were + encrypted with a different key or were never encrypted. + """ + try: + return self._get_fernet().decrypt(ciphertext.encode()).decode() + except (InvalidToken, Exception) as exc: + logger.warning( + "Failed to decrypt value (returning None): %s. " + "This may occur when loading data encrypted with a different key " + "or when migrating from unencrypted storage.", + exc, + ) + return None + + def is_ciphertext(self, value: str) -> bool: + """Return True if *value* is a Fernet token (already encrypted).""" + return value.startswith(FERNET_TOKEN_PREFIX) + + def _get_fernet(self) -> Fernet: + if self._fernet is None: + key = b64encode(hashlib.sha256(self._secret_key.encode()).digest()) + self._fernet = Fernet(key) + return self._fernet + + +def get_cipher() -> Cipher | None: + """Return a ``Cipher`` built from the first available key env var, or None. + + Checks ``AUTOMATION_SECRET_KEY`` then ``OH_SECRET_KEY``. If neither is + set, returns ``None`` — callers should store/read values as plaintext and + log an appropriate warning. + """ + for env_var in _KEY_ENV_VARS: + key = os.getenv(env_var) + if key: + return Cipher(key) + return None diff --git a/openhands/automation/utils/encrypted_fields.py b/openhands/automation/utils/encrypted_fields.py new file mode 100644 index 0000000..6769859 --- /dev/null +++ b/openhands/automation/utils/encrypted_fields.py @@ -0,0 +1,158 @@ +"""SQLAlchemy TypeDecorators for application-layer encryption. + +Provides two column types that transparently encrypt/decrypt values using the +``Cipher`` from ``openhands.automation.utils.cipher``: + +``EncryptedString`` + For single string columns (e.g. ``app_token``). Stores Fernet ciphertext + when a key is configured; falls back to plaintext otherwise. + +``EncryptedJSONHeaders`` + For ``dict[str, str]`` header columns. Only encrypts values whose *key* + matches sensitive-header patterns (``Authorization``, ``Cookie``, + ``X-Api-Key``, etc.) to match the behaviour of ``LookupSecret._serialize_secrets`` + in the OpenHands SDK. Non-sensitive headers are stored unencrypted. + +Migration safety +---------------- +Both decoders check ``Cipher.is_ciphertext`` before attempting decryption, so +existing plaintext rows continue to work after the key is introduced — no data +migration is required. Once a row is updated its value will be re-encrypted. + +If ``AUTOMATION_SECRET_KEY`` / ``OH_SECRET_KEY`` is absent a one-time WARNING +is emitted and values are stored as plaintext (preserving current behaviour). +""" + +import logging + +from sqlalchemy import String +from sqlalchemy.types import JSON, TypeDecorator + +from openhands.automation.utils.cipher import Cipher, get_cipher + + +logger = logging.getLogger("automation.utils.encrypted_fields") + +# Header-key patterns whose *values* are considered sensitive. +# Mirrors ``SECRET_KEY_PATTERNS`` from ``openhands.sdk.utils.redact``. +_SECRET_HEADER_PATTERNS = frozenset( + { + "AUTHORIZATION", + "COOKIE", + "CREDENTIAL", + "KEY", + "PASSWORD", + "SECRET", + "SESSION", + "TOKEN", + } +) + +_warned_no_cipher = False # emit the "no key configured" warning only once + + +def _warn_no_cipher(field: str) -> None: + global _warned_no_cipher + if not _warned_no_cipher: + logger.warning( + "AUTOMATION_SECRET_KEY / OH_SECRET_KEY is not set — sensitive " + "field '%s' (and others) will be stored as plaintext. " + "Set the key to enable encryption at rest.", + field, + ) + _warned_no_cipher = True + + +def _is_secret_header(key: str) -> bool: + """Return True if the header key name indicates a sensitive value.""" + upper = key.upper() + return any(pattern in upper for pattern in _SECRET_HEADER_PATTERNS) + + +class EncryptedString(TypeDecorator): + """A ``String`` column that is transparently encrypted/decrypted. + + Matches the per-field encryption pattern used by ``StaticSecret.value`` + in the OpenHands SDK: the Fernet cipher is applied on the way in and out + of the database; the ORM always works with plaintext strings. + + If no cipher key is configured the column behaves as a plain ``String``. + """ + + impl = String + cache_ok = True + + def process_bind_param(self, value: str | None, dialect) -> str | None: + """Encrypt on the way TO the database.""" + if value is None: + return None + cipher: Cipher | None = get_cipher() + if cipher is None: + _warn_no_cipher(self.__class__.__name__) + return value + return cipher.encrypt(value) + + def process_result_value(self, value: str | None, dialect) -> str | None: + """Decrypt on the way FROM the database.""" + if value is None: + return None + cipher: Cipher | None = get_cipher() + if cipher is None: + return value # stored as plaintext (no key at write time) + if cipher.is_ciphertext(value): + return cipher.decrypt(value) + return value # plaintext row written before key was introduced + + +class EncryptedJSONHeaders(TypeDecorator): + """A ``JSON`` column storing ``dict[str, str]`` headers. + + Only the values of *sensitive* header keys (matching + ``_SECRET_HEADER_PATTERNS``) are encrypted, mirroring the behaviour of + ``LookupSecret._serialize_secrets`` / ``_validate_secrets`` in the + OpenHands SDK. Non-sensitive headers (e.g. ``Content-Type``) are stored + as-is to keep the stored document human-readable in non-critical cases. + """ + + impl = JSON + cache_ok = True + + def process_bind_param( + self, value: dict | None, dialect + ) -> dict | None: + """Encrypt sensitive header values on the way TO the database.""" + if not value: + return value + cipher: Cipher | None = get_cipher() + if cipher is None: + _warn_no_cipher("headers") + return value + result: dict = {} + for k, v in value.items(): + if _is_secret_header(k) and isinstance(v, str) and v: + result[k] = cipher.encrypt(v) + else: + result[k] = v + return result + + def process_result_value( + self, value: dict | None, dialect + ) -> dict | None: + """Decrypt sensitive header values on the way FROM the database.""" + if not value: + return value + cipher: Cipher | None = get_cipher() + if cipher is None: + return value + result: dict = {} + for k, v in value.items(): + if ( + _is_secret_header(k) + and isinstance(v, str) + and cipher.is_ciphertext(v) + ): + decrypted = cipher.decrypt(v) + result[k] = decrypted if decrypted is not None else v + else: + result[k] = v + return result diff --git a/openhands/automation/websocket_source_router.py b/openhands/automation/websocket_source_router.py index c8c1227..ef38bb7 100644 --- a/openhands/automation/websocket_source_router.py +++ b/openhands/automation/websocket_source_router.py @@ -179,6 +179,21 @@ async def update_websocket_source( for field, value in update_data.items(): setattr(source, field, value) + # Validate kind-specific required fields remain set after applying updates. + # This guards against a caller clearing url on a generic source or + # app_token on a slack source, which would cause runtime failures in the + # SocketManager on the next connect attempt. + if source.kind == "generic" and not source.url: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="url is required for generic WebSocket sources and cannot be cleared", + ) + if source.kind == "slack" and not source.app_token: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="app_token is required for Slack WebSocket sources and cannot be cleared", + ) + await session.commit() await session.refresh(source) diff --git a/pyproject.toml b/pyproject.toml index 9b03927..2d71327 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "aiosqlite>=0.20", "alembic>=1.14", "asyncpg>=0.30", + "cryptography>=42", "boto3>=1.35", "cachetools>=7.0.5", "cloud-sql-python-connector[asyncpg]>=1.16", diff --git a/tests/test_cipher_and_encryption.py b/tests/test_cipher_and_encryption.py new file mode 100644 index 0000000..03fee4b --- /dev/null +++ b/tests/test_cipher_and_encryption.py @@ -0,0 +1,328 @@ +"""Tests for application-layer encryption: Cipher, EncryptedString, EncryptedJSONHeaders. + +All tests in this module are pure unit tests — no database or Docker required. +""" + +import os +from unittest.mock import patch + +import pytest + +from openhands.automation.utils.cipher import ( + FERNET_TOKEN_PREFIX, + Cipher, + get_cipher, +) +from openhands.automation.utils.encrypted_fields import ( + EncryptedJSONHeaders, + EncryptedString, + _is_secret_header, +) + + +TEST_KEY = "test-secret-key-for-automation-service" + + +# --------------------------------------------------------------------------- +# Cipher +# --------------------------------------------------------------------------- + + +class TestCipher: + def test_encrypt_decrypt_roundtrip(self): + cipher = Cipher(TEST_KEY) + plaintext = "xapp-1-AAAAAAAAA-1111111111-aaaaaaaaaaaaaaaa" + ciphertext = cipher.encrypt(plaintext) + assert cipher.decrypt(ciphertext) == plaintext + + def test_ciphertext_is_string(self): + cipher = Cipher(TEST_KEY) + ct = cipher.encrypt("hello") + assert isinstance(ct, str) + + def test_ciphertext_has_fernet_prefix(self): + cipher = Cipher(TEST_KEY) + ct = cipher.encrypt("hello") + assert cipher.is_ciphertext(ct) + assert ct.startswith(FERNET_TOKEN_PREFIX) + + def test_plaintext_is_not_ciphertext(self): + cipher = Cipher(TEST_KEY) + assert not cipher.is_ciphertext("xapp-1-AAAAAAAAA") + assert not cipher.is_ciphertext("Bearer token123") + assert not cipher.is_ciphertext("") + + def test_decrypt_invalid_returns_none(self): + cipher = Cipher(TEST_KEY) + result = cipher.decrypt("not-a-valid-fernet-token") + assert result is None + + def test_decrypt_wrong_key_returns_none(self): + cipher1 = Cipher("key-one") + cipher2 = Cipher("key-two") + ct = cipher1.encrypt("secret") + assert cipher2.decrypt(ct) is None + + def test_different_plaintexts_produce_different_ciphertexts(self): + cipher = Cipher(TEST_KEY) + ct1 = cipher.encrypt("secret-a") + ct2 = cipher.encrypt("secret-b") + assert ct1 != ct2 + + def test_same_plaintext_produces_different_ciphertexts(self): + # Fernet uses random nonces — two encryptions of the same plaintext differ + cipher = Cipher(TEST_KEY) + ct1 = cipher.encrypt("same") + ct2 = cipher.encrypt("same") + assert ct1 != ct2 + assert cipher.decrypt(ct1) == cipher.decrypt(ct2) == "same" + + +class TestGetCipher: + def test_returns_cipher_when_automation_key_set(self): + with patch.dict(os.environ, {"AUTOMATION_SECRET_KEY": TEST_KEY}, clear=False): + cipher = get_cipher() + assert cipher is not None + assert isinstance(cipher, Cipher) + + def test_falls_back_to_oh_secret_key(self): + env = {k: v for k, v in os.environ.items() if k not in ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY")} + env["OH_SECRET_KEY"] = TEST_KEY + with patch.dict(os.environ, env, clear=True): + cipher = get_cipher() + assert cipher is not None + + def test_returns_none_when_no_key_set(self): + env = {k: v for k, v in os.environ.items() if k not in ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY")} + with patch.dict(os.environ, env, clear=True): + cipher = get_cipher() + assert cipher is None + + def test_automation_key_takes_precedence_over_oh_key(self): + env = {k: v for k, v in os.environ.items()} + env["AUTOMATION_SECRET_KEY"] = "automation-key" + env["OH_SECRET_KEY"] = "oh-key" + with patch.dict(os.environ, env, clear=True): + cipher = get_cipher() + assert cipher is not None + # automation key should be used + ct = cipher.encrypt("value") + assert Cipher("automation-key").decrypt(ct) == "value" + + +# --------------------------------------------------------------------------- +# _is_secret_header +# --------------------------------------------------------------------------- + + +class TestIsSecretHeader: + @pytest.mark.parametrize( + "header", + [ + "Authorization", + "AUTHORIZATION", + "authorization", + "X-Api-Key", + "x-api-key", + "X-SECRET-HEADER", + "Cookie", + "cookie", + "password", + "X-Auth-Token", + "session-token", + ], + ) + def test_secret_headers_detected(self, header): + assert _is_secret_header(header) + + @pytest.mark.parametrize( + "header", + [ + "Content-Type", + "Accept", + "User-Agent", + "X-Request-ID", + "X-Correlation-ID", + "Cache-Control", + ], + ) + def test_non_secret_headers_not_detected(self, header): + assert not _is_secret_header(header) + + +# --------------------------------------------------------------------------- +# EncryptedString TypeDecorator +# --------------------------------------------------------------------------- + + +class TestEncryptedString: + def _make_col(self): + return EncryptedString(255) + + def test_encrypt_on_bind_param(self): + col = self._make_col() + with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: + cipher = Cipher(TEST_KEY) + mock.return_value = cipher + result = col.process_bind_param("xapp-token", dialect=None) + assert result is not None + assert cipher.is_ciphertext(result) + + def test_decrypt_on_result_value(self): + col = self._make_col() + cipher = Cipher(TEST_KEY) + ciphertext = cipher.encrypt("xapp-token") + with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: + mock.return_value = cipher + result = col.process_result_value(ciphertext, dialect=None) + assert result == "xapp-token" + + def test_none_passthrough(self): + col = self._make_col() + assert col.process_bind_param(None, dialect=None) is None + assert col.process_result_value(None, dialect=None) is None + + def test_no_cipher_stores_plaintext(self): + col = self._make_col() + with patch("openhands.automation.utils.encrypted_fields.get_cipher", return_value=None): + result = col.process_bind_param("xapp-token", dialect=None) + assert result == "xapp-token" + + def test_no_cipher_reads_plaintext(self): + col = self._make_col() + with patch("openhands.automation.utils.encrypted_fields.get_cipher", return_value=None): + result = col.process_result_value("xapp-token", dialect=None) + assert result == "xapp-token" + + def test_read_plaintext_row_with_cipher_present(self): + """Plaintext rows written before key was set should still be readable.""" + col = self._make_col() + cipher = Cipher(TEST_KEY) + with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: + mock.return_value = cipher + # Value does not have Fernet prefix → returned as-is + result = col.process_result_value("xapp-1-old-plaintext-token", dialect=None) + assert result == "xapp-1-old-plaintext-token" + + def test_roundtrip_bind_then_result(self): + col = self._make_col() + cipher = Cipher(TEST_KEY) + with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: + mock.return_value = cipher + encrypted = col.process_bind_param("my-secret", dialect=None) + decrypted = col.process_result_value(encrypted, dialect=None) + assert decrypted == "my-secret" + + +# --------------------------------------------------------------------------- +# EncryptedJSONHeaders TypeDecorator +# --------------------------------------------------------------------------- + + +class TestEncryptedJSONHeaders: + def _make_col(self): + return EncryptedJSONHeaders() + + def test_encrypts_auth_header_only(self): + col = self._make_col() + cipher = Cipher(TEST_KEY) + headers = { + "Authorization": "Bearer secret-token", + "Content-Type": "application/json", + "X-Request-ID": "req-123", + } + with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: + mock.return_value = cipher + stored = col.process_bind_param(headers, dialect=None) + + assert cipher.is_ciphertext(stored["Authorization"]) + assert stored["Content-Type"] == "application/json" + assert stored["X-Request-ID"] == "req-123" + + def test_decrypts_auth_header_only(self): + col = self._make_col() + cipher = Cipher(TEST_KEY) + headers = { + "Authorization": "Bearer secret-token", + "Content-Type": "application/json", + } + with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: + mock.return_value = cipher + stored = col.process_bind_param(headers, dialect=None) + restored = col.process_result_value(stored, dialect=None) + + assert restored["Authorization"] == "Bearer secret-token" + assert restored["Content-Type"] == "application/json" + + def test_none_passthrough(self): + col = self._make_col() + assert col.process_bind_param(None, dialect=None) is None + assert col.process_result_value(None, dialect=None) is None + + def test_empty_dict_passthrough(self): + col = self._make_col() + assert col.process_bind_param({}, dialect=None) == {} + + def test_no_cipher_stores_plaintext(self): + col = self._make_col() + headers = {"Authorization": "Bearer token", "X-Api-Key": "key123"} + with patch("openhands.automation.utils.encrypted_fields.get_cipher", return_value=None): + result = col.process_bind_param(headers, dialect=None) + assert result == headers + + def test_plaintext_headers_readable_after_key_introduced(self): + """Rows stored before key was set should read back cleanly.""" + col = self._make_col() + cipher = Cipher(TEST_KEY) + plaintext_stored = {"Authorization": "Bearer old-token"} + with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: + mock.return_value = cipher + result = col.process_result_value(plaintext_stored, dialect=None) + # Value doesn't have Fernet prefix → returned as-is + assert result["Authorization"] == "Bearer old-token" + + def test_multiple_secret_headers_encrypted(self): + col = self._make_col() + cipher = Cipher(TEST_KEY) + headers = { + "Authorization": "Bearer tok", + "X-Api-Key": "apikey123", + "Cookie": "session=abc", + "Accept": "application/json", + } + with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: + mock.return_value = cipher + stored = col.process_bind_param(headers, dialect=None) + + assert cipher.is_ciphertext(stored["Authorization"]) + assert cipher.is_ciphertext(stored["X-Api-Key"]) + assert cipher.is_ciphertext(stored["Cookie"]) + assert stored["Accept"] == "application/json" + + +# --------------------------------------------------------------------------- +# Update endpoint kind-specific validation +# --------------------------------------------------------------------------- + + +class TestUpdateEndpointKindValidation: + """Schema-level tests for the update endpoint's kind constraint.""" + + def test_url_field_present_in_update_schema(self): + from openhands.automation.schemas import WebSocketSourceUpdate + + u = WebSocketSourceUpdate(url="wss://new.example.com") + assert u.url == "wss://new.example.com" + + def test_app_token_field_present_in_update_schema(self): + from openhands.automation.schemas import WebSocketSourceUpdate + + u = WebSocketSourceUpdate(app_token="xapp-1-NEW") + assert u.app_token == "xapp-1-NEW" + + def test_url_can_be_set_to_none_in_schema(self): + """Schema accepts None (validation is enforced at the router layer).""" + from openhands.automation.schemas import WebSocketSourceUpdate + + u = WebSocketSourceUpdate(url=None) + assert u.url is None diff --git a/uv.lock b/uv.lock index c4f8b8e..c4fddaa 100644 --- a/uv.lock +++ b/uv.lock @@ -2168,6 +2168,7 @@ dependencies = [ { name = "cachetools" }, { name = "cloud-sql-python-connector", extra = ["asyncpg"] }, { name = "croniter" }, + { name = "cryptography" }, { name = "fastapi" }, { name = "google-cloud-storage" }, { name = "httpx" }, @@ -2207,6 +2208,7 @@ requires-dist = [ { name = "cachetools", specifier = ">=7.0.5" }, { name = "cloud-sql-python-connector", extras = ["asyncpg"], specifier = ">=1.16" }, { name = "croniter", specifier = ">=2" }, + { name = "cryptography", specifier = ">=42" }, { name = "fastapi", specifier = ">=0.115" }, { name = "google-cloud-storage", specifier = ">=2.18" }, { name = "httpx", specifier = ">=0.27" }, From 3c38d49acd6e32aec295088a435e8177e82355a6 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 21 May 2026 22:25:17 -0600 Subject: [PATCH 3/5] fix: resolve ruff lint/format failures from CI Fixes all pre-commit failures caught by ci/backend: - E501: shorten error-message strings in websocket_source_router.py - F841: remove unused captured_payload variable in test - ARG002: remove unused ack_ws/ack_id params from _dispatch (and all call sites); prefix dialect params with _ in TypeDecorators - ruff format: reformat app.py, models.py, encrypted_fields.py, test_cipher_and_encryption.py, test_websocket_source_router.py Co-authored-by: openhands --- openhands/automation/app.py | 4 +- openhands/automation/models.py | 8 +-- openhands/automation/socket_manager.py | 40 +++++------- .../automation/utils/encrypted_fields.py | 18 ++--- .../automation/websocket_source_router.py | 4 +- tests/test_cipher_and_encryption.py | 65 ++++++++++++------- tests/test_websocket_source_router.py | 37 ++++------- 7 files changed, 80 insertions(+), 96 deletions(-) diff --git a/openhands/automation/app.py b/openhands/automation/app.py index 1c2e418..6ac0272 100644 --- a/openhands/automation/app.py +++ b/openhands/automation/app.py @@ -30,7 +30,9 @@ from openhands.automation.uploads import router as uploads_router from openhands.automation.watchdog import watchdog_loop from openhands.automation.webhook_router import router as webhook_router -from openhands.automation.websocket_source_router import router as websocket_source_router +from openhands.automation.websocket_source_router import ( + router as websocket_source_router, +) logger = logging.getLogger("automation.app") diff --git a/openhands/automation/models.py b/openhands/automation/models.py index c6debb9..6d8123c 100644 --- a/openhands/automation/models.py +++ b/openhands/automation/models.py @@ -420,9 +420,7 @@ class OutboundWebSocketSource(Base): # Slack App-Level Token (xapp-…). Required for Socket Mode. # Encrypted at rest via EncryptedString using AUTOMATION_SECRET_KEY / # OH_SECRET_KEY. - app_token: Mapped[str | None] = mapped_column( - EncryptedString(255), nullable=True - ) + app_token: Mapped[str | None] = mapped_column(EncryptedString(255), nullable=True) # --- Runtime state (managed by SocketManager, not by the API) --- @@ -452,7 +450,5 @@ class OutboundWebSocketSource(Base): ) __table_args__ = ( - Index( - "ix_outbound_ws_sources_org_source", "org_id", "source", unique=True - ), + Index("ix_outbound_ws_sources_org_source", "org_id", "source", unique=True), ) diff --git a/openhands/automation/socket_manager.py b/openhands/automation/socket_manager.py index cfa39a0..fd6bf9e 100644 --- a/openhands/automation/socket_manager.py +++ b/openhands/automation/socket_manager.py @@ -43,7 +43,10 @@ from openhands.automation.filter_eval import evaluate_filter from openhands.automation.models import OutboundWebSocketSource, WebSocketStatus from openhands.automation.trigger_matcher import matches_trigger -from openhands.automation.utils.webhook import create_automation_run, get_event_automations +from openhands.automation.utils.webhook import ( + create_automation_run, + get_event_automations, +) logger = logging.getLogger("automation.socket_manager") @@ -121,9 +124,7 @@ async def on_source_deleted(self, source_id: uuid.UUID) -> None: # Internal helpers # ------------------------------------------------------------------ - def _start_task( - self, source_id: uuid.UUID, org_id: uuid.UUID, kind: str - ) -> None: + def _start_task(self, source_id: uuid.UUID, org_id: uuid.UUID, kind: str) -> None: task = asyncio.create_task( self._run_source_loop(source_id, org_id, kind), name=f"ws-source-{source_id}", @@ -193,8 +194,7 @@ async def _record_event(self, source_id: uuid.UUID) -> None: raise except Exception: logger.exception( - "Failed to update last_event_at source_id=%s — " - "non-fatal, continuing.", + "Failed to update last_event_at source_id=%s — non-fatal, continuing.", source_id, ) @@ -231,8 +231,7 @@ async def _run_source_loop( if failures >= _MAX_FAILURES: logger.error( - "WebSocket source exceeded max failures, pausing " - "source_id=%s", + "WebSocket source exceeded max failures, pausing source_id=%s", source_id, ) await self._set_status( @@ -297,11 +296,13 @@ async def _receive_generic( try: msg = json.loads(raw) except json.JSONDecodeError: - logger.debug("Non-JSON message from source_id=%s, skipping", source.id) + logger.debug( + "Non-JSON message from source_id=%s, skipping", source.id + ) continue await self._record_event(source.id) - await self._dispatch(source, org_id, msg, ack_ws=None, ack_id=None) + await self._dispatch(source, org_id, msg) async def _receive_slack( self, source: OutboundWebSocketSource, org_id: uuid.UUID @@ -316,24 +317,18 @@ async def _receive_slack( """ app_token = source.app_token if not app_token: - raise ValueError( - f"Slack source {source.id} has no app_token configured" - ) + raise ValueError(f"Slack source {source.id} has no app_token configured") # Fetch a fresh connection URL from Slack wss_url = await _slack_open_connection(app_token) - logger.info( - "Obtained Slack Socket Mode URL for source_id=%s", source.id - ) + logger.info("Obtained Slack Socket Mode URL for source_id=%s", source.id) async with websockets.connect(wss_url) as ws: # Slack sends a hello message immediately on connect hello_raw = await ws.recv() hello = json.loads(hello_raw) if hello.get("type") != "hello": - raise RuntimeError( - f"Expected Slack hello, got: {hello.get('type')!r}" - ) + raise RuntimeError(f"Expected Slack hello, got: {hello.get('type')!r}") await self._set_status( source.id, @@ -369,7 +364,7 @@ async def _receive_slack( continue await self._record_event(source.id) - await self._dispatch(source, org_id, msg, ack_ws=None, ack_id=None) + await self._dispatch(source, org_id, msg) # ------------------------------------------------------------------ # Dispatch pipeline (shared by all kinds) @@ -380,8 +375,6 @@ async def _dispatch( source: OutboundWebSocketSource, org_id: uuid.UUID, raw_msg: dict[str, Any], - ack_ws: Any, # reserved for future per-message ack callbacks - ack_id: str | None, ) -> None: """Apply pre-filter, extract event key, unwrap, find automations, run. @@ -406,7 +399,8 @@ async def _dispatch( event_key = jmespath.search(source.event_key_expr, raw_msg) except Exception: logger.debug( - "event_key_expr failed for source_id=%s, dropping", source.id, + "event_key_expr failed for source_id=%s, dropping", + source.id, exc_info=True, ) return diff --git a/openhands/automation/utils/encrypted_fields.py b/openhands/automation/utils/encrypted_fields.py index 6769859..3ba152d 100644 --- a/openhands/automation/utils/encrypted_fields.py +++ b/openhands/automation/utils/encrypted_fields.py @@ -82,7 +82,7 @@ class EncryptedString(TypeDecorator): impl = String cache_ok = True - def process_bind_param(self, value: str | None, dialect) -> str | None: + def process_bind_param(self, value: str | None, _dialect) -> str | None: """Encrypt on the way TO the database.""" if value is None: return None @@ -92,7 +92,7 @@ def process_bind_param(self, value: str | None, dialect) -> str | None: return value return cipher.encrypt(value) - def process_result_value(self, value: str | None, dialect) -> str | None: + def process_result_value(self, value: str | None, _dialect) -> str | None: """Decrypt on the way FROM the database.""" if value is None: return None @@ -117,9 +117,7 @@ class EncryptedJSONHeaders(TypeDecorator): impl = JSON cache_ok = True - def process_bind_param( - self, value: dict | None, dialect - ) -> dict | None: + def process_bind_param(self, value: dict | None, _dialect) -> dict | None: """Encrypt sensitive header values on the way TO the database.""" if not value: return value @@ -135,9 +133,7 @@ def process_bind_param( result[k] = v return result - def process_result_value( - self, value: dict | None, dialect - ) -> dict | None: + def process_result_value(self, value: dict | None, _dialect) -> dict | None: """Decrypt sensitive header values on the way FROM the database.""" if not value: return value @@ -146,11 +142,7 @@ def process_result_value( return value result: dict = {} for k, v in value.items(): - if ( - _is_secret_header(k) - and isinstance(v, str) - and cipher.is_ciphertext(v) - ): + if _is_secret_header(k) and isinstance(v, str) and cipher.is_ciphertext(v): decrypted = cipher.decrypt(v) result[k] = decrypted if decrypted is not None else v else: diff --git a/openhands/automation/websocket_source_router.py b/openhands/automation/websocket_source_router.py index ef38bb7..91f45e7 100644 --- a/openhands/automation/websocket_source_router.py +++ b/openhands/automation/websocket_source_router.py @@ -186,12 +186,12 @@ async def update_websocket_source( if source.kind == "generic" and not source.url: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="url is required for generic WebSocket sources and cannot be cleared", + detail="url is required for generic sources and cannot be cleared", ) if source.kind == "slack" and not source.app_token: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail="app_token is required for Slack WebSocket sources and cannot be cleared", + detail="app_token is required for slack sources and cannot be cleared", ) await session.commit() diff --git a/tests/test_cipher_and_encryption.py b/tests/test_cipher_and_encryption.py index 03fee4b..ed57bd6 100644 --- a/tests/test_cipher_and_encryption.py +++ b/tests/test_cipher_and_encryption.py @@ -1,6 +1,7 @@ -"""Tests for application-layer encryption: Cipher, EncryptedString, EncryptedJSONHeaders. +"""Tests for application-layer encryption. -All tests in this module are pure unit tests — no database or Docker required. +Covers Cipher, EncryptedString, and EncryptedJSONHeaders. +Pure unit tests — no database or Docker required. """ import os @@ -86,14 +87,22 @@ def test_returns_cipher_when_automation_key_set(self): assert isinstance(cipher, Cipher) def test_falls_back_to_oh_secret_key(self): - env = {k: v for k, v in os.environ.items() if k not in ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY")} + env = { + k: v + for k, v in os.environ.items() + if k not in ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY") + } env["OH_SECRET_KEY"] = TEST_KEY with patch.dict(os.environ, env, clear=True): cipher = get_cipher() assert cipher is not None def test_returns_none_when_no_key_set(self): - env = {k: v for k, v in os.environ.items() if k not in ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY")} + env = { + k: v + for k, v in os.environ.items() + if k not in ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY") + } with patch.dict(os.environ, env, clear=True): cipher = get_cipher() assert cipher is None @@ -164,7 +173,7 @@ def test_encrypt_on_bind_param(self): with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: cipher = Cipher(TEST_KEY) mock.return_value = cipher - result = col.process_bind_param("xapp-token", dialect=None) + result = col.process_bind_param("xapp-token", None) assert result is not None assert cipher.is_ciphertext(result) @@ -174,24 +183,28 @@ def test_decrypt_on_result_value(self): ciphertext = cipher.encrypt("xapp-token") with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: mock.return_value = cipher - result = col.process_result_value(ciphertext, dialect=None) + result = col.process_result_value(ciphertext, None) assert result == "xapp-token" def test_none_passthrough(self): col = self._make_col() - assert col.process_bind_param(None, dialect=None) is None - assert col.process_result_value(None, dialect=None) is None + assert col.process_bind_param(None, None) is None + assert col.process_result_value(None, None) is None def test_no_cipher_stores_plaintext(self): col = self._make_col() - with patch("openhands.automation.utils.encrypted_fields.get_cipher", return_value=None): - result = col.process_bind_param("xapp-token", dialect=None) + with patch( + "openhands.automation.utils.encrypted_fields.get_cipher", return_value=None + ): + result = col.process_bind_param("xapp-token", None) assert result == "xapp-token" def test_no_cipher_reads_plaintext(self): col = self._make_col() - with patch("openhands.automation.utils.encrypted_fields.get_cipher", return_value=None): - result = col.process_result_value("xapp-token", dialect=None) + with patch( + "openhands.automation.utils.encrypted_fields.get_cipher", return_value=None + ): + result = col.process_result_value("xapp-token", None) assert result == "xapp-token" def test_read_plaintext_row_with_cipher_present(self): @@ -201,7 +214,7 @@ def test_read_plaintext_row_with_cipher_present(self): with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: mock.return_value = cipher # Value does not have Fernet prefix → returned as-is - result = col.process_result_value("xapp-1-old-plaintext-token", dialect=None) + result = col.process_result_value("xapp-1-old-plaintext-token", None) assert result == "xapp-1-old-plaintext-token" def test_roundtrip_bind_then_result(self): @@ -209,8 +222,8 @@ def test_roundtrip_bind_then_result(self): cipher = Cipher(TEST_KEY) with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: mock.return_value = cipher - encrypted = col.process_bind_param("my-secret", dialect=None) - decrypted = col.process_result_value(encrypted, dialect=None) + encrypted = col.process_bind_param("my-secret", None) + decrypted = col.process_result_value(encrypted, None) assert decrypted == "my-secret" @@ -233,7 +246,7 @@ def test_encrypts_auth_header_only(self): } with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: mock.return_value = cipher - stored = col.process_bind_param(headers, dialect=None) + stored = col.process_bind_param(headers, None) assert cipher.is_ciphertext(stored["Authorization"]) assert stored["Content-Type"] == "application/json" @@ -248,26 +261,28 @@ def test_decrypts_auth_header_only(self): } with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: mock.return_value = cipher - stored = col.process_bind_param(headers, dialect=None) - restored = col.process_result_value(stored, dialect=None) + stored = col.process_bind_param(headers, None) + restored = col.process_result_value(stored, None) assert restored["Authorization"] == "Bearer secret-token" assert restored["Content-Type"] == "application/json" def test_none_passthrough(self): col = self._make_col() - assert col.process_bind_param(None, dialect=None) is None - assert col.process_result_value(None, dialect=None) is None + assert col.process_bind_param(None, None) is None + assert col.process_result_value(None, None) is None def test_empty_dict_passthrough(self): col = self._make_col() - assert col.process_bind_param({}, dialect=None) == {} + assert col.process_bind_param({}, None) == {} def test_no_cipher_stores_plaintext(self): col = self._make_col() headers = {"Authorization": "Bearer token", "X-Api-Key": "key123"} - with patch("openhands.automation.utils.encrypted_fields.get_cipher", return_value=None): - result = col.process_bind_param(headers, dialect=None) + with patch( + "openhands.automation.utils.encrypted_fields.get_cipher", return_value=None + ): + result = col.process_bind_param(headers, None) assert result == headers def test_plaintext_headers_readable_after_key_introduced(self): @@ -277,7 +292,7 @@ def test_plaintext_headers_readable_after_key_introduced(self): plaintext_stored = {"Authorization": "Bearer old-token"} with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: mock.return_value = cipher - result = col.process_result_value(plaintext_stored, dialect=None) + result = col.process_result_value(plaintext_stored, None) # Value doesn't have Fernet prefix → returned as-is assert result["Authorization"] == "Bearer old-token" @@ -292,7 +307,7 @@ def test_multiple_secret_headers_encrypted(self): } with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: mock.return_value = cipher - stored = col.process_bind_param(headers, dialect=None) + stored = col.process_bind_param(headers, None) assert cipher.is_ciphertext(stored["Authorization"]) assert cipher.is_ciphertext(stored["X-Api-Key"]) diff --git a/tests/test_websocket_source_router.py b/tests/test_websocket_source_router.py index d136b72..2579963 100644 --- a/tests/test_websocket_source_router.py +++ b/tests/test_websocket_source_router.py @@ -203,9 +203,7 @@ async def test_create_duplicate_source_returns_409( assert resp2.status_code == 409 assert "already exists" in resp2.json()["detail"] - async def test_create_notifies_socket_manager( - self, async_client, async_session - ): + async def test_create_notifies_socket_manager(self, async_client, async_session): """When a socket manager is present, it should be notified on create.""" from openhands.automation.app import app @@ -258,9 +256,7 @@ async def test_missing_kind_returns_422(self, async_client, _no_socket_manager): ) assert resp.status_code == 422 - async def test_slack_bad_token_returns_422( - self, async_client, _no_socket_manager - ): + async def test_slack_bad_token_returns_422(self, async_client, _no_socket_manager): resp = await async_client.post( BASE_URL, json={ @@ -355,9 +351,7 @@ async def test_get_nonexistent_returns_404(self, async_client): class TestUpdateWebSocketSource: - async def test_update_name_and_enabled( - self, async_client, _no_socket_manager - ): + async def test_update_name_and_enabled(self, async_client, _no_socket_manager): create = await async_client.post( BASE_URL, json={ @@ -427,7 +421,9 @@ async def test_delete_success(self, async_client, _no_socket_manager): get_resp = await async_client.get(f"{BASE_URL}/{source_id}") assert get_resp.status_code == 404 - async def test_delete_notifies_socket_manager(self, async_client, _no_socket_manager): + async def test_delete_notifies_socket_manager( + self, async_client, _no_socket_manager + ): from openhands.automation.app import app mock_sm = AsyncMock() @@ -447,9 +443,7 @@ async def test_delete_notifies_socket_manager(self, async_client, _no_socket_man mock_sm.reset_mock() await async_client.delete(f"{BASE_URL}/{source_id}") - mock_sm.on_source_deleted.assert_awaited_once_with( - uuid.UUID(source_id) - ) + mock_sm.on_source_deleted.assert_awaited_once_with(uuid.UUID(source_id)) finally: del app.state.socket_manager @@ -508,9 +502,7 @@ async def test_reconnect_disabled_source_returns_400( class TestSocketManagerDispatch: """Unit tests for the dispatch pipeline without real WebSocket connections.""" - async def test_dispatch_pre_filter_drops_non_matching( - self, async_session_factory - ): + async def test_dispatch_pre_filter_drops_non_matching(self, async_session_factory): """Events that fail filter_expr should be silently dropped.""" from openhands.automation.socket_manager import SocketManager @@ -532,9 +524,7 @@ async def test_dispatch_pre_filter_drops_non_matching( "openhands.automation.socket_manager.get_event_automations", new_callable=AsyncMock, ) as mock_get: - await sm._dispatch( - source, TEST_ORG_ID, {"type": "blocked"}, None, None - ) + await sm._dispatch(source, TEST_ORG_ID, {"type": "blocked"}) mock_get.assert_not_called() async def test_dispatch_unwraps_payload_via_payload_expr( @@ -568,8 +558,6 @@ async def test_dispatch_unwraps_payload_via_payload_expr( }, } - captured_payload = None - async def fake_get_event_automations(org_id, source_name, session): return [] @@ -577,9 +565,8 @@ async def fake_get_event_automations(org_id, source_name, session): "openhands.automation.socket_manager.get_event_automations", side_effect=fake_get_event_automations, ): - await sm._dispatch(source, TEST_ORG_ID, raw_msg, None, None) - # No match, but the key extraction + unwrap path was exercised. - # Verify via a spy that matches_trigger would have received the inner event. + await sm._dispatch(source, TEST_ORG_ID, raw_msg) + # No automations match, but the key extraction + unwrap path ran. async def test_dispatch_drops_non_string_event_key(self, async_session_factory): """If event_key_expr returns a non-string, the event should be dropped.""" @@ -604,7 +591,5 @@ async def test_dispatch_drops_non_string_event_key(self, async_session_factory): source, TEST_ORG_ID, {"type": "message", "metadata": {"key": "val"}}, - None, - None, ) mock_get.assert_not_called() From 55d1e8392945a5eba15f03539176566cada44420 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 21 May 2026 22:32:13 -0600 Subject: [PATCH 4/5] fix: resolve remaining pre-commit failures (pyright + ruff) Three distinct issues caught by running pre-commit locally: ruff lint (import order) migrations/007: Boolean before JSON - ruff isort wanted JSON first. Applied auto-fix. pyright reportIncompatibleMethodOverride encrypted_fields.py: renaming the TypeDecorator dialect parameter to _dialect (to silence ruff ARG002) changed the parameter name relative to the SQLAlchemy base class, which pyright flags as an incompatible override in standard mode. Fix: restore the name to 'dialect' and annotate it as 'Any' instead. 'Any' is bidirectionally compatible with 'Dialect' in pyright, so the override check passes; noqa: ARG002 silences ruff. pyright reportOptionalSubscript (x10 in tests) process_bind_param/process_result_value return dict | None; tests were subscripting the result without a None guard, which pyright flags. Fix: added 'assert stored/restored/result is not None' before each subscript in the affected test cases. Co-authored-by: openhands --- migrations/versions/007_outbound_websocket_sources.py | 2 +- openhands/automation/utils/encrypted_fields.py | 9 +++++---- tests/test_cipher_and_encryption.py | 5 +++++ 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/migrations/versions/007_outbound_websocket_sources.py b/migrations/versions/007_outbound_websocket_sources.py index 50730e6..f9c2f4a 100644 --- a/migrations/versions/007_outbound_websocket_sources.py +++ b/migrations/versions/007_outbound_websocket_sources.py @@ -16,7 +16,7 @@ from collections.abc import Sequence from alembic import op -from sqlalchemy import Boolean, Column, DateTime, Enum, JSON, String, Text, Uuid, text +from sqlalchemy import JSON, Boolean, Column, DateTime, Enum, String, Text, Uuid, text revision: str = "007" diff --git a/openhands/automation/utils/encrypted_fields.py b/openhands/automation/utils/encrypted_fields.py index 3ba152d..5a40811 100644 --- a/openhands/automation/utils/encrypted_fields.py +++ b/openhands/automation/utils/encrypted_fields.py @@ -24,6 +24,7 @@ """ import logging +from typing import Any from sqlalchemy import String from sqlalchemy.types import JSON, TypeDecorator @@ -82,7 +83,7 @@ class EncryptedString(TypeDecorator): impl = String cache_ok = True - def process_bind_param(self, value: str | None, _dialect) -> str | None: + def process_bind_param(self, value: str | None, dialect: Any) -> str | None: # noqa: ARG002 """Encrypt on the way TO the database.""" if value is None: return None @@ -92,7 +93,7 @@ def process_bind_param(self, value: str | None, _dialect) -> str | None: return value return cipher.encrypt(value) - def process_result_value(self, value: str | None, _dialect) -> str | None: + def process_result_value(self, value: str | None, dialect: Any) -> str | None: # noqa: ARG002 """Decrypt on the way FROM the database.""" if value is None: return None @@ -117,7 +118,7 @@ class EncryptedJSONHeaders(TypeDecorator): impl = JSON cache_ok = True - def process_bind_param(self, value: dict | None, _dialect) -> dict | None: + def process_bind_param(self, value: dict | None, dialect: Any) -> dict | None: # noqa: ARG002 """Encrypt sensitive header values on the way TO the database.""" if not value: return value @@ -133,7 +134,7 @@ def process_bind_param(self, value: dict | None, _dialect) -> dict | None: result[k] = v return result - def process_result_value(self, value: dict | None, _dialect) -> dict | None: + def process_result_value(self, value: dict | None, dialect: Any) -> dict | None: # noqa: ARG002 """Decrypt sensitive header values on the way FROM the database.""" if not value: return value diff --git a/tests/test_cipher_and_encryption.py b/tests/test_cipher_and_encryption.py index ed57bd6..b9d72cd 100644 --- a/tests/test_cipher_and_encryption.py +++ b/tests/test_cipher_and_encryption.py @@ -248,6 +248,7 @@ def test_encrypts_auth_header_only(self): mock.return_value = cipher stored = col.process_bind_param(headers, None) + assert stored is not None assert cipher.is_ciphertext(stored["Authorization"]) assert stored["Content-Type"] == "application/json" assert stored["X-Request-ID"] == "req-123" @@ -262,8 +263,10 @@ def test_decrypts_auth_header_only(self): with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: mock.return_value = cipher stored = col.process_bind_param(headers, None) + assert stored is not None restored = col.process_result_value(stored, None) + assert restored is not None assert restored["Authorization"] == "Bearer secret-token" assert restored["Content-Type"] == "application/json" @@ -294,6 +297,7 @@ def test_plaintext_headers_readable_after_key_introduced(self): mock.return_value = cipher result = col.process_result_value(plaintext_stored, None) # Value doesn't have Fernet prefix → returned as-is + assert result is not None assert result["Authorization"] == "Bearer old-token" def test_multiple_secret_headers_encrypted(self): @@ -309,6 +313,7 @@ def test_multiple_secret_headers_encrypted(self): mock.return_value = cipher stored = col.process_bind_param(headers, None) + assert stored is not None assert cipher.is_ciphertext(stored["Authorization"]) assert cipher.is_ciphertext(stored["X-Api-Key"]) assert cipher.is_ciphertext(stored["Cookie"]) From 1e442d345adb34dbaca231af50732efd9896664a Mon Sep 17 00:00:00 2001 From: openhands Date: Fri, 22 May 2026 06:04:23 -0600 Subject: [PATCH 5/5] refactor: use SDK Cipher + DiscriminatedUnionMixin per review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two changes requested in code review: ## 1. Use openhands.sdk.utils.cipher.Cipher (not a local copy) - Delete utils/cipher.py — the automation service was duplicating the SDK's Cipher class verbatim. The SDK is already a declared dependency so there is no reason to maintain a second copy. - utils/encrypted_fields.py: import Cipher and FERNET_TOKEN_PREFIX from openhands.sdk.utils.cipher. Adapt callers to the SDK interface: · cipher.encrypt(SecretStr(value)) instead of cipher.encrypt(value) · cipher.try_decrypt_str(v) instead of cipher.decrypt(v) · v.startswith(FERNET_TOKEN_PREFIX) instead of cipher.is_ciphertext(v) - Remove cryptography>=42 from pyproject.toml — the SDK already brings this in transitively. ## 2. Use DiscriminatedUnionMixin instead of a hand-rolled discriminated union Following the SecretSource / StaticSecret / LookupSecret pattern: schemas.py - _WebSocketSourceBase(BaseModel) → WebSocketSourceCreate(DiscriminatedUnionMixin, ABC) s the computed 'kind' field (= class name) and the model_validator that dispatches to the correct subclass — no need for explicit 'kind: Literal[...]' fields or an Annotated union type alias. - GenericWebSocketSourceCreate → GenericWebSocketSource (kind = class name) - SlackWebSocketSourceCreate → SlackWebSocketSource (kind = class name) - Remove WebSocketSourceCreate = Annotated[...] union alias (not needed) - Move event_key_expr / payload_expr to the abstract base class (common to both kinds); SlackWebSocketSource overrides their defaults. Kind values change from lowercase slugs to class names: generic → GenericWebSocketSource slack → SlackWebSocketSource Updated everywhere: router, socket_manager, models comments, tests. Co-authored-by: openhands --- openhands/automation/models.py | 16 +-- openhands/automation/schemas.py | 69 +++++------- openhands/automation/socket_manager.py | 2 +- openhands/automation/utils/cipher.py | 103 ------------------ .../automation/utils/encrypted_fields.py | 53 ++++++--- .../automation/websocket_source_router.py | 4 +- pyproject.toml | 1 - tests/test_cipher_and_encryption.py | 99 +++++------------ tests/test_websocket_source_router.py | 78 +++++++------ uv.lock | 2 - 10 files changed, 139 insertions(+), 288 deletions(-) delete mode 100644 openhands/automation/utils/cipher.py diff --git a/openhands/automation/models.py b/openhands/automation/models.py index 6d8123c..42679d9 100644 --- a/openhands/automation/models.py +++ b/openhands/automation/models.py @@ -346,11 +346,11 @@ class OutboundWebSocketSource(Base): Two kinds are supported, selected via the ``kind`` discriminator column: - ``"generic"`` + ``"GenericWebSocketSource"`` Connects to a static ``wss://`` URL with optional HTTP headers. Suitable for any service that exposes a plain WebSocket endpoint. - ``"slack"`` + ``"SlackWebSocketSource"`` Connects to Slack's Socket Mode API. Requires a Slack App-Level Token (``xapp-…``). The connection URL is fetched dynamically by calling ``apps.connections.open`` before each connect attempt; Slack-specific @@ -381,7 +381,7 @@ class OutboundWebSocketSource(Base): # Must be unique per org (enforced by the unique index below). source: Mapped[str] = mapped_column(String(100), nullable=False) - # Discriminator: "generic" or "slack" + # Discriminator: "GenericWebSocketSource" or "SlackWebSocketSource" kind: Mapped[str] = mapped_column(String(50), nullable=False) enabled: Mapped[bool] = mapped_column(default=True, nullable=False) @@ -390,13 +390,13 @@ class OutboundWebSocketSource(Base): # Extracts the event-type key used for trigger.on pattern matching. # Defaults differ per kind and are set at the schema layer: - # generic → "type" - # slack → "payload.event.type" + # GenericWebSocketSource → "type" + # SlackWebSocketSource → "payload.event.type" event_key_expr: Mapped[str] = mapped_column(String(500), nullable=False) # Unwraps outer envelopes so the stored/filtered payload is the inner event. # None means pass the raw message through unchanged. - # slack default → "payload.event" + # SlackWebSocketSource default → "payload.event" payload_expr: Mapped[str | None] = mapped_column(String(500), nullable=True) # Connection-level pre-filter (JMESPath). Evaluated against the raw message @@ -404,7 +404,7 @@ class OutboundWebSocketSource(Base): # None means accept all events. filter_expr: Mapped[str | None] = mapped_column(Text, nullable=True) - # --- kind = "generic" fields --- + # --- kind = "GenericWebSocketSource fields --- # Static wss:// URL to connect to. Not encrypted (not a credential). url: Mapped[str | None] = mapped_column(Text, nullable=True) @@ -415,7 +415,7 @@ class OutboundWebSocketSource(Base): # OH_SECRET_KEY. Non-sensitive headers are stored as-is. headers: Mapped[dict | None] = mapped_column(EncryptedJSONHeaders, nullable=True) - # --- kind = "slack" fields --- + # --- kind = "SlackWebSocketSource fields --- # Slack App-Level Token (xapp-…). Required for Socket Mode. # Encrypted at rest via EncryptedString using AUTOMATION_SECRET_KEY / diff --git a/openhands/automation/schemas.py b/openhands/automation/schemas.py index 3e96949..ef41a04 100644 --- a/openhands/automation/schemas.py +++ b/openhands/automation/schemas.py @@ -2,6 +2,7 @@ import re import uuid +from abc import ABC from datetime import datetime from enum import StrEnum from typing import Annotated, Literal @@ -11,6 +12,7 @@ from openhands.automation.config import get_config from openhands.automation.constants import MODEL_PROFILE_PATTERN +from openhands.sdk.utils.models import DiscriminatedUnionMixin # Allowed URI schemes for tarball_path (includes internal upload scheme) @@ -587,8 +589,13 @@ def _validate_jmespath(v: str | None) -> str | None: return v -class _WebSocketSourceBase(BaseModel): - """Shared fields for all WebSocket source schemas.""" +class WebSocketSourceCreate(DiscriminatedUnionMixin, ABC): + """Abstract base for all outbound WebSocket source create schemas. + + Concrete subclasses (``GenericWebSocketSource``, ``SlackWebSocketSource``) + are selected automatically via ``DiscriminatedUnionMixin`` using the + ``kind`` computed field (equal to the class name). + """ model_config = ConfigDict(extra="forbid") @@ -628,16 +635,30 @@ def validate_source(cls, v: str) -> str: ) return v_lower - @field_validator("filter_expr") + event_key_expr: str = Field( + default="type", + max_length=500, + description="JMESPath expression to extract the event type from each message.", + ) + payload_expr: str | None = Field( + default=None, + max_length=500, + description=( + "Optional JMESPath expression to unwrap an outer envelope. " + "The result replaces the raw message as the event payload passed to " + "automations. None means use the raw message as-is." + ), + ) + + @field_validator("filter_expr", "event_key_expr", "payload_expr") @classmethod - def validate_filter_expr(cls, v: str | None) -> str | None: + def validate_jmespath_exprs(cls, v: str | None) -> str | None: return _validate_jmespath(v) -class GenericWebSocketSourceCreate(_WebSocketSourceBase): +class GenericWebSocketSource(WebSocketSourceCreate): """Create a generic outbound WebSocket source with a static URL.""" - kind: Literal["generic"] = "generic" url: str = Field(..., description="The wss:// URL to connect to.") headers: dict[str, str] | None = Field( default=None, @@ -647,20 +668,6 @@ class GenericWebSocketSourceCreate(_WebSocketSourceBase): "Treat values as sensitive credentials." ), ) - event_key_expr: str = Field( - default="type", - max_length=500, - description="JMESPath expression to extract the event type from each message.", - ) - payload_expr: str | None = Field( - default=None, - max_length=500, - description=( - "Optional JMESPath expression to unwrap an outer envelope. " - "The result replaces the raw message as the event payload passed to " - "automations. None means use the raw message as-is." - ), - ) @field_validator("url") @classmethod @@ -669,13 +676,8 @@ def validate_url(cls, v: str) -> str: raise ValueError("url must start with wss:// or ws://") return v - @field_validator("event_key_expr", "payload_expr") - @classmethod - def validate_jmespath(cls, v: str | None) -> str | None: - return _validate_jmespath(v) - -class SlackWebSocketSourceCreate(_WebSocketSourceBase): +class SlackWebSocketSource(WebSocketSourceCreate): """Create a Slack Socket Mode outbound WebSocket source. Uses ``apps.connections.open`` to obtain a fresh wss:// URL on each connect @@ -687,7 +689,6 @@ class SlackWebSocketSourceCreate(_WebSocketSourceBase): as the payload available to ``trigger.filter``. """ - kind: Literal["slack"] = "slack" app_token: str = Field( ..., description=( @@ -703,7 +704,7 @@ class SlackWebSocketSourceCreate(_WebSocketSourceBase): "Default extracts the inner Slack event type (e.g. 'message')." ), ) - payload_expr: str = Field( + payload_expr: str | None = Field( default="payload.event", max_length=500, description=( @@ -721,18 +722,6 @@ def validate_app_token(cls, v: str) -> str: ) return v - @field_validator("event_key_expr", "payload_expr") - @classmethod - def validate_jmespath(cls, v: str | None) -> str | None: - return _validate_jmespath(v) - - -# Discriminated union for the create request body -WebSocketSourceCreate = Annotated[ - GenericWebSocketSourceCreate | SlackWebSocketSourceCreate, - Field(discriminator="kind"), -] - class WebSocketSourceUpdate(BaseModel): """Request schema for updating a WebSocket source. diff --git a/openhands/automation/socket_manager.py b/openhands/automation/socket_manager.py index fd6bf9e..2b2ff8b 100644 --- a/openhands/automation/socket_manager.py +++ b/openhands/automation/socket_manager.py @@ -266,7 +266,7 @@ async def _connect_and_receive( if source is None: return - if kind == "slack": + if kind == "SlackWebSocketSource": await self._receive_slack(source, org_id) else: await self._receive_generic(source, org_id) diff --git a/openhands/automation/utils/cipher.py b/openhands/automation/utils/cipher.py deleted file mode 100644 index e98ef79..0000000 --- a/openhands/automation/utils/cipher.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Application-layer encryption for sensitive fields. - -This module provides the same Fernet-based cipher used by the OpenHands SDK -(``openhands.sdk.utils.cipher``) — adapted for use inside the automation -service where a shared key encrypts credentials before they are written to the -database. - -Key derivation --------------- -The raw ``AUTOMATION_SECRET_KEY`` (or ``OH_SECRET_KEY``) value is hashed with -SHA-256 and base64-encoded to produce a valid 256-bit Fernet key, so the env -var itself does not need to be exactly 32 bytes. - -Ciphertext detection --------------------- -Fernet tokens always start with the prefix ``"gAAAAA"`` (the base-64 encoding -of the ``\\x80`` version byte). ``Cipher.is_ciphertext`` exploits this to -distinguish already-encrypted values from plaintext — important during a -rolling migration where old rows may still be plaintext. - -Safe defaults -------------- -If ``AUTOMATION_SECRET_KEY`` / ``OH_SECRET_KEY`` is absent, ``get_cipher()`` -returns ``None`` and the callers fall back to storing/reading plaintext. A -warning is logged once at import time so operators are not surprised. This -preserves backward-compat for deployments that haven't yet set the key. -""" - -import hashlib -import logging -import os -from base64 import b64encode - -from cryptography.fernet import Fernet, InvalidToken - - -logger = logging.getLogger("automation.utils.cipher") - -# Fernet tokens always start with this 6-char prefix (base64 of 0x80 version byte). -# Using a 6-char prefix avoids collisions with realistic base64 plaintext values. -FERNET_TOKEN_PREFIX = "gAAAAA" - -# Env-var names checked in order of preference -_KEY_ENV_VARS = ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY") - - -class Cipher: - """Symmetric Fernet cipher with SHA-256 key derivation. - - Matches the interface of ``openhands.sdk.utils.cipher.Cipher`` so that the - same encrypted blobs can be shared between the SDK and the automation - service if they share the same key. - """ - - def __init__(self, secret_key: str) -> None: - self._secret_key = secret_key - self._fernet: Fernet | None = None - - def encrypt(self, plaintext: str) -> str: - """Encrypt a string and return the Fernet token as a str.""" - return self._get_fernet().encrypt(plaintext.encode()).decode() - - def decrypt(self, ciphertext: str) -> str | None: - """Decrypt a Fernet token, returning None if decryption fails. - - A ``None`` return (instead of an exception) matches the SDK convention - and allows callers to fall back gracefully — e.g. when rows were - encrypted with a different key or were never encrypted. - """ - try: - return self._get_fernet().decrypt(ciphertext.encode()).decode() - except (InvalidToken, Exception) as exc: - logger.warning( - "Failed to decrypt value (returning None): %s. " - "This may occur when loading data encrypted with a different key " - "or when migrating from unencrypted storage.", - exc, - ) - return None - - def is_ciphertext(self, value: str) -> bool: - """Return True if *value* is a Fernet token (already encrypted).""" - return value.startswith(FERNET_TOKEN_PREFIX) - - def _get_fernet(self) -> Fernet: - if self._fernet is None: - key = b64encode(hashlib.sha256(self._secret_key.encode()).digest()) - self._fernet = Fernet(key) - return self._fernet - - -def get_cipher() -> Cipher | None: - """Return a ``Cipher`` built from the first available key env var, or None. - - Checks ``AUTOMATION_SECRET_KEY`` then ``OH_SECRET_KEY``. If neither is - set, returns ``None`` — callers should store/read values as plaintext and - log an appropriate warning. - """ - for env_var in _KEY_ENV_VARS: - key = os.getenv(env_var) - if key: - return Cipher(key) - return None diff --git a/openhands/automation/utils/encrypted_fields.py b/openhands/automation/utils/encrypted_fields.py index 5a40811..fee4a11 100644 --- a/openhands/automation/utils/encrypted_fields.py +++ b/openhands/automation/utils/encrypted_fields.py @@ -1,7 +1,7 @@ """SQLAlchemy TypeDecorators for application-layer encryption. Provides two column types that transparently encrypt/decrypt values using the -``Cipher`` from ``openhands.automation.utils.cipher``: +``Cipher`` from ``openhands.sdk.utils.cipher``: ``EncryptedString`` For single string columns (e.g. ``app_token``). Stores Fernet ciphertext @@ -15,21 +15,23 @@ Migration safety ---------------- -Both decoders check ``Cipher.is_ciphertext`` before attempting decryption, so -existing plaintext rows continue to work after the key is introduced — no data -migration is required. Once a row is updated its value will be re-encrypted. +Both decoders check the Fernet token prefix (``FERNET_TOKEN_PREFIX``) before +attempting decryption, so existing plaintext rows continue to work after the +key is introduced — no data migration is required. If ``AUTOMATION_SECRET_KEY`` / ``OH_SECRET_KEY`` is absent a one-time WARNING is emitted and values are stored as plaintext (preserving current behaviour). """ import logging +import os from typing import Any +from pydantic import SecretStr from sqlalchemy import String from sqlalchemy.types import JSON, TypeDecorator -from openhands.automation.utils.cipher import Cipher, get_cipher +from openhands.sdk.utils.cipher import FERNET_TOKEN_PREFIX, Cipher logger = logging.getLogger("automation.utils.encrypted_fields") @@ -64,6 +66,20 @@ def _warn_no_cipher(field: str) -> None: _warned_no_cipher = True +def get_cipher() -> Cipher | None: + """Return an SDK ``Cipher`` built from the first available key env var, or None. + + Checks ``AUTOMATION_SECRET_KEY`` then ``OH_SECRET_KEY``. If neither is + set, returns ``None`` — callers store/read values as plaintext and a + one-time WARNING is emitted. + """ + for env_var in ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY"): + key = os.getenv(env_var) + if key: + return Cipher(key) + return None + + def _is_secret_header(key: str) -> bool: """Return True if the header key name indicates a sensitive value.""" upper = key.upper() @@ -74,7 +90,7 @@ class EncryptedString(TypeDecorator): """A ``String`` column that is transparently encrypted/decrypted. Matches the per-field encryption pattern used by ``StaticSecret.value`` - in the OpenHands SDK: the Fernet cipher is applied on the way in and out + in the OpenHands SDK: the SDK ``Cipher`` is applied on the way in and out of the database; the ORM always works with plaintext strings. If no cipher key is configured the column behaves as a plain ``String``. @@ -87,21 +103,21 @@ def process_bind_param(self, value: str | None, dialect: Any) -> str | None: # """Encrypt on the way TO the database.""" if value is None: return None - cipher: Cipher | None = get_cipher() + cipher = get_cipher() if cipher is None: _warn_no_cipher(self.__class__.__name__) return value - return cipher.encrypt(value) + return cipher.encrypt(SecretStr(value)) def process_result_value(self, value: str | None, dialect: Any) -> str | None: # noqa: ARG002 """Decrypt on the way FROM the database.""" if value is None: return None - cipher: Cipher | None = get_cipher() + cipher = get_cipher() if cipher is None: return value # stored as plaintext (no key at write time) - if cipher.is_ciphertext(value): - return cipher.decrypt(value) + if value.startswith(FERNET_TOKEN_PREFIX): + return cipher.try_decrypt_str(value) return value # plaintext row written before key was introduced @@ -122,14 +138,14 @@ def process_bind_param(self, value: dict | None, dialect: Any) -> dict | None: """Encrypt sensitive header values on the way TO the database.""" if not value: return value - cipher: Cipher | None = get_cipher() + cipher = get_cipher() if cipher is None: _warn_no_cipher("headers") return value result: dict = {} for k, v in value.items(): if _is_secret_header(k) and isinstance(v, str) and v: - result[k] = cipher.encrypt(v) + result[k] = cipher.encrypt(SecretStr(v)) else: result[k] = v return result @@ -138,14 +154,17 @@ def process_result_value(self, value: dict | None, dialect: Any) -> dict | None: """Decrypt sensitive header values on the way FROM the database.""" if not value: return value - cipher: Cipher | None = get_cipher() + cipher = get_cipher() if cipher is None: return value result: dict = {} for k, v in value.items(): - if _is_secret_header(k) and isinstance(v, str) and cipher.is_ciphertext(v): - decrypted = cipher.decrypt(v) - result[k] = decrypted if decrypted is not None else v + if ( + _is_secret_header(k) + and isinstance(v, str) + and v.startswith(FERNET_TOKEN_PREFIX) + ): + result[k] = cipher.try_decrypt_str(v) or v else: result[k] = v return result diff --git a/openhands/automation/websocket_source_router.py b/openhands/automation/websocket_source_router.py index 91f45e7..d913154 100644 --- a/openhands/automation/websocket_source_router.py +++ b/openhands/automation/websocket_source_router.py @@ -183,12 +183,12 @@ async def update_websocket_source( # This guards against a caller clearing url on a generic source or # app_token on a slack source, which would cause runtime failures in the # SocketManager on the next connect attempt. - if source.kind == "generic" and not source.url: + if source.kind == "GenericWebSocketSource" and not source.url: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="url is required for generic sources and cannot be cleared", ) - if source.kind == "slack" and not source.app_token: + if source.kind == "SlackWebSocketSource" and not source.app_token: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="app_token is required for slack sources and cannot be cleared", diff --git a/pyproject.toml b/pyproject.toml index 2d71327..9b03927 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ dependencies = [ "aiosqlite>=0.20", "alembic>=1.14", "asyncpg>=0.30", - "cryptography>=42", "boto3>=1.35", "cachetools>=7.0.5", "cloud-sql-python-connector[asyncpg]>=1.16", diff --git a/tests/test_cipher_and_encryption.py b/tests/test_cipher_and_encryption.py index b9d72cd..6449e03 100644 --- a/tests/test_cipher_and_encryption.py +++ b/tests/test_cipher_and_encryption.py @@ -1,6 +1,7 @@ """Tests for application-layer encryption. -Covers Cipher, EncryptedString, and EncryptedJSONHeaders. +Covers EncryptedString and EncryptedJSONHeaders TypeDecorators and the +get_cipher() helper. Uses the SDK Cipher (openhands.sdk.utils.cipher). Pure unit tests — no database or Docker required. """ @@ -8,83 +9,36 @@ from unittest.mock import patch import pytest +from pydantic import SecretStr -from openhands.automation.utils.cipher import ( - FERNET_TOKEN_PREFIX, - Cipher, - get_cipher, -) from openhands.automation.utils.encrypted_fields import ( EncryptedJSONHeaders, EncryptedString, _is_secret_header, + get_cipher, ) +from openhands.sdk.utils.cipher import FERNET_TOKEN_PREFIX, Cipher TEST_KEY = "test-secret-key-for-automation-service" -# --------------------------------------------------------------------------- -# Cipher -# --------------------------------------------------------------------------- - - -class TestCipher: - def test_encrypt_decrypt_roundtrip(self): - cipher = Cipher(TEST_KEY) - plaintext = "xapp-1-AAAAAAAAA-1111111111-aaaaaaaaaaaaaaaa" - ciphertext = cipher.encrypt(plaintext) - assert cipher.decrypt(ciphertext) == plaintext - - def test_ciphertext_is_string(self): - cipher = Cipher(TEST_KEY) - ct = cipher.encrypt("hello") - assert isinstance(ct, str) - - def test_ciphertext_has_fernet_prefix(self): - cipher = Cipher(TEST_KEY) - ct = cipher.encrypt("hello") - assert cipher.is_ciphertext(ct) - assert ct.startswith(FERNET_TOKEN_PREFIX) - - def test_plaintext_is_not_ciphertext(self): - cipher = Cipher(TEST_KEY) - assert not cipher.is_ciphertext("xapp-1-AAAAAAAAA") - assert not cipher.is_ciphertext("Bearer token123") - assert not cipher.is_ciphertext("") - - def test_decrypt_invalid_returns_none(self): - cipher = Cipher(TEST_KEY) - result = cipher.decrypt("not-a-valid-fernet-token") - assert result is None +def _encrypt(cipher: Cipher, plaintext: str) -> str: + """Thin wrapper: SDK Cipher.encrypt takes SecretStr; returns str.""" + result = cipher.encrypt(SecretStr(plaintext)) + assert result is not None + return result - def test_decrypt_wrong_key_returns_none(self): - cipher1 = Cipher("key-one") - cipher2 = Cipher("key-two") - ct = cipher1.encrypt("secret") - assert cipher2.decrypt(ct) is None - def test_different_plaintexts_produce_different_ciphertexts(self): - cipher = Cipher(TEST_KEY) - ct1 = cipher.encrypt("secret-a") - ct2 = cipher.encrypt("secret-b") - assert ct1 != ct2 - - def test_same_plaintext_produces_different_ciphertexts(self): - # Fernet uses random nonces — two encryptions of the same plaintext differ - cipher = Cipher(TEST_KEY) - ct1 = cipher.encrypt("same") - ct2 = cipher.encrypt("same") - assert ct1 != ct2 - assert cipher.decrypt(ct1) == cipher.decrypt(ct2) == "same" +# --------------------------------------------------------------------------- +# get_cipher helper +# --------------------------------------------------------------------------- class TestGetCipher: def test_returns_cipher_when_automation_key_set(self): with patch.dict(os.environ, {"AUTOMATION_SECRET_KEY": TEST_KEY}, clear=False): - cipher = get_cipher() - assert cipher is not None - assert isinstance(cipher, Cipher) + assert isinstance(get_cipher(), Cipher) def test_falls_back_to_oh_secret_key(self): env = { @@ -94,8 +48,7 @@ def test_falls_back_to_oh_secret_key(self): } env["OH_SECRET_KEY"] = TEST_KEY with patch.dict(os.environ, env, clear=True): - cipher = get_cipher() - assert cipher is not None + assert get_cipher() is not None def test_returns_none_when_no_key_set(self): env = { @@ -104,19 +57,17 @@ def test_returns_none_when_no_key_set(self): if k not in ("AUTOMATION_SECRET_KEY", "OH_SECRET_KEY") } with patch.dict(os.environ, env, clear=True): - cipher = get_cipher() - assert cipher is None + assert get_cipher() is None def test_automation_key_takes_precedence_over_oh_key(self): - env = {k: v for k, v in os.environ.items()} + env = dict(os.environ) env["AUTOMATION_SECRET_KEY"] = "automation-key" env["OH_SECRET_KEY"] = "oh-key" with patch.dict(os.environ, env, clear=True): cipher = get_cipher() assert cipher is not None - # automation key should be used - ct = cipher.encrypt("value") - assert Cipher("automation-key").decrypt(ct) == "value" + ct = _encrypt(cipher, "value") + assert Cipher("automation-key").try_decrypt_str(ct) == "value" # --------------------------------------------------------------------------- @@ -175,12 +126,12 @@ def test_encrypt_on_bind_param(self): mock.return_value = cipher result = col.process_bind_param("xapp-token", None) assert result is not None - assert cipher.is_ciphertext(result) + assert result.startswith(FERNET_TOKEN_PREFIX) def test_decrypt_on_result_value(self): col = self._make_col() cipher = Cipher(TEST_KEY) - ciphertext = cipher.encrypt("xapp-token") + ciphertext = _encrypt(cipher, "xapp-token") with patch("openhands.automation.utils.encrypted_fields.get_cipher") as mock: mock.return_value = cipher result = col.process_result_value(ciphertext, None) @@ -249,7 +200,7 @@ def test_encrypts_auth_header_only(self): stored = col.process_bind_param(headers, None) assert stored is not None - assert cipher.is_ciphertext(stored["Authorization"]) + assert stored["Authorization"].startswith(FERNET_TOKEN_PREFIX) assert stored["Content-Type"] == "application/json" assert stored["X-Request-ID"] == "req-123" @@ -314,9 +265,9 @@ def test_multiple_secret_headers_encrypted(self): stored = col.process_bind_param(headers, None) assert stored is not None - assert cipher.is_ciphertext(stored["Authorization"]) - assert cipher.is_ciphertext(stored["X-Api-Key"]) - assert cipher.is_ciphertext(stored["Cookie"]) + assert stored["Authorization"].startswith(FERNET_TOKEN_PREFIX) + assert stored["X-Api-Key"].startswith(FERNET_TOKEN_PREFIX) + assert stored["Cookie"].startswith(FERNET_TOKEN_PREFIX) assert stored["Accept"] == "application/json" diff --git a/tests/test_websocket_source_router.py b/tests/test_websocket_source_router.py index 2579963..f1030b5 100644 --- a/tests/test_websocket_source_router.py +++ b/tests/test_websocket_source_router.py @@ -8,8 +8,8 @@ from openhands.automation.models import OutboundWebSocketSource, WebSocketStatus from openhands.automation.schemas import ( - GenericWebSocketSourceCreate, - SlackWebSocketSourceCreate, + GenericWebSocketSource, + SlackWebSocketSource, WebSocketSourceUpdate, ) @@ -25,39 +25,37 @@ # --------------------------------------------------------------------------- -class TestGenericWebSocketSourceCreate: +class TestGenericWebSocketSource: def test_valid_minimal(self): - s = GenericWebSocketSourceCreate( + s = GenericWebSocketSource( name="My WS", source="my-ws", url="wss://example.com/events", ) - assert s.kind == "generic" + assert s.kind == "GenericWebSocketSource" assert s.event_key_expr == "type" assert s.payload_expr is None assert s.headers is None def test_url_must_be_wss(self): with pytest.raises(ValidationError, match="wss://"): - GenericWebSocketSourceCreate( + GenericWebSocketSource( name="Bad URL", source="bad", url="http://example.com" ) def test_ws_url_also_accepted(self): - s = GenericWebSocketSourceCreate( - name="WS", source="ws", url="ws://localhost:9000" - ) + s = GenericWebSocketSource(name="WS", source="ws", url="ws://localhost:9000") assert s.url == "ws://localhost:9000" def test_reserved_source_rejected(self): with pytest.raises(ValidationError, match="reserved"): - GenericWebSocketSourceCreate( + GenericWebSocketSource( name="GitHub WS", source="github", url="wss://example.com" ) def test_invalid_jmespath_event_key_expr(self): with pytest.raises(ValidationError, match="JMESPath"): - GenericWebSocketSourceCreate( + GenericWebSocketSource( name="Bad", source="bad", url="wss://example.com", @@ -66,7 +64,7 @@ def test_invalid_jmespath_event_key_expr(self): def test_invalid_jmespath_filter_expr(self): with pytest.raises(ValidationError, match="JMESPath"): - GenericWebSocketSourceCreate( + GenericWebSocketSource( name="Bad", source="bad", url="wss://example.com", @@ -74,13 +72,13 @@ def test_invalid_jmespath_filter_expr(self): ) def test_source_normalised_to_lowercase(self): - s = GenericWebSocketSourceCreate( + s = GenericWebSocketSource( name="WS", source="MySource", url="wss://example.com" ) assert s.source == "mysource" def test_with_headers_and_payload_expr(self): - s = GenericWebSocketSourceCreate( + s = GenericWebSocketSource( name="WS", source="my-ws", url="wss://example.com", @@ -91,25 +89,25 @@ def test_with_headers_and_payload_expr(self): assert s.payload_expr == "event" -class TestSlackWebSocketSourceCreate: +class TestSlackWebSocketSource: def test_valid(self): - s = SlackWebSocketSourceCreate( + s = SlackWebSocketSource( name="Slack", source="slack", app_token="xapp-1-abc123", ) - assert s.kind == "slack" + assert s.kind == "SlackWebSocketSource" assert s.event_key_expr == "payload.event.type" assert s.payload_expr == "payload.event" def test_app_token_must_start_with_xapp(self): with pytest.raises(ValidationError, match="xapp-"): - SlackWebSocketSourceCreate( + SlackWebSocketSource( name="Slack", source="slack", app_token="xoxb-wrong-token" ) def test_custom_event_key_expr(self): - s = SlackWebSocketSourceCreate( + s = SlackWebSocketSource( name="Slack", source="slack", app_token="xapp-1-abc123", @@ -153,7 +151,7 @@ async def test_create_generic_success(self, async_client, _no_socket_manager): resp = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "My Generic WS", "source": "my-ws", "url": "wss://example.com/events", @@ -161,7 +159,7 @@ async def test_create_generic_success(self, async_client, _no_socket_manager): ) assert resp.status_code == 201 data = resp.json() - assert data["kind"] == "generic" + assert data["kind"] == "GenericWebSocketSource" assert data["source"] == "my-ws" assert data["url"] == "wss://example.com/events" assert data["status"] == "DISCONNECTED" @@ -173,7 +171,7 @@ async def test_create_slack_success(self, async_client, _no_socket_manager): resp = await async_client.post( BASE_URL, json={ - "kind": "slack", + "kind": "SlackWebSocketSource", "name": "Slack Events", "source": "slack-prod", "app_token": "xapp-1-AAAAAAAAA-1111111111-aaaaaaaaaaaaaaaa", @@ -181,7 +179,7 @@ async def test_create_slack_success(self, async_client, _no_socket_manager): ) assert resp.status_code == 201 data = resp.json() - assert data["kind"] == "slack" + assert data["kind"] == "SlackWebSocketSource" assert data["event_key_expr"] == "payload.event.type" assert data["payload_expr"] == "payload.event" # app_token is never returned @@ -191,7 +189,7 @@ async def test_create_duplicate_source_returns_409( self, async_client, _no_socket_manager ): payload = { - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "WS Source", "source": "dup-source", "url": "wss://example.com", @@ -214,7 +212,7 @@ async def test_create_notifies_socket_manager(self, async_client, async_session) resp = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "Notify WS", "source": "notify-ws", "url": "wss://example.com", @@ -237,7 +235,7 @@ async def test_create_disabled_source_does_not_notify( resp = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "Disabled WS", "source": "disabled-ws", "url": "wss://example.com", @@ -260,7 +258,7 @@ async def test_slack_bad_token_returns_422(self, async_client, _no_socket_manage resp = await async_client.post( BASE_URL, json={ - "kind": "slack", + "kind": "SlackWebSocketSource", "name": "Slack", "source": "bad-slack", "app_token": "xoxb-not-an-app-token", @@ -285,7 +283,7 @@ async def test_list_returns_own_org_sources( await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": f"WS {i}", "source": f"ws-{i}", "url": "wss://example.com", @@ -297,7 +295,7 @@ async def test_list_returns_own_org_sources( org_id=OTHER_ORG_ID, name="Other WS", source="other-ws", - kind="generic", + kind="GenericWebSocketSource", url="wss://other.com", event_key_expr="type", status=WebSocketStatus.DISCONNECTED, @@ -317,7 +315,7 @@ async def test_get_existing(self, async_client, async_session, _no_socket_manage create = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "Get Me", "source": "get-me", "url": "wss://example.com", @@ -334,7 +332,7 @@ async def test_get_other_org_returns_404(self, async_client, async_session): org_id=OTHER_ORG_ID, name="Other", source="other", - kind="generic", + kind="GenericWebSocketSource", url="wss://other.com", event_key_expr="type", status=WebSocketStatus.DISCONNECTED, @@ -355,7 +353,7 @@ async def test_update_name_and_enabled(self, async_client, _no_socket_manager): create = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "Old Name", "source": "upd-ws", "url": "wss://example.com", @@ -384,7 +382,7 @@ async def test_update_triggers_reconnect_on_enabled_change( create = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "Reconnect Test", "source": "recon-ws", "url": "wss://example.com", @@ -407,7 +405,7 @@ async def test_delete_success(self, async_client, _no_socket_manager): create = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "Delete Me", "source": "del-ws", "url": "wss://example.com", @@ -433,7 +431,7 @@ async def test_delete_notifies_socket_manager( create = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "Del Notify", "source": "del-notify", "url": "wss://example.com", @@ -459,7 +457,7 @@ async def test_reconnect_enabled_source(self, async_client, _no_socket_manager): create = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "Reconnect WS", "source": "recon-ws2", "url": "wss://example.com", @@ -480,7 +478,7 @@ async def test_reconnect_disabled_source_returns_400( create = await async_client.post( BASE_URL, json={ - "kind": "generic", + "kind": "GenericWebSocketSource", "name": "Disabled", "source": "disabled-ws2", "url": "wss://example.com", @@ -513,7 +511,7 @@ async def test_dispatch_pre_filter_drops_non_matching(self, async_session_factor org_id=TEST_ORG_ID, name="Test", source="test-ws", - kind="generic", + kind="GenericWebSocketSource", event_key_expr="type", filter_expr="type == 'allowed'", status=WebSocketStatus.CONNECTED, @@ -540,7 +538,7 @@ async def test_dispatch_unwraps_payload_via_payload_expr( org_id=TEST_ORG_ID, name="Test", source="test-ws", - kind="slack", + kind="SlackWebSocketSource", event_key_expr="payload.event.type", payload_expr="payload.event", status=WebSocketStatus.CONNECTED, @@ -578,7 +576,7 @@ async def test_dispatch_drops_non_string_event_key(self, async_session_factory): org_id=TEST_ORG_ID, name="Test", source="test-ws", - kind="generic", + kind="GenericWebSocketSource", event_key_expr="metadata", # returns a dict, not a string status=WebSocketStatus.CONNECTED, ) diff --git a/uv.lock b/uv.lock index c4fddaa..c4f8b8e 100644 --- a/uv.lock +++ b/uv.lock @@ -2168,7 +2168,6 @@ dependencies = [ { name = "cachetools" }, { name = "cloud-sql-python-connector", extra = ["asyncpg"] }, { name = "croniter" }, - { name = "cryptography" }, { name = "fastapi" }, { name = "google-cloud-storage" }, { name = "httpx" }, @@ -2208,7 +2207,6 @@ requires-dist = [ { name = "cachetools", specifier = ">=7.0.5" }, { name = "cloud-sql-python-connector", extras = ["asyncpg"], specifier = ">=1.16" }, { name = "croniter", specifier = ">=2" }, - { name = "cryptography", specifier = ">=42" }, { name = "fastapi", specifier = ">=0.115" }, { name = "google-cloud-storage", specifier = ">=2.18" }, { name = "httpx", specifier = ">=0.27" },