diff --git a/migrations/versions/005_session_tables.py b/migrations/versions/005_session_tables.py new file mode 100644 index 0000000..a6c85f5 --- /dev/null +++ b/migrations/versions/005_session_tables.py @@ -0,0 +1,113 @@ +"""Add session-based event routing tables. + +This migration adds two tables to support session-based conversation reuse for +event-triggered automations: + +1. ``automation_sessions`` — tracks active sessions (sandbox + session key). + An ACTIVE session routes incoming events to its running sandbox rather than + creating a new run for each event. + +2. ``pending_session_events`` — queues events when a session's sandbox is alive + or when a sandbox has died and ``on_sandbox_death`` is set to "queue"/"restart". + +Cross-database compatible: works with both PostgreSQL and SQLite. + +Revision ID: 005 +Revises: 004 +Create Date: 2026-05-15 +""" + +from collections.abc import Sequence + +from alembic import op +from sqlalchemy import JSON, Column, DateTime, String, Uuid, text + + +revision: str = "005" +down_revision: str = "004" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # 1. Create automation_sessions table + op.create_table( + "automation_sessions", + Column("id", Uuid, primary_key=True), + Column("automation_id", Uuid, nullable=False), + Column("session_key", String(255), nullable=False), + Column("run_id", Uuid, nullable=False), + # Sandbox identifier (populated by the dispatcher after sandbox creation) + Column("sandbox_id", String(255), nullable=True), + # Status: ACTIVE, EXPIRED, or DEAD + # Using String instead of Enum for cross-database compatibility + Column("status", String(20), nullable=False, server_default="ACTIVE"), + Column( + "started_at", + DateTime(timezone=True), + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + ), + # Pre-computed expiry deadline: started_at + session_timeout_seconds + Column("expires_at", DateTime(timezone=True), nullable=False), + Column( + "last_event_at", + DateTime(timezone=True), + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + ), + ) + op.create_index( + "ix_automation_sessions_automation_id", + "automation_sessions", + ["automation_id"], + ) + op.create_index( + "ix_automation_sessions_status", + "automation_sessions", + ["status"], + ) + # Compound index for the primary lookup pattern: + # SELECT ... WHERE automation_id = ? AND session_key = ? AND status = 'ACTIVE' + op.create_index( + "ix_session_lookup", + "automation_sessions", + ["automation_id", "session_key", "status"], + ) + + # 2. Create pending_session_events table + op.create_table( + "pending_session_events", + Column("id", Uuid, primary_key=True), + Column("automation_id", Uuid, nullable=False), + Column("session_key", String(255), nullable=False), + # Event payload (same format as automation_runs.event_payload) + Column("event_payload", JSON, nullable=False), + Column( + "created_at", + DateTime(timezone=True), + nullable=False, + server_default=text("CURRENT_TIMESTAMP"), + ), + ) + op.create_index( + "ix_pending_session_events_automation_id", + "pending_session_events", + ["automation_id"], + ) + op.create_index( + "ix_pending_session_events_session_key", + "pending_session_events", + ["session_key"], + ) + # Compound index for fetching queued events for a specific session + op.create_index( + "ix_pending_session_events_lookup", + "pending_session_events", + ["automation_id", "session_key"], + ) + + +def downgrade() -> None: + op.drop_table("pending_session_events") + op.drop_table("automation_sessions") diff --git a/openhands/automation/dispatcher.py b/openhands/automation/dispatcher.py index aae6a53..5cb5624 100644 --- a/openhands/automation/dispatcher.py +++ b/openhands/automation/dispatcher.py @@ -44,6 +44,7 @@ mark_run_terminal, update_sandbox_id, ) +from openhands.automation.utils.session import create_session, extract_session_key from openhands.automation.utils.tarball_validation import ( is_http_url, parse_internal_upload_id, @@ -114,6 +115,72 @@ async def _poll_pending_runs( return list(result.scalars().all()) +async def _maybe_create_session( + run: AutomationRun, + sandbox_id: str | None, + session_factory: async_sessionmaker[AsyncSession], +) -> None: + """Create an AutomationSession if the run's trigger has session config. + + Called after a run is successfully dispatched. If the trigger has a + ``session`` configuration, extract the session key from the event payload + and record the session so subsequent events are routed here. + + Failures are logged but do not affect the run (best-effort). + """ + from openhands.automation.schemas import EventTrigger + + automation = run.automation + if not automation: + return + + try: + trigger_data = automation.trigger + if trigger_data.get("type") != "event": + return + + trigger = EventTrigger.model_validate(trigger_data) + if not trigger.session: + return + + session_cfg = trigger.session + event_payload = run.event_payload or {} + session_key = extract_session_key(session_cfg.key_expr, event_payload) + + if session_key is None: + logger.warning( + "Could not extract session key for run %s " + "using expr=%r; session not created", + run.id, + session_cfg.key_expr, + ) + return + + async with session_factory() as db_session: + session_record = await create_session( + automation_id=automation.id, + session_key=session_key, + run_id=run.id, + session_timeout_seconds=session_cfg.session_timeout_seconds, + db_session=db_session, + ) + if sandbox_id: + session_record.sandbox_id = sandbox_id + await db_session.commit() + + logger.info( + "Created session key=%s for run %s automation %s", + session_key, + run.id, + automation.id, + ) + except Exception: + logger.exception( + "Failed to create session for run %s; continuing without session", + run.id, + ) + + async def _execute_run( run: AutomationRun, settings: ServiceSettings, @@ -268,6 +335,10 @@ async def _fail(error: str, disable: bool = False) -> None: if result.success: if ctx.sandbox_id: await update_sandbox_id(session_factory, run.id, ctx.sandbox_id) + # If this run was triggered by an event with session config, create a + # session record so subsequent events with the same session key are + # routed to this sandbox instead of starting a new run. + await _maybe_create_session(run, ctx.sandbox_id, session_factory) logger.info( "Automation dispatched successfully, waiting for callback", extra=_log_ctx(sandbox_id=ctx.sandbox_id), diff --git a/openhands/automation/event_router.py b/openhands/automation/event_router.py index d490c8a..ef82de7 100644 --- a/openhands/automation/event_router.py +++ b/openhands/automation/event_router.py @@ -41,6 +41,11 @@ from openhands.automation.event_schemas import WebhookEvent, parse_event from openhands.automation.schemas import EventResponse from openhands.automation.trigger_matcher import matches_trigger +from openhands.automation.utils.session import ( + extract_session_key, + get_active_session, + queue_pending_event, +) from openhands.automation.utils.webhook import ( create_automation_run, get_event_automations, @@ -151,25 +156,25 @@ async def receive_event( org_id, ) - # 6. Find matching automations + # 6. Find matching automations (preserve trigger alongside each automation) automations = await get_event_automations(org_id, source, session) - matched_automations = [] + matched: list[tuple] = [] # (Automation, EventTrigger) for automation, trigger in automations: # Match trigger against webhook payload using JMESPath filter if matches_trigger(trigger, source, event.event_key, webhook_payload): - matched_automations.append(automation) + matched.append((automation, trigger)) logger.info( "Event matched %d/%d automations for org=%s", - len(matched_automations), + len(matched), len(automations), org_id, ) - # 7. Create PENDING runs for matched automations - # For Pydantic-parsed events (GitHub), use model_dump() for typed fields - # For custom webhooks, use the webhook payload directly + # 7. Create PENDING runs or queue events for matched automations. + # For Pydantic-parsed events (GitHub), use model_dump() for typed fields. + # For custom webhooks, use the webhook payload directly. event_payload = ( event.model_dump(mode="json") if isinstance(event, BaseModel) @@ -177,7 +182,49 @@ async def receive_event( ) run_ids: list[str] = [] - for automation in matched_automations: + events_queued: int = 0 + + for automation, trigger in matched: + session_cfg = trigger.session + + if session_cfg: + # Session mode: route to existing session or start a new one + session_key = extract_session_key(session_cfg.key_expr, event_payload) + + if session_key is None: + logger.warning( + "Could not extract session key for automation %s " + "using expr=%r; falling back to new run", + automation.id, + session_cfg.key_expr, + ) + else: + active_session = await get_active_session( + automation.id, session_key, session + ) + + if active_session is not None: + # Route to existing session — queue event for the running sandbox + await queue_pending_event( + automation.id, session_key, event_payload, session + ) + events_queued += 1 + logger.info( + "Event queued to existing session key=%s " + "automation=%s session_id=%s", + session_key, + automation.id, + active_session.id, + ) + continue # Skip new run creation + + # No active session — fall through to create a new run below + logger.info( + "No active session for key=%s automation=%s; creating new run", + session_key, + automation.id, + ) + run = await create_automation_run( automation, session, event_payload=event_payload ) @@ -187,6 +234,7 @@ async def receive_event( return EventResponse( received=True, - matched=len(matched_automations), + matched=len(matched), runs_created=run_ids, + events_queued=events_queued, ) diff --git a/openhands/automation/models.py b/openhands/automation/models.py index fe45932..6b542dc 100644 --- a/openhands/automation/models.py +++ b/openhands/automation/models.py @@ -42,6 +42,14 @@ class AutomationRunStatus(enum.Enum): FAILED = "FAILED" +class SessionStatus(enum.Enum): + """Status of an automation session.""" + + ACTIVE = "ACTIVE" # Sandbox alive, accepting and processing events + EXPIRED = "EXPIRED" # Timed out (idle timeout or max session lifetime reached) + DEAD = "DEAD" # Sandbox died unexpectedly (watchdog detected) + + class Automation(Base): """An automation definition: what to run and when to trigger it.""" @@ -310,3 +318,100 @@ class CustomWebhook(Base): __table_args__ = ( Index("ix_custom_webhooks_org_source", "org_id", "source", unique=True), ) + + +class AutomationSession(Base): + """Tracks active sessions for event routing. + + A session ties together a series of related events under a single sandbox run. + The session key (extracted from event payloads via a JMESPath expression) is + used to route incoming events to an existing session instead of creating a new + run for each event. + + Lifecycle: + - ACTIVE: Sandbox is alive; new events are queued as PendingSessionEvent rows. + - EXPIRED: Idle timeout or max session lifetime reached (set by SDK/watchdog). + - DEAD: Sandbox died unexpectedly; watchdog transitions here. + """ + + __tablename__ = "automation_sessions" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + automation_id: Mapped[uuid.UUID] = mapped_column( + Uuid, + ForeignKey("automations.id", ondelete="CASCADE"), + nullable=False, + ) + session_key: Mapped[str] = mapped_column(String(255), nullable=False) + + # The run that owns this session (the first run that created it) + run_id: Mapped[uuid.UUID] = mapped_column( + Uuid, + ForeignKey("automation_runs.id", ondelete="CASCADE"), + nullable=False, + ) + + # Sandbox identifier (set by the dispatcher after sandbox creation) + sandbox_id: Mapped[str | None] = mapped_column(String(255), nullable=True) + + # Session status + status: Mapped[SessionStatus] = mapped_column( + Enum(SessionStatus, native_enum=False, length=20), + nullable=False, + default=SessionStatus.ACTIVE, + index=True, + ) + + # Lifecycle timestamps + started_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + nullable=False, + ) + # Pre-computed absolute expiry: started_at + session_timeout_seconds + expires_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + # Updated each time a new event is queued into this session + last_event_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + nullable=False, + ) + + __table_args__ = ( + # Primary lookup: find active session for (automation, session_key) + Index("ix_session_lookup", "automation_id", "session_key", "status"), + ) + + +class PendingSessionEvent(Base): + """Events queued for delivery to an active session's sandbox. + + When an event arrives for an existing ACTIVE session, it is stored here + instead of creating a new ``AutomationRun``. The SDK script running inside + the sandbox polls for these events and processes them. + + Events are also queued here when a sandbox dies and ``on_sandbox_death`` + is set to ``"queue"`` or ``"restart"`` — allowing the next sandbox to + pick them up. + """ + + __tablename__ = "pending_session_events" + + id: Mapped[uuid.UUID] = mapped_column(Uuid, primary_key=True, default=uuid.uuid4) + automation_id: Mapped[uuid.UUID] = mapped_column( + Uuid, + ForeignKey("automations.id", ondelete="CASCADE"), + nullable=False, + ) + session_key: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + + # The event payload (same format as AutomationRun.event_payload) + event_payload: Mapped[dict] = mapped_column(JSON, nullable=False) + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + server_default=text("CURRENT_TIMESTAMP"), + nullable=False, + ) diff --git a/openhands/automation/schemas.py b/openhands/automation/schemas.py index cf7d6db..3ac3365 100644 --- a/openhands/automation/schemas.py +++ b/openhands/automation/schemas.py @@ -12,6 +12,84 @@ from openhands.automation.config import get_config +class SessionConfig(BaseModel): + """Configure session-based conversation reuse for event-triggered automations. + + Sessions allow related events (e.g., comments on the same GitHub PR, messages + in the same Slack thread) to be routed to the same running sandbox instead of + creating a new sandbox for each event. + + ## How It Works + + 1. A JMESPath expression (``key_expr``) extracts a unique identifier from each + event payload — the session key. + 2. The first event with a new session key creates a new ``AutomationRun`` and + sandbox (normal dispatch flow). + 3. Subsequent events with the same session key are queued as + ``PendingSessionEvent`` records and delivered to the running sandbox. + 4. The SDK script polls for pending events via the workspace API and processes + them in a loop until idle timeout is reached. + + ## Examples + + ```json + // GitHub PR: route all events on the same PR to one session + {"key_expr": "pull_request.number || issue.number"} + + // Slack thread: route thread replies to the originating session + {"key_expr": "thread_ts || ts"} + ``` + """ + + model_config = ConfigDict(extra="forbid") + + key_expr: str = Field( + ..., + description=( + "JMESPath expression to extract the session key from the event payload. " + "The extracted value (converted to string) uniquely identifies a session. " + "Examples: 'pull_request.number', 'thread_ts || ts'" + ), + ) + idle_timeout_seconds: int = Field( + default=300, + ge=30, + le=3600, + description=( + "Session expires after this many seconds without a new event. " + "The SDK script should exit after this idle period (30–3600s)." + ), + ) + session_timeout_seconds: int = Field( + default=3600, + ge=60, + le=86400, + description="Maximum total session lifetime in seconds (60–86400s).", + ) + on_sandbox_death: Literal["queue", "restart", "drop"] = Field( + default="queue", + description=( + "Behavior when the sandbox dies while events are pending: " + "'queue' stores events for the next sandbox launch, " + "'restart' immediately starts a new sandbox with the queued events, " + "'drop' discards pending events." + ), + ) + + @field_validator("key_expr") + @classmethod + def validate_key_expr(cls, v: str) -> str: + """Validate JMESPath expression syntax.""" + import jmespath + from jmespath import exceptions as jmespath_exceptions + + try: + jmespath.compile(v) + except jmespath_exceptions.JMESPathError as e: + raise ValueError(f"Invalid JMESPath expression: {e}") from e + return v + + # Allowed URI schemes for tarball_path (includes internal upload scheme) _TARBALL_SCHEME_RE = re.compile(r"^(s3|gs|https?|oh-internal)://") @@ -152,6 +230,15 @@ class EventTrigger(BaseModel): "icontains(comment.body, '@openhands-resolver')" ), ) + session: SessionConfig | None = Field( + default=None, + description=( + "Optional session configuration for routing related events to the same " + "running sandbox. When set, events with the same session key (extracted " + "from the payload via key_expr) are queued to the active session instead " + "of creating a new run." + ), + ) @field_validator("filter") @classmethod @@ -358,6 +445,7 @@ class EventResponse(BaseModel): received: bool matched: int runs_created: list[str] # List of run IDs created + events_queued: int = 0 # Events routed to existing sessions # Valid source name pattern: lowercase alphanumeric with hyphens, 1-50 chars diff --git a/openhands/automation/utils/session.py b/openhands/automation/utils/session.py new file mode 100644 index 0000000..3b1ef07 --- /dev/null +++ b/openhands/automation/utils/session.py @@ -0,0 +1,179 @@ +"""Session management utilities for event routing. + +Helpers for creating, looking up, and updating automation sessions used +in session-based event routing (see schemas.SessionConfig). +""" + +import logging +import uuid +from datetime import timedelta +from typing import Any + +import jmespath +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.automation.models import ( + AutomationRun, + AutomationSession, + PendingSessionEvent, + SessionStatus, +) +from openhands.automation.utils.time import utcnow + + +logger = logging.getLogger("automation.utils.session") + + +def extract_session_key(key_expr: str, payload: dict[str, Any]) -> str | None: + """Extract a session key from an event payload using a JMESPath expression. + + Args: + key_expr: JMESPath expression, e.g. ``'pull_request.number'`` or + ``'thread_ts || ts'``. + payload: The event payload dict. + + Returns: + The extracted value as a string, or ``None`` if the expression + evaluates to ``None`` or raises an error. + """ + try: + result = jmespath.search(key_expr, payload) + except Exception as e: + logger.warning("Failed to extract session key using expr=%r: %s", key_expr, e) + return None + + if result is None: + return None + return str(result) + + +async def get_active_session( + automation_id: uuid.UUID, + session_key: str, + db_session: AsyncSession, +) -> AutomationSession | None: + """Look up an ACTIVE, non-expired session for an automation + session key. + + Returns the most-recently-started active session, or ``None`` if none exists. + """ + now = utcnow() + result = await db_session.execute( + select(AutomationSession) + .where( + AutomationSession.automation_id == automation_id, + AutomationSession.session_key == session_key, + AutomationSession.status == SessionStatus.ACTIVE, + AutomationSession.expires_at > now, + ) + .order_by(AutomationSession.started_at.desc()) + .limit(1) + ) + return result.scalar_one_or_none() + + +async def create_session( + automation_id: uuid.UUID, + session_key: str, + run_id: uuid.UUID, + session_timeout_seconds: int, + db_session: AsyncSession, +) -> AutomationSession: + """Create a new ACTIVE session record. + + Args: + automation_id: The parent automation. + session_key: The key extracted from the triggering event payload. + run_id: The ``AutomationRun`` that owns this session. + session_timeout_seconds: Maximum session lifetime. + db_session: Async SQLAlchemy session. + + Returns: + The newly created ``AutomationSession`` (not yet committed). + """ + now = utcnow() + session = AutomationSession( + id=uuid.uuid4(), + automation_id=automation_id, + session_key=session_key, + run_id=run_id, + status=SessionStatus.ACTIVE, + started_at=now, + expires_at=now + timedelta(seconds=session_timeout_seconds), + last_event_at=now, + ) + db_session.add(session) + return session + + +async def queue_pending_event( + automation_id: uuid.UUID, + session_key: str, + event_payload: dict[str, Any], + db_session: AsyncSession, +) -> PendingSessionEvent: + """Queue an event for delivery to the sandbox owning the given session. + + Updates the session's ``last_event_at`` if an active session exists. + + Args: + automation_id: The parent automation. + session_key: The session key. + event_payload: The event payload to deliver. + db_session: Async SQLAlchemy session. + + Returns: + The created ``PendingSessionEvent`` (not yet committed). + """ + pending_event = PendingSessionEvent( + id=uuid.uuid4(), + automation_id=automation_id, + session_key=session_key, + event_payload=event_payload, + ) + db_session.add(pending_event) + + # Bump last_event_at on the active session so idle-timeout tracking is accurate + now = utcnow() + await db_session.execute( + update(AutomationSession) + .where( + AutomationSession.automation_id == automation_id, + AutomationSession.session_key == session_key, + AutomationSession.status == SessionStatus.ACTIVE, + ) + .values(last_event_at=now) + ) + + return pending_event + + +async def mark_session_dead( + automation_run: AutomationRun, + db_session: AsyncSession, +) -> bool: + """Mark the session for a given run as DEAD. + + Called by the watchdog when a run's sandbox is found to be dead. Uses + optimistic locking (``WHERE status = 'ACTIVE'``) so concurrent callbacks + don't clobber a legitimate EXPIRED transition. + + Args: + automation_run: The run whose session should be marked dead. + db_session: Async SQLAlchemy session. + + Returns: + ``True`` if a session was found and updated, ``False`` otherwise. + """ + from sqlalchemy.engine import CursorResult + + stmt = ( + update(AutomationSession) + .where( + AutomationSession.run_id == automation_run.id, + AutomationSession.status == SessionStatus.ACTIVE, + ) + .values(status=SessionStatus.DEAD) + ) + result: CursorResult = await db_session.execute(stmt) # type: ignore[assignment] + return result.rowcount > 0 diff --git a/openhands/automation/watchdog.py b/openhands/automation/watchdog.py index e68337b..62d1c43 100644 --- a/openhands/automation/watchdog.py +++ b/openhands/automation/watchdog.py @@ -23,6 +23,7 @@ from openhands.automation.config import Settings from openhands.automation.models import AutomationRun, AutomationRunStatus from openhands.automation.utils import log_extra +from openhands.automation.utils.session import mark_session_dead from openhands.automation.utils.time import utcnow @@ -231,6 +232,16 @@ async def mark_stale_runs( try: if await _verify_and_mark_run(session, run, settings): marked += 1 + # Mark any active session for this run as DEAD so the event + # router stops routing new events to the dead sandbox. + try: + session_marked = await mark_session_dead(run, session) + if session_marked: + logger.info("Marked session dead for run", extra=extra) + except Exception: + logger.exception( + "Failed to mark session dead for run", extra=extra + ) else: logger.info("Run already completed, skipping", extra=extra) except Exception: diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..c7880a5 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,245 @@ +"""Tests for session-based event routing. + +Tests session schema validation (SessionConfig), session utility functions +(extract_session_key, get_active_session, etc.), and event router session routing. +These tests do NOT require Docker or PostgreSQL — they use in-process logic only +or mocks where needed. +""" + +import pytest + +from openhands.automation.schemas import EventTrigger, SessionConfig +from openhands.automation.utils.session import extract_session_key + + +# --------------------------------------------------------------------------- +# SessionConfig schema validation +# --------------------------------------------------------------------------- + + +class TestSessionConfigSchema: + def test_minimal_valid_config(self): + cfg = SessionConfig(key_expr="pull_request.number") + assert cfg.key_expr == "pull_request.number" + assert cfg.idle_timeout_seconds == 300 # default + assert cfg.session_timeout_seconds == 3600 # default + assert cfg.on_sandbox_death == "queue" # default + + def test_full_config(self): + cfg = SessionConfig( + key_expr="thread_ts || ts", + idle_timeout_seconds=120, + session_timeout_seconds=7200, + on_sandbox_death="restart", + ) + assert cfg.key_expr == "thread_ts || ts" + assert cfg.idle_timeout_seconds == 120 + assert cfg.session_timeout_seconds == 7200 + assert cfg.on_sandbox_death == "restart" + + def test_on_sandbox_death_drop(self): + cfg = SessionConfig(key_expr="issue.number", on_sandbox_death="drop") + assert cfg.on_sandbox_death == "drop" + + def test_on_sandbox_death_queue(self): + cfg = SessionConfig(key_expr="issue.number", on_sandbox_death="queue") + assert cfg.on_sandbox_death == "queue" + + def test_invalid_on_sandbox_death(self): + with pytest.raises(Exception): + SessionConfig.model_validate( + {"key_expr": "issue.number", "on_sandbox_death": "invalid"} + ) + + def test_idle_timeout_too_low(self): + with pytest.raises(Exception): + SessionConfig(key_expr="issue.number", idle_timeout_seconds=10) # min 30 + + def test_idle_timeout_too_high(self): + with pytest.raises(Exception): + SessionConfig( + key_expr="issue.number", idle_timeout_seconds=4000 + ) # max 3600 + + def test_session_timeout_too_low(self): + with pytest.raises(Exception): + SessionConfig(key_expr="issue.number", session_timeout_seconds=30) # min 60 + + def test_session_timeout_too_high(self): + with pytest.raises(Exception): + SessionConfig( + key_expr="issue.number", session_timeout_seconds=100000 + ) # max 86400 + + def test_invalid_jmespath_expression(self): + with pytest.raises(Exception) as exc_info: + SessionConfig(key_expr="[invalid(((") + assert "JMESPath" in str(exc_info.value) or "Invalid" in str(exc_info.value) + + def test_extra_fields_forbidden(self): + with pytest.raises(Exception): + SessionConfig.model_validate( + {"key_expr": "issue.number", "unknown_field": "value"} + ) + + def test_complex_jmespath(self): + # JMESPath alternatives expression is valid syntax for jmespath.compile + cfg = SessionConfig(key_expr="pull_request.number || issue.number") + assert cfg.key_expr == "pull_request.number || issue.number" + + def test_roundtrip_json(self): + cfg = SessionConfig( + key_expr="issue.number", + idle_timeout_seconds=60, + session_timeout_seconds=1800, + on_sandbox_death="restart", + ) + data = cfg.model_dump() + restored = SessionConfig.model_validate(data) + assert restored == cfg + + +# --------------------------------------------------------------------------- +# EventTrigger with session field +# --------------------------------------------------------------------------- + + +class TestEventTriggerWithSession: + def test_event_trigger_without_session(self): + trigger = EventTrigger(source="github", on="push") + assert trigger.session is None + + def test_event_trigger_with_session(self): + trigger = EventTrigger( + source="github", + on=["issue_comment.created", "pull_request.synchronize"], + filter="icontains(comment.body, '@openhands')", + session=SessionConfig( + key_expr="pull_request.number || issue.number", + idle_timeout_seconds=600, + ), + ) + assert trigger.session is not None + assert trigger.session.key_expr == "pull_request.number || issue.number" + assert trigger.session.idle_timeout_seconds == 600 + + def test_event_trigger_session_roundtrip(self): + """EventTrigger with session survives model_dump / model_validate.""" + trigger = EventTrigger( + source="slack", + on="message", + session=SessionConfig( + key_expr="thread_ts || ts", + on_sandbox_death="queue", + ), + ) + data = trigger.model_dump() + restored = EventTrigger.model_validate(data) + assert restored.session is not None + assert restored.session.key_expr == "thread_ts || ts" + assert restored.session.on_sandbox_death == "queue" + + def test_event_trigger_session_via_dict(self): + """EventTrigger can be constructed from a dict (as stored in DB JSON column).""" + trigger_dict = { + "type": "event", + "source": "github", + "on": "pull_request.opened", + "session": { + "key_expr": "pull_request.number", + "idle_timeout_seconds": 300, + "session_timeout_seconds": 3600, + "on_sandbox_death": "restart", + }, + } + trigger = EventTrigger.model_validate(trigger_dict) + assert trigger.session is not None + assert trigger.session.on_sandbox_death == "restart" + + +# --------------------------------------------------------------------------- +# extract_session_key utility +# --------------------------------------------------------------------------- + + +class TestExtractSessionKey: + def test_simple_path(self): + payload = {"pull_request": {"number": 42}} + key = extract_session_key("pull_request.number", payload) + assert key == "42" + + def test_nested_path(self): + payload = {"issue": {"id": 7, "number": 123}} + key = extract_session_key("issue.number", payload) + assert key == "123" + + def test_string_value(self): + payload = {"thread_ts": "1234567890.123456"} + key = extract_session_key("thread_ts", payload) + assert key == "1234567890.123456" + + def test_missing_path_returns_none(self): + payload = {"other": "data"} + key = extract_session_key("pull_request.number", payload) + assert key is None + + def test_null_value_returns_none(self): + payload = {"pull_request": None} + key = extract_session_key("pull_request.number", payload) + assert key is None + + def test_jmespath_or_expression(self): + # JMESPath || (alternatives) — first non-null wins + payload = {"pull_request": {"number": 99}} + key = extract_session_key("pull_request.number || issue.number", payload) + assert key == "99" + + def test_jmespath_or_fallback(self): + payload = {"issue": {"number": 55}} + key = extract_session_key("pull_request.number || issue.number", payload) + assert key == "55" + + def test_integer_value_converted_to_string(self): + payload = {"pr": {"id": 1001}} + key = extract_session_key("pr.id", payload) + assert key == "1001" + assert isinstance(key, str) + + def test_invalid_expression_returns_none(self): + """Invalid JMESPath expression — extract_session_key returns None, no raise.""" + # Note: jmespath.compile validation is done at schema creation time. + # At runtime, malformed expressions caught and return None. + payload = {"data": "value"} + key = extract_session_key("[invalid(((", payload) + assert key is None + + def test_empty_payload(self): + key = extract_session_key("pull_request.number", {}) + assert key is None + + def test_list_value_converted_to_string(self): + payload = {"labels": ["bug", "feature"]} + key = extract_session_key("labels", payload) + # Lists stringify as Python list repr — usable as session key + assert key is not None + + +# --------------------------------------------------------------------------- +# EventResponse schema +# --------------------------------------------------------------------------- + + +class TestEventResponseSchema: + def test_events_queued_defaults_to_zero(self): + from openhands.automation.schemas import EventResponse + + resp = EventResponse(received=True, matched=3, runs_created=["a", "b"]) + assert resp.events_queued == 0 + + def test_events_queued_set(self): + from openhands.automation.schemas import EventResponse + + resp = EventResponse( + received=True, matched=2, runs_created=["a"], events_queued=1 + ) + assert resp.events_queued == 1