diff --git a/openhands/automation/preset_router.py b/openhands/automation/preset_router.py index d63ab10..45b37cc 100644 --- a/openhands/automation/preset_router.py +++ b/openhands/automation/preset_router.py @@ -20,6 +20,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel, ConfigDict, Field, model_validator +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from openhands.automation.auth import AuthenticatedUser, authenticate_request @@ -28,8 +29,12 @@ from openhands.automation.models import Automation, TarballUpload, UploadStatus from openhands.automation.schemas import AutomationResponse, Trigger from openhands.automation.storage import FileStore, get_file_store +from openhands.automation.utils import utcnow from openhands.automation.utils.model_profiles import resolve_model_profile_for_user -from openhands.automation.utils.tarball_validation import build_internal_url +from openhands.automation.utils.tarball_validation import ( + build_internal_url, + parse_internal_upload_id, +) from openhands.sdk.plugin import PluginSource from openhands.workspace import RepoSource @@ -207,6 +212,150 @@ def _build_storage_path( return f"uploads/{org_id}/{user_id}/{upload_id}.tar" +def _replace_prompt_in_tarball(tarball_bytes: bytes, new_prompt: str) -> bytes | None: + """Return a copy of a preset tarball with ``prompt.txt`` swapped for ``new_prompt``. + + Every other member (``main.py``, ``setup.sh``, ``plugins_config.json``, + ``repos_config.json``, ...) is copied through unchanged, so plugin and repo + configuration are preserved and the working template is untouched. + + Returns ``None`` if the archive has no ``prompt.txt`` member — i.e. it is not a + regenerable preset tarball — so the caller can leave the tarball as-is. + """ + out_buffer = io.BytesIO() + found = False + with ( + tarfile.open(fileobj=io.BytesIO(tarball_bytes), mode="r:gz") as src, + tarfile.open(fileobj=out_buffer, mode="w:gz") as dst, + ): + for member in src.getmembers(): + if member.name == "prompt.txt": + found = True + _add_file_to_tar( + dst, "prompt.txt", new_prompt, mode=member.mode or 0o644 + ) + continue + if member.isfile(): + extracted = src.extractfile(member) + data = extracted.read() if extracted is not None else b"" + info = tarfile.TarInfo(name=member.name) + info.size = len(data) + info.mode = member.mode + info.mtime = member.mtime + dst.addfile(info, io.BytesIO(data)) + else: + dst.addfile(member) + + if not found: + return None + + out_buffer.seek(0) + return out_buffer.read() + + +async def regenerate_preset_prompt_tarball( + automation: Automation, + new_prompt: str, + session: AsyncSession, +) -> str | None: + """Rebuild a preset automation's tarball with an updated prompt. + + Preset automations bake the prompt into ``prompt.txt`` inside the tarball the + dispatcher executes; the stored ``prompt`` column is metadata only. When the + prompt is edited the tarball must be rewritten too, otherwise dispatching keeps + running the original prompt. + + Reads the automation's current internal-upload tarball, swaps in ``new_prompt`` + (leaving all other files untouched), uploads the result as a new internal upload, + and returns its ``oh-internal://`` URL for the caller to store on ``tarball_path``. + + Returns ``None`` — leaving the tarball unchanged — when the automation is not a + regenerable preset: its ``tarball_path`` is an external URL, the referenced upload + is missing, or the archive contains no ``prompt.txt``. The file store is resolved + lazily so that updates to non-preset automations never construct one. + """ + upload_id = parse_internal_upload_id(automation.tarball_path) + if upload_id is None: + return None + + file_store = get_file_store() + result = await session.execute( + select(TarballUpload).where(TarballUpload.id == upload_id) + ) + source_upload = result.scalars().first() + if source_upload is None: + return None + + try: + current_tarball = file_store.read(source_upload.storage_path) + except FileNotFoundError: + return None + + new_tarball = _replace_prompt_in_tarball(current_tarball, new_prompt) + if new_tarball is None: + return None + + new_upload_id = uuid.uuid4() + storage_path = _build_storage_path( + automation.org_id, automation.user_id, new_upload_id + ) + upload = TarballUpload( + id=new_upload_id, + user_id=automation.user_id, + org_id=automation.org_id, + name=f"prompt-automation-{_safe_truncate(automation.name, 50)}-edit", + description=f"Prompt updated for: {_safe_truncate(automation.name, 100)}", + status=UploadStatus.UPLOADING, + storage_path=storage_path, + ) + session.add(upload) + await session.flush() + + try: + size_bytes = await file_store.write_stream( + path=storage_path, + stream=_bytes_to_async_iter(new_tarball), + content_type="application/x-tar", + ) + upload.status = UploadStatus.COMPLETED + upload.size_bytes = size_bytes + except Exception as e: + # The session is rolled back when the HTTPException propagates (see + # get_session), so don't flush here — the in-memory status/error_message + # are only for log/debug context and won't be persisted. + logger.exception("Failed to upload regenerated tarball: %s", e) + upload.status = UploadStatus.FAILED + upload.error_message = f"Upload failed: {e!s}" + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to upload regenerated tarball: {e!s}", + ) + + # The old tarball is now superseded. Remove its file and soft-delete the + # upload record so repeated prompt edits don't accumulate orphaned storage. + # Only soft-delete once the file is confirmed gone: if the delete fails the + # record stays live so the still-present file remains discoverable for a + # later retry/cleanup instead of becoming a hidden orphan (file on disk, + # record marked deleted). + file_removed = False + try: + file_store.delete(source_upload.storage_path) + file_removed = True + except FileNotFoundError: + file_removed = True + except Exception as e: + logger.exception( + "Failed to delete superseded tarball at %s: %s", + source_upload.storage_path, + e, + ) + if file_removed: + source_upload.deleted_at = utcnow() + + await session.flush() + return build_internal_url(new_upload_id) + + @router.post("/prompt", status_code=status.HTTP_201_CREATED) async def create_automation_from_prompt( body: CreatePromptAutomationRequest, diff --git a/openhands/automation/router.py b/openhands/automation/router.py index 9cd9fbd..930ce9c 100644 --- a/openhands/automation/router.py +++ b/openhands/automation/router.py @@ -20,6 +20,7 @@ AutomationRunStatus, TarballUpload, ) +from openhands.automation.preset_router import regenerate_preset_prompt_tarball from openhands.automation.schemas import ( AutomationListResponse, AutomationResponse, @@ -152,9 +153,26 @@ async def update_automation( if "model" in update_data: update_data["model"] = resolve_model_profile_for_user(body.model, user) + original_prompt = auto.prompt for field, value in update_data.items(): setattr(auto, field, value) + # A preset automation bakes its prompt into the tarball the dispatcher + # executes; the `prompt` column is metadata only. When the prompt actually + # changes, rebuild the tarball so the next dispatch runs the new prompt + # instead of the original baked one. Skipped when the value is unchanged (a + # no-op edit), or for non-preset automations. + if ( + "prompt" in update_data + and isinstance(auto.prompt, str) + and auto.prompt != original_prompt + ): + new_tarball_path = await regenerate_preset_prompt_tarball( + auto, auto.prompt, session + ) + if new_tarball_path is not None: + auto.tarball_path = new_tarball_path + # Note: updated_at is handled automatically by the model's onupdate=utcnow await session.flush() await session.refresh(auto) diff --git a/tests/test_preset_router.py b/tests/test_preset_router.py index 4ccf750..9a121ea 100644 --- a/tests/test_preset_router.py +++ b/tests/test_preset_router.py @@ -14,6 +14,7 @@ from openhands.automation.preset_router import ( _generate_plugin_tarball, _generate_tarball, + _replace_prompt_in_tarball, ) from openhands.sdk.plugin import PluginSource from openhands.workspace import RepoSource @@ -210,6 +211,58 @@ def test_generate_tarball_with_repos(self): assert repos_config[1]["ref"] == "main" +class TestReplacePromptInTarball: + """Tests for swapping prompt.txt inside an existing preset tarball.""" + + def test_replaces_prompt_and_preserves_sibling_files(self): + """The prompt is swapped while every other file is left byte-for-byte intact.""" + # Arrange — a plugin preset tarball carries main.py, setup.sh, prompt.txt, + # plugins_config.json and repos_config.json; all but the prompt must survive. + original = _generate_plugin_tarball( + [PluginSource(source="github:owner/repo")], + "Original prompt", + repos=[RepoSource(url="owner/repo", provider="github")], + ) + + # Act + updated = _replace_prompt_in_tarball(original, "New prompt") + + # Assert + assert updated is not None + + def _read(tarball_bytes): + files = {} + with tarfile.open(fileobj=io.BytesIO(tarball_bytes), mode="r:gz") as tar: + for member in tar.getmembers(): + if not member.isfile(): + continue + extracted = tar.extractfile(member) + assert extracted is not None + files[member.name] = extracted.read() + return files, tar.getmember("setup.sh").mode + + old_files, _ = _read(original) + new_files, new_setup_mode = _read(updated) + + assert new_files["prompt.txt"].decode() == "New prompt" + for name in ("main.py", "setup.sh", "plugins_config.json", "repos_config.json"): + assert new_files[name] == old_files[name] + assert new_setup_mode & 0o100 # setup.sh stays executable + + def test_returns_none_when_tarball_has_no_prompt(self): + """A tarball without prompt.txt is not regenerable, so None is returned.""" + # Arrange — an archive that has no prompt.txt member. + buffer = io.BytesIO() + with tarfile.open(fileobj=buffer, mode="w:gz") as tar: + data = b"print('hi')" + info = tarfile.TarInfo(name="main.py") + info.size = len(data) + tar.addfile(info, io.BytesIO(data)) + + # Act / Assert + assert _replace_prompt_in_tarball(buffer.getvalue(), "New prompt") is None + + class TestRepoSource: """Tests for RepoSource model.""" diff --git a/tests/test_router.py b/tests/test_router.py index 878c05b..e5dda36 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -1,11 +1,19 @@ """Tests for API router endpoints.""" +import io +import tarfile import uuid +from unittest.mock import AsyncMock, MagicMock import pytest -from openhands.automation.models import Automation +from openhands.automation.models import Automation, TarballUpload, UploadStatus +from openhands.automation.preset_router import _build_storage_path, _generate_tarball from openhands.automation.utils import utcnow +from openhands.automation.utils.tarball_validation import ( + build_internal_url, + parse_internal_upload_id, +) # Test UUIDs matching mock_authenticated_user fixture @@ -15,6 +23,71 @@ OTHER_ORG_ID = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") +@pytest.fixture +def preset_store(monkeypatch): + """Stateful in-memory file store wired into the preset tarball helpers. + + The PATCH handler resolves the file store lazily via + ``preset_router.get_file_store`` only when regenerating a preset tarball, so + tests patch that name rather than overriding a FastAPI dependency. + """ + from collections.abc import AsyncIterator + + from openhands.automation import preset_router + + store = MagicMock() + store._storage = {} + + async def _write_stream( + path: str, + stream: AsyncIterator[bytes], + max_size: int | None = None, + content_type: str = "application/octet-stream", + ) -> int: + content = b"" + async for chunk in stream: + content += chunk + store._storage[path] = content + return len(content) + + store.write_stream = AsyncMock(side_effect=_write_stream) + store.read = MagicMock(side_effect=lambda path: store._storage[path]) + store.delete = MagicMock(side_effect=lambda path: store._storage.pop(path, None)) + monkeypatch.setattr(preset_router, "get_file_store", lambda: store) + return store + + +async def _seed_prompt_preset_automation(async_session, store, prompt): + """Insert a prompt-preset automation whose internal tarball bakes ``prompt``.""" + upload_id = uuid.uuid4() + storage_path = _build_storage_path(TEST_ORG_ID, TEST_USER_ID, upload_id) + store._storage[storage_path] = _generate_tarball(prompt) + + async_session.add( + TarballUpload( + id=upload_id, + user_id=TEST_USER_ID, + org_id=TEST_ORG_ID, + name="seed-upload", + status=UploadStatus.COMPLETED, + storage_path=storage_path, + ) + ) + automation = Automation( + user_id=TEST_USER_ID, + org_id=TEST_ORG_ID, + name="Preset Automation", + prompt=prompt, + trigger={"type": "cron", "schedule": "0 9 * * *", "timezone": "UTC"}, + tarball_path=build_internal_url(upload_id), + setup_script_path="setup.sh", + entrypoint=".venv/bin/python main.py", + ) + async_session.add(automation) + await async_session.commit() + return automation + + @pytest.fixture def _automation_for_permission_tests(async_session): """Create an automation owned by the test user for permission tests.""" @@ -879,6 +952,113 @@ async def test_update_automation_prompt(self, async_client, async_session): assert response.status_code == 200 assert response.json()["prompt"] == "Updated prompt" + async def test_update_prompt_regenerates_preset_tarball( + self, async_client, async_session, preset_store + ): + """Editing the prompt rebuilds the baked tarball the dispatcher executes.""" + # Arrange — a prompt-preset automation whose tarball bakes "Original prompt". + automation = await _seed_prompt_preset_automation( + async_session, preset_store, "Original prompt" + ) + original_tarball_path = automation.tarball_path + old_upload_id = parse_internal_upload_id(original_tarball_path) + assert old_upload_id is not None + old_storage_path = _build_storage_path(TEST_ORG_ID, TEST_USER_ID, old_upload_id) + + # Act — edit the prompt. + response = await async_client.patch( + f"/api/automation/v1/{automation.id}", + json={"prompt": "Updated prompt"}, + ) + + # Assert — both the stored prompt and the executable tarball reflect the edit. + assert response.status_code == 200 + data = response.json() + assert data["prompt"] == "Updated prompt" + assert data["tarball_path"] != original_tarball_path + + new_upload_id = parse_internal_upload_id(data["tarball_path"]) + assert new_upload_id is not None + new_storage_path = _build_storage_path(TEST_ORG_ID, TEST_USER_ID, new_upload_id) + with tarfile.open( + fileobj=io.BytesIO(preset_store._storage[new_storage_path]), mode="r:gz" + ) as tar: + prompt_file = tar.extractfile("prompt.txt") + assert prompt_file is not None + assert prompt_file.read().decode() == "Updated prompt" + + # The superseded tarball file is removed so storage doesn't grow unbounded. + assert old_storage_path not in preset_store._storage + + async def test_update_name_does_not_regenerate_preset_tarball( + self, async_client, async_session, preset_store + ): + """Editing a non-prompt field leaves the baked tarball untouched.""" + # Arrange + automation = await _seed_prompt_preset_automation( + async_session, preset_store, "Original prompt" + ) + original_tarball_path = automation.tarball_path + + # Act — edit only the name. + response = await async_client.patch( + f"/api/automation/v1/{automation.id}", + json={"name": "Renamed"}, + ) + + # Assert — name changes, tarball reference is preserved. + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Renamed" + assert data["tarball_path"] == original_tarball_path + + async def test_update_unchanged_prompt_does_not_regenerate_tarball( + self, async_client, async_session, preset_store + ): + """Re-sending the same prompt is a no-op: no tarball rebuild, no new upload.""" + # Arrange + automation = await _seed_prompt_preset_automation( + async_session, preset_store, "Same prompt" + ) + original_tarball_path = automation.tarball_path + + # Act — PATCH the prompt with the value it already has. + response = await async_client.patch( + f"/api/automation/v1/{automation.id}", + json={"prompt": "Same prompt"}, + ) + + # Assert — tarball untouched and no new upload was written. + assert response.status_code == 200 + assert response.json()["tarball_path"] == original_tarball_path + preset_store.write_stream.assert_not_called() + + async def test_update_prompt_upload_failure_returns_500( + self, async_client, async_session, preset_store + ): + """If the regenerated tarball fails to upload, the edit fails cleanly. + + A 500 is returned and the automation still points at its original + tarball — no half-committed state leaks through. + """ + # Arrange — make the upload step fail. + automation = await _seed_prompt_preset_automation( + async_session, preset_store, "Original prompt" + ) + original_tarball_path = automation.tarball_path + preset_store.write_stream = AsyncMock(side_effect=RuntimeError("storage down")) + + # Act + response = await async_client.patch( + f"/api/automation/v1/{automation.id}", + json={"prompt": "Updated prompt"}, + ) + + # Assert + assert response.status_code == 500 + await async_session.refresh(automation) + assert automation.tarball_path == original_tarball_path + async def test_update_automation_timeout(self, async_client, async_session): """Can update automation timeout.""" automation = Automation(