diff --git a/openhands/automation/preset_router.py b/openhands/automation/preset_router.py index 45b37cc..815541d 100644 --- a/openhands/automation/preset_router.py +++ b/openhands/automation/preset_router.py @@ -465,6 +465,22 @@ async def create_automation_from_prompt( # --- Plugin Preset --- +MAX_VARIANTS = 10 + + +class ExperimentVariant(BaseModel): + """A single variant in an A/B test experiment.""" + + model_config = ConfigDict(extra="forbid") + + name: str = Field(..., min_length=1, max_length=100) + weight: int = Field(..., gt=0, description="Relative selection weight (> 0)") + plugins: list[PluginSource] = Field( + ..., + min_length=1, + description="Plugin(s) for this variant.", + ) + class CreatePluginAutomationRequest(BaseModel): """Request to create an automation using plugins.""" @@ -472,11 +488,22 @@ class CreatePluginAutomationRequest(BaseModel): model_config = ConfigDict(extra="forbid") name: str = Field(..., min_length=1, max_length=500) - plugins: list[PluginSource] = Field( - ..., - description="Plugin(s) to load. Can be a single plugin or a list of plugins. " - "Each plugin specifies a source (github:owner/repo, git URL, or local path), " - "optional ref (branch/tag/commit), and optional repo_path for monorepos.", + plugins: list[PluginSource] | None = Field( + default=None, + description="Plugin(s) to load. Mutually exclusive with 'variants'.", + ) + variants: list[ExperimentVariant] | None = Field( + default=None, + description=( + "A/B test variants. Each variant specifies its own plugin set and a " + "relative weight. Mutually exclusive with 'plugins'." + ), + ) + experiment_id: str | None = Field( + default=None, + min_length=1, + max_length=200, + description="Required when using variants. A human-readable experiment name.", ) prompt: str = Field( ..., @@ -524,7 +551,7 @@ def normalize_plugins_and_repos(cls, data: dict) -> dict: # type: ignore[type-a """Normalize plugins and repos to always be lists.""" if isinstance(data, dict): # Normalize plugins - if "plugins" in data: + if "plugins" in data and data["plugins"] is not None: plugins = data["plugins"] if isinstance(plugins, dict): data["plugins"] = [plugins] @@ -537,46 +564,77 @@ def normalize_plugins_and_repos(cls, data: dict) -> dict: # type: ignore[type-a data["repos"] = [repos] return data + @model_validator(mode="after") + def validate_plugins_or_variants(self) -> "CreatePluginAutomationRequest": + """Enforce mutual exclusivity between plugins and variants.""" + if (self.plugins is None) == (self.variants is None): + raise ValueError("Exactly one of 'plugins' or 'variants' must be provided.") + + if self.variants is not None: + if self.experiment_id is None: + raise ValueError("'experiment_id' is required when using 'variants'.") + if len(self.variants) < 2: + raise ValueError("At least two variants are required for an A/B test.") + if len(self.variants) > MAX_VARIANTS: + raise ValueError(f"At most {MAX_VARIANTS} variants are allowed.") + names = [v.name for v in self.variants] + if len(names) != len(set(names)): + raise ValueError("Variant names must be unique.") + else: + if self.experiment_id is not None: + raise ValueError( + "'experiment_id' can only be used with 'variants', not 'plugins'." + ) + + return self + def _generate_plugin_tarball( - plugins: list[PluginSource], prompt: str, repos: list[RepoSource] | None = None + plugins: list[PluginSource] | None, + prompt: str, + repos: list[RepoSource] | None = None, + *, + experiment_id: str | None = None, + variants: list[ExperimentVariant] | None = None, ) -> bytes: """Generate a tarball containing SDK code, plugin config, and prompt. - The tarball contains: - - main.py: SDK boilerplate that loads plugins and runs conversation - - plugins_config.json: List of plugin sources (serialized PluginSource models) - - prompt.txt: The prompt to send - - setup.sh: Script to install the SDK - - repos_config.json: (optional) Repository configuration for cloning - - Note: Clone and skill loading functionality is now provided by the SDK's - OpenHandsCloudWorkspace.clone_repos() and load_skills_from_agent_server() - methods, so separate scripts are no longer needed. - - Args: - plugins: List of plugins to load - prompt: The user's prompt text - repos: Optional list of repositories to clone - - Returns: - bytes: The tarball content as bytes + When *variants* is provided the tarball contains ``experiment_config.json`` + instead of ``plugins_config.json``. The two are mutually exclusive. """ preset_files = _load_plugin_preset_files() - # Serialize plugins using Pydantic (exclude None values for cleaner JSON) - plugins_config = [p.model_dump(exclude_none=True) for p in plugins] - plugins_config_json = json.dumps(plugins_config, indent=2) - tarball_buffer = io.BytesIO() with tarfile.open(fileobj=tarball_buffer, mode="w:gz") as tar: _add_file_to_tar(tar, "main.py", preset_files["main.py"]) - _add_file_to_tar(tar, "plugins_config.json", plugins_config_json) _add_file_to_tar(tar, "prompt.txt", prompt) _add_file_to_tar(tar, "setup.sh", preset_files["setup.sh"], mode=0o755) - # Add repos config if repos specified (SDK workspace handles cloning) + if variants is not None: + experiment_config = { + "experiment_id": experiment_id, + "variants": [ + { + "name": v.name, + "weight": v.weight, + "plugins": [p.model_dump(exclude_none=True) for p in v.plugins], + } + for v in variants + ], + } + _add_file_to_tar( + tar, + "experiment_config.json", + json.dumps(experiment_config, indent=2), + ) + else: + assert plugins is not None # guaranteed by caller + plugins_config = [p.model_dump(exclude_none=True) for p in plugins] + _add_file_to_tar( + tar, "plugins_config.json", json.dumps(plugins_config, indent=2) + ) + if repos: repos_config = [r.model_dump(exclude_none=True) for r in repos] _add_file_to_tar( @@ -622,9 +680,13 @@ async def create_automation_from_plugin( """ model = resolve_model_profile_for_user(body.model, user) - # 1. Generate tarball with SDK code, plugin config, prompt, and repos config + # 1. Generate tarball with SDK code, plugin/experiment config, and prompt tarball_content = _generate_plugin_tarball( - body.plugins, body.prompt, repos=body.repos + body.plugins, + body.prompt, + repos=body.repos, + experiment_id=body.experiment_id, + variants=body.variants, ) # 2. Upload tarball to storage @@ -632,14 +694,23 @@ async def create_automation_from_plugin( storage_path = _build_storage_path(user.org_id, user.user_id, upload_id) # Create upload record - plugin_sources_str = _format_plugin_sources_for_description(body.plugins) - truncated_sources = _safe_truncate(plugin_sources_str, 100) + if body.variants is not None: + variant_names = ", ".join(v.name for v in body.variants) + description = _safe_truncate( + f"A/B experiment {body.experiment_id}: {variant_names}", 200 + ) + else: + assert body.plugins is not None # guaranteed by validator + plugin_sources_str = _format_plugin_sources_for_description(body.plugins) + truncated = _safe_truncate(plugin_sources_str, 100) + description = f"Auto-generated with plugins: {truncated}" + upload = TarballUpload( id=upload_id, user_id=user.user_id, org_id=user.org_id, name=f"plugin-automation-{_safe_truncate(body.name, 50)}", - description=f"Auto-generated with plugins: {truncated_sources}", + description=description, status=UploadStatus.UPLOADING, storage_path=storage_path, ) @@ -697,14 +768,17 @@ async def create_automation_from_plugin( detail=f"Failed to create automation: {e!s}", ) - logger.info( - "Created automation from plugin", - extra={ - "automation_id": str(automation.id), - "upload_id": str(upload_id), - "plugin_count": len(body.plugins), - "prompt_length": len(body.prompt), - }, - ) + log_extra: dict[str, Any] = { + "automation_id": str(automation.id), + "upload_id": str(upload_id), + "prompt_length": len(body.prompt), + } + if body.variants is not None: + log_extra["experiment_id"] = body.experiment_id + log_extra["variant_count"] = len(body.variants) + elif body.plugins is not None: + log_extra["plugin_count"] = len(body.plugins) + + logger.info("Created automation from plugin", extra=log_extra) return AutomationResponse.model_validate(automation) diff --git a/openhands/automation/presets/plugin/sdk_main.py b/openhands/automation/presets/plugin/sdk_main.py index 69c30de..068238f 100644 --- a/openhands/automation/presets/plugin/sdk_main.py +++ b/openhands/automation/presets/plugin/sdk_main.py @@ -64,6 +64,7 @@ import json import os +import random import sys import time @@ -215,11 +216,32 @@ repos_context = workspace.get_repos_context(clone_result.repo_mappings) # Load configuration files + EXPERIMENT_CONFIG_FILE = os.path.join(SCRIPT_DIR, "experiment_config.json") PLUGINS_CONFIG_FILE = os.path.join(SCRIPT_DIR, "plugins_config.json") PROMPT_FILE = os.path.join(SCRIPT_DIR, "prompt.txt") - with open(PLUGINS_CONFIG_FILE) as f: - plugins_config = json.load(f) + # Experiment-aware variant selection + experiment_id: str | None = None + selected_variant: str | None = None + + if os.path.exists(EXPERIMENT_CONFIG_FILE): + with open(EXPERIMENT_CONFIG_FILE) as f: + experiment_config = json.load(f) + + experiment_id = experiment_config["experiment_id"] + variants = experiment_config["variants"] + weights = [v["weight"] for v in variants] + selected = random.choices(variants, weights=weights, k=1)[0] + + selected_variant = selected["name"] + plugins_config = selected["plugins"] + print("\n=== EXPERIMENT ===") + print(f" id: {experiment_id}") + print(f" variant: {selected_variant}") + print(f" weights: {dict(zip([v['name'] for v in variants], weights))}") + else: + with open(PLUGINS_CONFIG_FILE) as f: + plugins_config = json.load(f) with open(PROMPT_FILE) as f: USER_PROMPT = f.read() @@ -330,16 +352,30 @@ def event_callback(event) -> None: received_events.append(event) last_event_time["ts"] = time.time() + # Build experiment tags (if running an A/B test) + experiment_tags: dict[str, str] = {} + if experiment_id: + experiment_tags["experiment_id"] = experiment_id + if selected_variant is None: + raise RuntimeError( + "BUG: experiment_id is set but selected_variant is None — " + "experiment config may be malformed." + ) + experiment_tags["variant"] = selected_variant + conversation = Conversation( agent=agent, workspace=workspace, plugins=plugin_sources, # All plugins loaded here callbacks=[event_callback], delete_on_close=False, # Keep conversation history after completion + tags=experiment_tags or None, ) assert isinstance(conversation, RemoteConversation) print(f" conversation created: {type(conversation).__name__}") print(f" plugins loaded: {len(plugin_sources)}") + if experiment_tags: + print(f" experiment tags: {experiment_tags}") # Inject secrets into the conversation (auto-exported as env vars in bash) if secrets: diff --git a/tests/test_ab_testing_integration.py b/tests/test_ab_testing_integration.py new file mode 100644 index 0000000..c86bc69 --- /dev/null +++ b/tests/test_ab_testing_integration.py @@ -0,0 +1,483 @@ +"""Integration tests for A/B testing plugin variants. + +Uses SQLite (no Docker needed) to test the full flow: + 1. Create an A/B test automation via POST /v1/preset/plugin + 2. Verify the API response + 3. Inspect the generated tarball (experiment_config.json) + 4. Simulate the sdk_main.py variant selection logic + 5. Verify backward compatibility with standard plugin automations +""" + +import io +import json +import os +import random +import tarfile +import uuid +from collections.abc import AsyncGenerator, AsyncIterator +from typing import ClassVar +from unittest.mock import AsyncMock, MagicMock + +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + + +os.environ["LOG_JSON"] = "0" + +from openhands.automation.app import app # noqa: E402 +from openhands.automation.auth import ( # noqa: E402 + AuthenticatedUser, + AuthMethod, + authenticate_request, + create_http_client, +) +from openhands.automation.db import get_session # noqa: E402 +from openhands.automation.models import Base # noqa: E402 +from openhands.automation.storage import get_file_store # noqa: E402 + + +# --- Fixtures (SQLite-based, no Docker) --- + +TEST_USER_ID = uuid.UUID("12345678-1234-5678-1234-567812345678") +TEST_ORG_ID = uuid.UUID("87654321-4321-8765-4321-876543218765") + + +@pytest.fixture +async def async_engine(): + """Create an in-memory SQLite engine.""" + engine = create_async_engine( + "sqlite+aiosqlite://", + connect_args={"check_same_thread": False}, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await engine.dispose() + + +@pytest.fixture +async def async_session_factory(async_engine): + return async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + +@pytest.fixture +async def async_session(async_session_factory) -> AsyncGenerator[AsyncSession, None]: + async with async_session_factory() as session: + yield session + + +@pytest.fixture +def mock_authenticated_user(): + return AuthenticatedUser( + user_id=TEST_USER_ID, + org_id=TEST_ORG_ID, + email="test@example.com", + role="owner", + permissions=["view_org_settings", "manage_automations"], + auth_method=AuthMethod.API_KEY, + api_key="test-api-key", + ) + + +@pytest.fixture +def mock_file_store(): + """Mock file store that captures uploaded content.""" + store = MagicMock() + store._captured_content = None + + async def mock_write_stream( + path: str, + stream: AsyncIterator[bytes], + max_size: int | None = None, + content_type: str = "application/octet-stream", + ) -> int: + content = b"" + async for chunk in stream: + content += chunk + store._captured_content = content + return len(content) + + store.write_stream = AsyncMock(side_effect=mock_write_stream) + store.delete = MagicMock() + return store + + +@pytest.fixture +async def client( + async_engine, + async_session_factory, + async_session, + mock_authenticated_user, + mock_file_store, +) -> AsyncGenerator[AsyncClient, None]: + """Async test client with SQLite DB, mock auth, and mock file store.""" + + async def override_get_session(): + yield async_session + + app.dependency_overrides[get_session] = override_get_session + app.dependency_overrides[authenticate_request] = lambda: mock_authenticated_user + app.dependency_overrides[get_file_store] = lambda: mock_file_store + + app.state.engine = async_engine + app.state.session_factory = async_session_factory + app.state.http_client = create_http_client() + + async with AsyncClient( + transport=ASGITransport(app=app), + base_url="http://test", + ) as c: + yield c + + await app.state.http_client.aclose() + app.dependency_overrides.clear() + + +# --- Helpers --- + + +def _extract_tarball(mock_store) -> dict[str, bytes]: + """Extract all files from the captured tarball.""" + assert mock_store._captured_content is not None + files = {} + with tarfile.open( + fileobj=io.BytesIO(mock_store._captured_content), mode="r:gz" + ) as tar: + for member in tar.getmembers(): + f = tar.extractfile(member) + if f: + files[member.name] = f.read() + return files + + +def _simulate_variant_selection(experiment_config: dict, seed: int = 42) -> str: + """Select a variant using the same weighted-random logic as sdk_main.py. + + Uses a seeded RNG for deterministic test assertions; production code + uses the global (unseeded) random module. + """ + rng = random.Random(seed) + variants = experiment_config["variants"] + weights = [v["weight"] for v in variants] + selected = rng.choices(variants, weights=weights, k=1)[0] + return selected["name"] + + +# --- Tests --- + + +class TestABTestAutomationCreation: + """End-to-end: create an A/B test automation via the API.""" + + AB_PAYLOAD: ClassVar[dict] = { + "name": "PR Review A/B Test", + "experiment_id": "pr-review-v2-test", + "variants": [ + { + "name": "control", + "weight": 70, + "plugins": [ + { + "source": "github:OpenHands/extensions", + "repo_path": "plugins/pr-review", + "ref": "v1.0.0", + }, + ], + }, + { + "name": "treatment", + "weight": 30, + "plugins": [ + { + "source": "github:OpenHands/extensions", + "repo_path": "plugins/pr-review", + "ref": "v2.0.0", + }, + ], + }, + ], + "prompt": "Review this PR for code quality and potential bugs.", + "trigger": { + "type": "event", + "source": "github", + "on": "pull_request.opened", + }, + } + + async def test_create_ab_automation_returns_201(self, client, mock_file_store): + """POST with variants returns 201 and correct metadata.""" + resp = await client.post( + "/api/automation/v1/preset/plugin", + json=self.AB_PAYLOAD, + ) + assert resp.status_code == 201, resp.text + data = resp.json() + + assert data["name"] == "PR Review A/B Test" + assert data["prompt"] == self.AB_PAYLOAD["prompt"] + assert data["trigger"]["type"] == "event" + assert data["tarball_path"].startswith("oh-internal://uploads/") + assert data["enabled"] is True + + async def test_tarball_contains_experiment_config(self, client, mock_file_store): + """Generated tarball has experiment_config.json, not plugins_config.json.""" + resp = await client.post( + "/api/automation/v1/preset/plugin", + json=self.AB_PAYLOAD, + ) + assert resp.status_code == 201 + + files = _extract_tarball(mock_file_store) + assert "experiment_config.json" in files + assert "plugins_config.json" not in files + assert "main.py" in files + assert "prompt.txt" in files + assert "setup.sh" in files + + async def test_experiment_config_matches_request(self, client, mock_file_store): + """experiment_config.json faithfully represents the request.""" + resp = await client.post( + "/api/automation/v1/preset/plugin", + json=self.AB_PAYLOAD, + ) + assert resp.status_code == 201 + + files = _extract_tarball(mock_file_store) + config = json.loads(files["experiment_config.json"]) + + assert config["experiment_id"] == "pr-review-v2-test" + assert len(config["variants"]) == 2 + + control = config["variants"][0] + assert control["name"] == "control" + assert control["weight"] == 70 + assert control["plugins"][0]["ref"] == "v1.0.0" + + treatment = config["variants"][1] + assert treatment["name"] == "treatment" + assert treatment["weight"] == 30 + assert treatment["plugins"][0]["ref"] == "v2.0.0" + + async def test_main_py_has_experiment_support(self, client, mock_file_store): + """main.py template includes experiment detection and tagging code.""" + resp = await client.post( + "/api/automation/v1/preset/plugin", + json=self.AB_PAYLOAD, + ) + assert resp.status_code == 201 + + files = _extract_tarball(mock_file_store) + main_py = files["main.py"].decode("utf-8") + + assert "experiment_config.json" in main_py + assert "experiment_id" in main_py + assert "selected_variant" in main_py + assert "random.choices" in main_py + assert "experiment_tags" in main_py + + +class TestVariantSelectionLogic: + """Test the runtime variant selection as it would run in sdk_main.py.""" + + EXPERIMENT_CONFIG: ClassVar[dict] = { + "experiment_id": "test-experiment", + "variants": [ + { + "name": "control", + "weight": 80, + "plugins": [{"source": "github:o/r", "ref": "v1"}], + }, + { + "name": "treatment", + "weight": 20, + "plugins": [{"source": "github:o/r", "ref": "v2"}], + }, + ], + } + + def test_selection_respects_weights_distribution(self): + """Over many runs, variant selection roughly follows weight ratios.""" + counts = {"control": 0, "treatment": 0} + for seed in range(1000): + name = _simulate_variant_selection(self.EXPERIMENT_CONFIG, seed=seed) + counts[name] += 1 + + # 80/20 weights → expect ~800/200 with some variance + assert counts["control"] > 600, ( + f"control selected only {counts['control']}/1000 times" + ) + assert counts["treatment"] > 100, ( + f"treatment selected only {counts['treatment']}/1000 times" + ) + + def test_deterministic_with_same_seed(self): + """Same seed always picks the same variant.""" + v1 = _simulate_variant_selection(self.EXPERIMENT_CONFIG, seed=12345) + v2 = _simulate_variant_selection(self.EXPERIMENT_CONFIG, seed=12345) + assert v1 == v2 + + def test_selected_variant_has_plugins(self): + """The selected variant carries its plugin config.""" + variants = self.EXPERIMENT_CONFIG["variants"] + weights = [v["weight"] for v in variants] + rng = random.Random(42) + selected = rng.choices(variants, weights=weights, k=1)[0] + + assert "plugins" in selected + assert len(selected["plugins"]) == 1 + assert "source" in selected["plugins"][0] + + def test_equal_weights_both_variants_appear(self): + """With 50/50 weights, both variants should appear over many runs.""" + config = { + "experiment_id": "equal", + "variants": [ + {"name": "a", "weight": 50, "plugins": [{"source": "github:o/r"}]}, + {"name": "b", "weight": 50, "plugins": [{"source": "github:o/r"}]}, + ], + } + seen = {_simulate_variant_selection(config, seed=s) for s in range(100)} + assert seen == {"a", "b"} + + +class TestBackwardCompatibility: + """Standard plugin automations still work exactly as before.""" + + STANDARD_PAYLOAD: ClassVar[dict] = { + "name": "Standard Plugin Automation", + "plugins": [ + {"source": "github:owner/plugin", "ref": "v1.0.0"}, + ], + "prompt": "Do something with the plugin.", + "trigger": {"type": "cron", "schedule": "0 9 * * *"}, + } + + async def test_standard_plugin_automation_returns_201( + self, client, mock_file_store + ): + """Standard (non-experiment) request still works.""" + resp = await client.post( + "/api/automation/v1/preset/plugin", + json=self.STANDARD_PAYLOAD, + ) + assert resp.status_code == 201, resp.text + data = resp.json() + assert data["name"] == "Standard Plugin Automation" + + async def test_standard_tarball_has_plugins_config(self, client, mock_file_store): + """Standard tarball uses plugins_config.json, no experiment_config.json.""" + resp = await client.post( + "/api/automation/v1/preset/plugin", + json=self.STANDARD_PAYLOAD, + ) + assert resp.status_code == 201 + + files = _extract_tarball(mock_file_store) + assert "plugins_config.json" in files + assert "experiment_config.json" not in files + + config = json.loads(files["plugins_config.json"]) + assert len(config) == 1 + assert config[0]["source"] == "github:owner/plugin" + assert config[0]["ref"] == "v1.0.0" + + +class TestABTestValidationViaAPI: + """Validation errors return proper 422 responses through the API.""" + + async def test_both_plugins_and_variants_rejected(self, client): + resp = await client.post( + "/api/automation/v1/preset/plugin", + json={ + "name": "Bad", + "plugins": [{"source": "github:o/r"}], + "variants": [ + {"name": "a", "weight": 1, "plugins": [{"source": "github:o/r"}]}, + {"name": "b", "weight": 1, "plugins": [{"source": "github:o/r"}]}, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + }, + ) + assert resp.status_code == 422 + + async def test_missing_experiment_id_rejected(self, client): + resp = await client.post( + "/api/automation/v1/preset/plugin", + json={ + "name": "Bad", + "variants": [ + {"name": "a", "weight": 1, "plugins": [{"source": "github:o/r"}]}, + {"name": "b", "weight": 1, "plugins": [{"source": "github:o/r"}]}, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + }, + ) + assert resp.status_code == 422 + + async def test_single_variant_rejected(self, client): + resp = await client.post( + "/api/automation/v1/preset/plugin", + json={ + "name": "Bad", + "experiment_id": "test", + "variants": [ + { + "name": "only", + "weight": 1, + "plugins": [{"source": "github:o/r"}], + }, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + }, + ) + assert resp.status_code == 422 + + async def test_duplicate_names_rejected(self, client): + resp = await client.post( + "/api/automation/v1/preset/plugin", + json={ + "name": "Bad", + "experiment_id": "test", + "variants": [ + { + "name": "same", + "weight": 1, + "plugins": [{"source": "github:o/r"}], + }, + { + "name": "same", + "weight": 1, + "plugins": [{"source": "github:o/r"}], + }, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + }, + ) + assert resp.status_code == 422 + + async def test_zero_weight_rejected(self, client): + resp = await client.post( + "/api/automation/v1/preset/plugin", + json={ + "name": "Bad", + "experiment_id": "test", + "variants": [ + {"name": "a", "weight": 0, "plugins": [{"source": "github:o/r"}]}, + {"name": "b", "weight": 1, "plugins": [{"source": "github:o/r"}]}, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + }, + ) + assert resp.status_code == 422 diff --git a/tests/test_preset_router.py b/tests/test_preset_router.py index 9a121ea..3bc88fc 100644 --- a/tests/test_preset_router.py +++ b/tests/test_preset_router.py @@ -907,6 +907,317 @@ def test_generate_plugin_tarball_with_repos(self): assert repos_config[1]["ref"] == "develop" +class TestExperimentVariantValidation: + """Tests for A/B test variant validation on CreatePluginAutomationRequest.""" + + def test_variants_accepted(self): + """Valid variants request is accepted.""" + from openhands.automation.preset_router import CreatePluginAutomationRequest + + request = CreatePluginAutomationRequest.model_validate( + { + "name": "AB Test", + "experiment_id": "my-experiment", + "variants": [ + { + "name": "control", + "weight": 50, + "plugins": [{"source": "github:owner/repo", "ref": "v1"}], + }, + { + "name": "treatment", + "weight": 50, + "plugins": [{"source": "github:owner/repo", "ref": "v2"}], + }, + ], + "prompt": "Test prompt", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + assert request.variants is not None + assert len(request.variants) == 2 + assert request.plugins is None + assert request.experiment_id == "my-experiment" + + def test_plugins_and_variants_mutually_exclusive(self): + """Providing both plugins and variants raises error.""" + from openhands.automation.preset_router import CreatePluginAutomationRequest + + with pytest.raises(ValueError, match="Exactly one of"): + CreatePluginAutomationRequest.model_validate( + { + "name": "Bad", + "plugins": [{"source": "github:owner/repo"}], + "variants": [ + { + "name": "a", + "weight": 1, + "plugins": [{"source": "github:owner/repo"}], + }, + { + "name": "b", + "weight": 1, + "plugins": [{"source": "github:owner/repo"}], + }, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + + def test_neither_plugins_nor_variants_rejected(self): + """Providing neither plugins nor variants raises error.""" + from openhands.automation.preset_router import CreatePluginAutomationRequest + + with pytest.raises(ValueError, match="Exactly one of"): + CreatePluginAutomationRequest.model_validate( + { + "name": "Bad", + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + + def test_experiment_id_required_with_variants(self): + """experiment_id is required when variants is used.""" + from openhands.automation.preset_router import CreatePluginAutomationRequest + + with pytest.raises(ValueError, match="experiment_id.*required"): + CreatePluginAutomationRequest.model_validate( + { + "name": "Test", + "variants": [ + { + "name": "a", + "weight": 1, + "plugins": [{"source": "github:owner/repo"}], + }, + { + "name": "b", + "weight": 1, + "plugins": [{"source": "github:owner/repo"}], + }, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + + def test_experiment_id_rejected_with_plugins(self): + """experiment_id cannot be used with plugins (only variants).""" + from openhands.automation.preset_router import CreatePluginAutomationRequest + + with pytest.raises(ValueError, match="experiment_id.*can only be used with"): + CreatePluginAutomationRequest.model_validate( + { + "name": "Test", + "plugins": [{"source": "github:owner/repo"}], + "experiment_id": "oops", + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + + def test_single_variant_rejected(self): + """At least two variants are required.""" + from openhands.automation.preset_router import CreatePluginAutomationRequest + + with pytest.raises(ValueError, match="At least two variants"): + CreatePluginAutomationRequest.model_validate( + { + "name": "Test", + "experiment_id": "test", + "variants": [ + { + "name": "only", + "weight": 1, + "plugins": [{"source": "github:owner/repo"}], + }, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + + def test_duplicate_variant_names_rejected(self): + """Variant names must be unique.""" + from openhands.automation.preset_router import CreatePluginAutomationRequest + + with pytest.raises(ValueError, match="unique"): + CreatePluginAutomationRequest.model_validate( + { + "name": "Test", + "experiment_id": "test", + "variants": [ + { + "name": "same", + "weight": 1, + "plugins": [{"source": "github:owner/repo"}], + }, + { + "name": "same", + "weight": 1, + "plugins": [{"source": "github:owner/repo"}], + }, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + + def test_zero_weight_rejected(self): + """Variant weight must be positive.""" + from openhands.automation.preset_router import CreatePluginAutomationRequest + + with pytest.raises(ValueError): + CreatePluginAutomationRequest.model_validate( + { + "name": "Test", + "experiment_id": "test", + "variants": [ + { + "name": "a", + "weight": 0, + "plugins": [{"source": "github:owner/repo"}], + }, + { + "name": "b", + "weight": 1, + "plugins": [{"source": "github:owner/repo"}], + }, + ], + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + + def test_max_variants_accepted(self): + """Exactly MAX_VARIANTS variants is accepted.""" + from openhands.automation.preset_router import ( + MAX_VARIANTS, + CreatePluginAutomationRequest, + ) + + variants = [ + {"name": f"v{i}", "weight": 1, "plugins": [{"source": "github:owner/repo"}]} + for i in range(MAX_VARIANTS) + ] + req = CreatePluginAutomationRequest.model_validate( + { + "name": "Test", + "experiment_id": "boundary-test", + "variants": variants, + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + assert req.variants is not None and len(req.variants) == MAX_VARIANTS + + def test_too_many_variants_rejected(self): + """More than MAX_VARIANTS is rejected.""" + from openhands.automation.preset_router import ( + MAX_VARIANTS, + CreatePluginAutomationRequest, + ) + + variants = [ + {"name": f"v{i}", "weight": 1, "plugins": [{"source": "github:owner/repo"}]} + for i in range(MAX_VARIANTS + 1) + ] + with pytest.raises(ValueError, match="At most"): + CreatePluginAutomationRequest.model_validate( + { + "name": "Test", + "experiment_id": "test", + "variants": variants, + "prompt": "Test", + "trigger": {"type": "cron", "schedule": "0 0 * * *"}, + } + ) + + +class TestExperimentTarball: + """Tests for experiment (A/B) tarball generation.""" + + def test_experiment_tarball_contains_experiment_config(self): + """Experiment tarball has experiment_config.json, not plugins_config.json.""" + from openhands.automation.preset_router import ExperimentVariant + + variants = [ + ExperimentVariant( + name="control", + weight=50, + plugins=[PluginSource(source="github:owner/repo", ref="v1")], + ), + ExperimentVariant( + name="treatment", + weight=50, + plugins=[PluginSource(source="github:owner/repo", ref="v2")], + ), + ] + tarball_bytes = _generate_plugin_tarball( + None, + "Test prompt", + experiment_id="test-exp", + variants=variants, + ) + + with tarfile.open(fileobj=io.BytesIO(tarball_bytes), mode="r:gz") as tar: + names = tar.getnames() + assert "experiment_config.json" in names + assert "plugins_config.json" not in names + assert "main.py" in names + assert "prompt.txt" in names + assert "setup.sh" in names + + def test_experiment_config_content(self): + """experiment_config.json has correct structure.""" + from openhands.automation.preset_router import ExperimentVariant + + variants = [ + ExperimentVariant( + name="control", + weight=70, + plugins=[PluginSource(source="github:owner/repo", ref="v1")], + ), + ExperimentVariant( + name="treatment", + weight=30, + plugins=[PluginSource(source="github:owner/repo", ref="v2")], + ), + ] + tarball_bytes = _generate_plugin_tarball( + None, + "Test prompt", + experiment_id="my-exp", + variants=variants, + ) + + with tarfile.open(fileobj=io.BytesIO(tarball_bytes), mode="r:gz") as tar: + config_file = tar.extractfile("experiment_config.json") + assert config_file is not None + config = json.loads(config_file.read().decode("utf-8")) + + assert config["experiment_id"] == "my-exp" + assert len(config["variants"]) == 2 + assert config["variants"][0]["name"] == "control" + assert config["variants"][0]["weight"] == 70 + assert config["variants"][0]["plugins"][0]["source"] == "github:owner/repo" + assert config["variants"][0]["plugins"][0]["ref"] == "v1" + assert config["variants"][1]["name"] == "treatment" + assert config["variants"][1]["weight"] == 30 + + def test_standard_tarball_unchanged(self): + """Non-experiment tarball still produces plugins_config.json.""" + plugins = [PluginSource(source="github:owner/repo", ref="main")] + tarball_bytes = _generate_plugin_tarball(plugins, "Test prompt") + + with tarfile.open(fileobj=io.BytesIO(tarball_bytes), mode="r:gz") as tar: + names = tar.getnames() + assert "plugins_config.json" in names + assert "experiment_config.json" not in names + + @requires_docker class TestCreateAutomationFromPlugin: """Tests for POST /v1/preset/plugin endpoint."""