diff --git a/rock/sdk/envhub/regionless/__init__.py b/rock/sdk/envhub/regionless/__init__.py new file mode 100644 index 0000000000..c092432a4e --- /dev/null +++ b/rock/sdk/envhub/regionless/__init__.py @@ -0,0 +1,18 @@ +from rock.sdk.envhub.regionless.compose import compose_pull, resolve_compose +from rock.sdk.envhub.regionless.resolver import ( + ROCK_REGISTRY_ENV, + RegionlessResolver, + RockRegistryResolver, + resolve_dockerfile, + resolve_image, +) + +__all__ = [ + "ROCK_REGISTRY_ENV", + "RegionlessResolver", + "RockRegistryResolver", + "compose_pull", + "resolve_compose", + "resolve_dockerfile", + "resolve_image", +] diff --git a/rock/sdk/envhub/regionless/compose.py b/rock/sdk/envhub/regionless/compose.py new file mode 100644 index 0000000000..fa83537a55 --- /dev/null +++ b/rock/sdk/envhub/regionless/compose.py @@ -0,0 +1,120 @@ +"""Compose-level regionless support: resolve images in compose files and pull.""" + +from __future__ import annotations + +import asyncio +import os +import subprocess +from pathlib import Path + +import yaml + +from rock.logger import init_logger +from rock.sdk.envhub.regionless.resolver import _DEFAULT_PROBE_TIMEOUT_SEC, RockRegistryResolver + +logger = init_logger(__name__) + +_default_resolver = RockRegistryResolver() + + +async def resolve_compose( + compose_path: Path, + *, + timeout_sec: float = _DEFAULT_PROBE_TIMEOUT_SEC, + resolver: RockRegistryResolver | None = None, +) -> bool: + """Rewrite ``image:`` of every service in a compose file to ROCK mirrors when available. + + Only rewrites ``image:`` fields; ``build:`` sections are left untouched. + Returns True if any service image was rewritten. + """ + r = resolver or _default_resolver + + try: + text = compose_path.read_text() + data = yaml.safe_load(text) + except Exception: + logger.warning("Failed to parse compose file %s, skipping regionless rewrite", compose_path, exc_info=True) + return False + + if not isinstance(data, dict): + return False + + services = data.get("services") + if not isinstance(services, dict): + return False + + changed = False + for svc_name, svc_config in services.items(): + if not isinstance(svc_config, dict): + continue + image = svc_config.get("image") + if not isinstance(image, str) or not image: + continue + resolved = await r.resolve_image(image, timeout_sec=timeout_sec) + if resolved != image: + svc_config["image"] = resolved + changed = True + + if changed: + compose_path.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False)) + + return changed + + +async def compose_pull( + compose_path: Path, + *, + services: list[str] | None = None, + project_name: str | None = None, + env: dict[str, str] | None = None, + timeout_sec: float = _DEFAULT_PROBE_TIMEOUT_SEC, + extra_args: list[str] | None = None, + resolver: RockRegistryResolver | None = None, +) -> subprocess.CompletedProcess: + """Resolve images in *compose_path* to ROCK mirrors, then ``docker compose pull``. + + 1. resolve_compose(compose_path) — rewrite service images to regional mirrors + 2. docker compose -f [-p ] pull [services...] + + Resolve failures are non-blocking (fail-safe). Pull failures raise RuntimeError. + """ + try: + await resolve_compose(compose_path, timeout_sec=timeout_sec, resolver=resolver) + except Exception: + logger.warning("resolve_compose failed for %s, proceeding with original images", compose_path, exc_info=True) + + cmd = ["docker", "compose", "-f", str(compose_path)] + if project_name: + cmd.extend(["-p", project_name]) + cmd.append("pull") + if extra_args: + cmd.extend(extra_args) + if services: + cmd.extend(services) + + run_env = dict(os.environ) + if env: + run_env.update(env) + + proc = await asyncio.create_subprocess_exec( + *cmd, + env=run_env, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout_bytes, stderr_bytes = await proc.communicate() + + result = subprocess.CompletedProcess( + args=cmd, + returncode=proc.returncode, + stdout=stdout_bytes.decode(errors="replace") if stdout_bytes else "", + stderr=stderr_bytes.decode(errors="replace") if stderr_bytes else "", + ) + + if result.returncode != 0: + raise RuntimeError( + f"docker compose pull failed (exit {result.returncode}):\nstdout: {result.stdout}\nstderr: {result.stderr}" + ) + + return result diff --git a/rock/sdk/envhub/regionless/resolver.py b/rock/sdk/envhub/regionless/resolver.py new file mode 100644 index 0000000000..a0ca1c9e72 --- /dev/null +++ b/rock/sdk/envhub/regionless/resolver.py @@ -0,0 +1,238 @@ +"""Rewrite container image references to a ROCK mirror registry when available. + +Resolution rule: + +- Read ``INSTANCE_ROCK_REGISTRY`` (one or more ``host/namespace`` entries + separated by ``,`` or ``;``, e.g. + ``reg-a.aliyuncs.com/mirror-1,reg-b.aliyuncs.com/mirror-2``). If unset/empty, + return the original image unchanged. +- Take the *last* path segment of the image reference as the image name and + combine it with the original tag/digest: + ``swebench/sweb.eval.x86_64.foo:latest`` → + ``/sweb.eval.x86_64.foo:latest``. +- For each configured registry in order, probe the candidate via the Docker + Registry v2 manifest API (``GET /v2/{repo}/manifests/{tag}``), with Bearer + token authentication support and a short timeout. Return the *first* + candidate that exists; if none exist (or the probes time out / fail), + fall back to the original image. Probe results are cached in-process so + concurrent trials do not hammer the registry. +""" + +from __future__ import annotations + +import asyncio +import os +import re +from pathlib import Path +from urllib.parse import urlencode + +import httpx + +from rock.logger import init_logger + +logger = init_logger(__name__) + +ROCK_REGISTRY_ENV = "INSTANCE_ROCK_REGISTRY" +_DEFAULT_PROBE_TIMEOUT_SEC = 5.0 +_REGISTRY_SEPARATORS = (",", ";") + +_FROM_RE = re.compile( + r"^(?P\s*FROM\s+(?:--\S+\s+)*)(?P\S+)(?P.*)$", + re.IGNORECASE, +) + + +class RockRegistryResolver: + """Resolves container image references to ROCK mirror registries.""" + + def __init__(self) -> None: + self._resolve_cache: dict[str, str] = {} + self._cache_lock = asyncio.Lock() + + @staticmethod + def parse_registries(raw: str) -> list[str]: + """Split the env value into an ordered list of non-empty registry entries.""" + if not raw: + return [] + tokens = [raw] + for sep in _REGISTRY_SEPARATORS: + tokens = [piece for token in tokens for piece in token.split(sep)] + return [token.strip().rstrip("/") for token in tokens if token.strip()] + + @staticmethod + def split_tag_or_digest(image: str) -> tuple[str, str]: + """Split image into (path, tag-or-digest-suffix). + + Examples: + "foo/bar:1.2" -> ("foo/bar", ":1.2") + "foo/bar@sha256:abc" -> ("foo/bar", "@sha256:abc") + "foo/bar" -> ("foo/bar", ":latest") + """ + if "@" in image: + path, _, digest = image.partition("@") + return path, f"@{digest}" + last_slash = image.rfind("/") + last_colon = image.rfind(":") + if last_colon > last_slash: + return image[:last_colon], f":{image[last_colon + 1 :]}" + return image, ":latest" + + @staticmethod + def build_candidate(image: str, registry: str) -> str: + """Build the candidate image reference under the ROCK registry. + + Strips the original registry and first-level namespace, preserving any + nested namespaces and the image name plus tag/digest. + Example: ``ghcr.io/foo/bar/baz:v1`` with registry ``reg/ns`` → + ``reg/ns/bar/baz:v1``. + """ + path, suffix = RockRegistryResolver.split_tag_or_digest(image) + if "/" in path: + _, path = path.split("/", 1) + if "/" in path: + _, path = path.split("/", 1) + return f"{registry.rstrip('/')}/{path}{suffix}" + + @staticmethod + def _parse_bearer_challenge(header: str) -> dict[str, str]: + """Parse ``realm``, ``service``, ``scope`` from a Bearer WWW-Authenticate header.""" + return {m.group(1): m.group(2) for m in re.finditer(r'(\w+)="([^"]*)"', header)} + + @staticmethod + def _parse_image_parts(image: str) -> tuple[str, str, str]: + """Extract (registry_host, repo_path, tag) from a fully-qualified image reference.""" + path, suffix = RockRegistryResolver.split_tag_or_digest(image) + tag = suffix.lstrip(":") + first_slash = path.find("/") + if first_slash == -1: + return (path, "", tag) + registry = path[:first_slash] + repo = path[first_slash + 1 :] + return (registry, repo, tag) + + async def _http_probe_manifest(self, image: str, timeout_sec: float) -> bool: + """Check whether *image* exists on its registry via the v2 manifest API.""" + registry, repo, tag = self._parse_image_parts(image) + if not registry or not repo: + return False + + url = f"https://{registry}/v2/{repo}/manifests/{tag}" + headers = { + "Accept": ", ".join( + [ + "application/vnd.docker.distribution.manifest.v2+json", + "application/vnd.oci.image.manifest.v1+json", + "application/vnd.docker.distribution.manifest.list.v2+json", + "application/vnd.oci.image.index.v1+json", + ] + ) + } + + try: + async with httpx.AsyncClient(timeout=timeout_sec) as client: + resp = await client.get(url, headers=headers) + + if resp.status_code == 401 and "www-authenticate" in resp.headers: + www_auth = resp.headers["www-authenticate"] + if www_auth.startswith("Bearer "): + params = self._parse_bearer_challenge(www_auth) + realm = params.get("realm", "") + service = params.get("service", "") + scope = params.get("scope", "") + token_url = f"{realm}?{urlencode({'service': service, 'scope': scope})}" + token_resp = await client.get(token_url) + if token_resp.status_code == 200: + data = token_resp.json() + token = data.get("token") or data.get("access_token") + if token: + headers["Authorization"] = f"Bearer {token}" + resp = await client.get(url, headers=headers) + + return resp.status_code == 200 + except httpx.HTTPError: + logger.debug("HTTP probe for %s failed (network/protocol)", image, exc_info=True) + return False + except (ValueError, KeyError): + logger.debug("HTTP probe for %s failed (response parsing)", image, exc_info=True) + return False + except Exception: + logger.warning("HTTP probe for %s failed (unexpected)", image, exc_info=True) + return False + + async def resolve_image( + self, + image: str, + *, + timeout_sec: float = _DEFAULT_PROBE_TIMEOUT_SEC, + ) -> str: + """Return a ROCK-mirrored image reference if available, else ``image``.""" + if not image: + return image + + if "@" in image: + return image + + registries = self.parse_registries(os.environ.get(ROCK_REGISTRY_ENV, "")) + if not registries: + return image + + cache_key = f"{'|'.join(registries)}||{image}" + async with self._cache_lock: + cached = self._resolve_cache.get(cache_key) + if cached is not None: + return cached + + resolved = image + for registry in registries: + candidate = self.build_candidate(image, registry) + if candidate == image: + continue + if await self._http_probe_manifest(candidate, timeout_sec=timeout_sec): + logger.info("Rewriting image %s -> %s (ROCK mirror)", image, candidate) + resolved = candidate + break + + async with self._cache_lock: + self._resolve_cache[cache_key] = resolved + return resolved + + async def resolve_dockerfile( + self, + dockerfile: Path, + *, + timeout_sec: float = _DEFAULT_PROBE_TIMEOUT_SEC, + ) -> bool: + """Rewrite ``FROM`` images in *dockerfile* to ROCK mirrors when available. + + Returns True if any ``FROM`` line was rewritten. + """ + text = dockerfile.read_text() + lines = text.splitlines(keepends=True) + changed = False + + for i, line in enumerate(lines): + m = _FROM_RE.match(line.rstrip("\n\r")) + if not m: + continue + original = m.group("image") + resolved = await self.resolve_image(original, timeout_sec=timeout_sec) + if resolved != original: + eol = line[len(line.rstrip("\n\r")) :] + lines[i] = f"{m.group('prefix')}{resolved}{m.group('suffix')}{eol}" + changed = True + + if changed: + dockerfile.write_text("".join(lines)) + return changed + + def reset_cache(self) -> None: + """Clear the in-process resolve cache. Test helper.""" + self._resolve_cache.clear() + + +RegionlessResolver = RockRegistryResolver + +_default_resolver = RockRegistryResolver() + +resolve_image = _default_resolver.resolve_image +resolve_dockerfile = _default_resolver.resolve_dockerfile diff --git a/tests/unit/envhub/__init__.py b/tests/unit/envhub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/envhub/regionless/__init__.py b/tests/unit/envhub/regionless/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/envhub/regionless/test_compose.py b/tests/unit/envhub/regionless/test_compose.py new file mode 100644 index 0000000000..ee267da555 --- /dev/null +++ b/tests/unit/envhub/regionless/test_compose.py @@ -0,0 +1,179 @@ +from unittest.mock import AsyncMock, patch + +import pytest +import yaml + +from rock.sdk.envhub.regionless.compose import compose_pull, resolve_compose +from rock.sdk.envhub.regionless.resolver import ROCK_REGISTRY_ENV, RockRegistryResolver + + +@pytest.fixture() +def resolver(): + r = RockRegistryResolver() + yield r + r.reset_cache() + + +def _write_compose(path, services: dict): + data = {"version": "3", "services": services} + path.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False)) + return path + + +class TestResolveCompose: + async def test_rewrites_service_image(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + cf = _write_compose( + tmp_path / "docker-compose.yml", + { + "web": {"image": "ghcr.io/org/app:v1", "ports": ["8080:80"]}, + }, + ) + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=True), + ): + changed = await resolve_compose(cf, resolver=resolver) + assert changed + data = yaml.safe_load(cf.read_text()) + assert data["services"]["web"]["image"] == "reg.example.com/ns/app:v1" + + async def test_no_change_on_miss(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + cf = _write_compose( + tmp_path / "docker-compose.yml", + { + "web": {"image": "ghcr.io/org/app:v1"}, + }, + ) + original = cf.read_text() + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=False), + ): + changed = await resolve_compose(cf, resolver=resolver) + assert not changed + assert cf.read_text() == original + + async def test_no_env_noop(self, tmp_path, resolver, monkeypatch): + monkeypatch.delenv(ROCK_REGISTRY_ENV, raising=False) + cf = _write_compose( + tmp_path / "docker-compose.yml", + { + "web": {"image": "ghcr.io/org/app:v1"}, + }, + ) + original = cf.read_text() + changed = await resolve_compose(cf, resolver=resolver) + assert not changed + assert cf.read_text() == original + + async def test_dedupes_probe(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + cf = _write_compose( + tmp_path / "docker-compose.yml", + { + "svc1": {"image": "ghcr.io/org/app:v1"}, + "svc2": {"image": "ghcr.io/org/app:v1"}, + }, + ) + probe = AsyncMock(return_value=True) + with patch.object(RockRegistryResolver, "_http_probe_manifest", new=probe): + changed = await resolve_compose(cf, resolver=resolver) + assert changed + assert probe.await_count == 1 + + async def test_ignores_build_section(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + cf = _write_compose( + tmp_path / "docker-compose.yml", + { + "builder": {"build": {"context": ".", "dockerfile": "Dockerfile"}}, + "web": {"image": "ghcr.io/org/app:v1"}, + }, + ) + probe = AsyncMock(return_value=True) + with patch.object(RockRegistryResolver, "_http_probe_manifest", new=probe): + changed = await resolve_compose(cf, resolver=resolver) + assert changed + data = yaml.safe_load(cf.read_text()) + assert "build" in data["services"]["builder"] + assert "image" not in data["services"]["builder"] + assert data["services"]["web"]["image"] == "reg.example.com/ns/app:v1" + assert probe.await_count == 1 + + +class TestComposePull: + async def test_calls_resolve_then_pull(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + cf = _write_compose( + tmp_path / "docker-compose.yml", + { + "web": {"image": "ghcr.io/org/app:v1"}, + }, + ) + + with ( + patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=True), + ), + patch("rock.sdk.envhub.regionless.compose.asyncio.create_subprocess_exec") as mock_exec, + ): + proc_mock = AsyncMock() + proc_mock.returncode = 0 + proc_mock.communicate.return_value = (b"Done\n", b"") + mock_exec.return_value = proc_mock + + result = await compose_pull(cf, resolver=resolver) + + assert result.returncode == 0 + cmd = mock_exec.call_args[0] + assert "docker" in cmd + assert "pull" in cmd + + async def test_propagates_pull_failure(self, tmp_path, resolver, monkeypatch): + monkeypatch.delenv(ROCK_REGISTRY_ENV, raising=False) + cf = _write_compose( + tmp_path / "docker-compose.yml", + { + "web": {"image": "ghcr.io/org/app:v1"}, + }, + ) + + with patch("rock.sdk.envhub.regionless.compose.asyncio.create_subprocess_exec") as mock_exec: + proc_mock = AsyncMock() + proc_mock.returncode = 1 + proc_mock.communicate.return_value = (b"", b"Error: pull access denied\n") + mock_exec.return_value = proc_mock + + with pytest.raises(RuntimeError, match="docker compose pull failed"): + await compose_pull(cf, resolver=resolver) + + async def test_resolve_failure_is_non_blocking(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + cf = _write_compose( + tmp_path / "docker-compose.yml", + { + "web": {"image": "ghcr.io/org/app:v1"}, + }, + ) + + with ( + patch( + "rock.sdk.envhub.regionless.compose.resolve_compose", + new=AsyncMock(side_effect=Exception("resolve boom")), + ), + patch("rock.sdk.envhub.regionless.compose.asyncio.create_subprocess_exec") as mock_exec, + ): + proc_mock = AsyncMock() + proc_mock.returncode = 0 + proc_mock.communicate.return_value = (b"Done\n", b"") + mock_exec.return_value = proc_mock + + result = await compose_pull(cf, resolver=resolver) + + assert result.returncode == 0 diff --git a/tests/unit/envhub/regionless/test_resolver.py b/tests/unit/envhub/regionless/test_resolver.py new file mode 100644 index 0000000000..a9507473fb --- /dev/null +++ b/tests/unit/envhub/regionless/test_resolver.py @@ -0,0 +1,284 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from rock.sdk.envhub.regionless.resolver import ( + ROCK_REGISTRY_ENV, + RockRegistryResolver, +) + + +@pytest.fixture() +def resolver(): + r = RockRegistryResolver() + yield r + r.reset_cache() + + +class TestSplitTagOrDigest: + def test_explicit_tag(self): + assert RockRegistryResolver.split_tag_or_digest("foo/bar:1.2") == ("foo/bar", ":1.2") + + def test_no_tag_defaults_to_latest(self): + assert RockRegistryResolver.split_tag_or_digest("foo/bar") == ("foo/bar", ":latest") + + def test_digest(self): + assert RockRegistryResolver.split_tag_or_digest("foo/bar@sha256:abc") == ("foo/bar", "@sha256:abc") + + def test_host_with_port_no_tag(self): + assert RockRegistryResolver.split_tag_or_digest("registry:5000/foo/bar") == ( + "registry:5000/foo/bar", + ":latest", + ) + + def test_host_with_port_and_tag(self): + assert RockRegistryResolver.split_tag_or_digest("registry:5000/foo/bar:v1") == ( + "registry:5000/foo/bar", + ":v1", + ) + + +class TestParseRegistries: + def test_empty(self): + assert RockRegistryResolver.parse_registries("") == [] + assert RockRegistryResolver.parse_registries(" ") == [] + + def test_single(self): + assert RockRegistryResolver.parse_registries("reg.example.com/ns") == ["reg.example.com/ns"] + + def test_comma_separated(self): + assert RockRegistryResolver.parse_registries("a.com/x, b.com/y") == ["a.com/x", "b.com/y"] + + def test_semicolon_separated(self): + assert RockRegistryResolver.parse_registries("a.com/x;b.com/y") == ["a.com/x", "b.com/y"] + + def test_mixed_separators_and_trailing_slashes(self): + assert RockRegistryResolver.parse_registries("a.com/x/, b.com/y; c.com/z/") == [ + "a.com/x", + "b.com/y", + "c.com/z", + ] + + def test_skips_blank_entries(self): + assert RockRegistryResolver.parse_registries("a.com/x,,; ; b.com/y") == ["a.com/x", "b.com/y"] + + +class TestBuildCandidate: + def test_strips_first_namespace_only(self): + assert ( + RockRegistryResolver.build_candidate("swebench/sweb.eval.x86_64.foo:latest", "reg.example.com/mirror") + == "reg.example.com/mirror/sweb.eval.x86_64.foo:latest" + ) + + def test_explicit_dockerhub_host(self): + assert ( + RockRegistryResolver.build_candidate("docker.io/library/python:3.12", "reg.example.com/ns") + == "reg.example.com/ns/python:3.12" + ) + + def test_ghcr_image(self): + assert ( + RockRegistryResolver.build_candidate("ghcr.io/foo/bar/baz:v1", "reg.example.com/ns") + == "reg.example.com/ns/bar/baz:v1" + ) + + def test_default_tag_when_missing(self): + assert RockRegistryResolver.build_candidate("foo/bar", "reg.example.com/ns") == "reg.example.com/ns/bar:latest" + + def test_strips_trailing_slash_on_registry(self): + assert RockRegistryResolver.build_candidate("foo/bar:1", "reg.example.com/ns/") == "reg.example.com/ns/bar:1" + + def test_deeply_nested_namespaces_preserved(self): + assert ( + RockRegistryResolver.build_candidate("gcr.io/project/subproj/image:v1", "reg.example.com/ns") + == "reg.example.com/ns/subproj/image:v1" + ) + + +class TestResolveImage: + async def test_no_env_returns_original(self, resolver, monkeypatch): + monkeypatch.delenv(ROCK_REGISTRY_ENV, raising=False) + result = await resolver.resolve_image("foo/bar:1") + assert result == "foo/bar:1" + + async def test_empty_env_returns_original(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, " ") + result = await resolver.resolve_image("foo/bar:1") + assert result == "foo/bar:1" + + async def test_empty_image_returns_unchanged(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + assert await resolver.resolve_image("") == "" + + async def test_digest_pinned_skips_rewrite_and_probe(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + probe = AsyncMock(return_value=True) + with patch.object(RockRegistryResolver, "_http_probe_manifest", new=probe): + result = await resolver.resolve_image("swebench/foo@sha256:abc123") + assert result == "swebench/foo@sha256:abc123" + probe.assert_not_awaited() + + async def test_hit_returns_candidate(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=True), + ) as probe: + result = await resolver.resolve_image("swebench/foo:latest") + assert result == "reg.example.com/ns/foo:latest" + probe.assert_awaited_once() + + async def test_miss_returns_original(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=False), + ): + result = await resolver.resolve_image("swebench/foo:latest") + assert result == "swebench/foo:latest" + + async def test_cache_skips_second_probe(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + probe = AsyncMock(return_value=True) + with patch.object(RockRegistryResolver, "_http_probe_manifest", new=probe): + await resolver.resolve_image("swebench/foo:latest") + await resolver.resolve_image("swebench/foo:latest") + assert probe.await_count == 1 + + async def test_cache_keyed_by_registry(self, resolver, monkeypatch): + probe = AsyncMock(return_value=True) + with patch.object(RockRegistryResolver, "_http_probe_manifest", new=probe): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg-a.example.com/ns") + await resolver.resolve_image("swebench/foo:latest") + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg-b.example.com/ns") + await resolver.resolve_image("swebench/foo:latest") + assert probe.await_count == 2 + + async def test_probe_timeout_falls_back(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=False), + ): + result = await resolver.resolve_image("swebench/foo:latest", timeout_sec=0.05) + assert result == "swebench/foo:latest" + + async def test_multi_registry_first_hit_wins(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg-a.example.com/ns,reg-b.example.com/ns") + calls: list[str] = [] + + async def _probe(self, candidate, *, timeout_sec): + calls.append(candidate) + return candidate.startswith("reg-a.example.com/") + + with patch.object(RockRegistryResolver, "_http_probe_manifest", new=_probe): + result = await resolver.resolve_image("swebench/foo:latest") + assert result == "reg-a.example.com/ns/foo:latest" + assert calls == ["reg-a.example.com/ns/foo:latest"] + + async def test_multi_registry_falls_through_to_second(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg-a.example.com/ns,reg-b.example.com/ns") + calls: list[str] = [] + + async def _probe(self, candidate, *, timeout_sec): + calls.append(candidate) + return candidate.startswith("reg-b.example.com/") + + with patch.object(RockRegistryResolver, "_http_probe_manifest", new=_probe): + result = await resolver.resolve_image("swebench/foo:latest") + assert result == "reg-b.example.com/ns/foo:latest" + assert calls == [ + "reg-a.example.com/ns/foo:latest", + "reg-b.example.com/ns/foo:latest", + ] + + async def test_multi_registry_all_miss_returns_original(self, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg-a.example.com/ns;reg-b.example.com/ns") + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=False), + ) as probe: + result = await resolver.resolve_image("swebench/foo:latest") + assert result == "swebench/foo:latest" + assert probe.await_count == 2 + + +class TestResolveDockerfile: + async def test_rewrites_from_line(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + df = tmp_path / "Dockerfile" + df.write_text("FROM old-registry.com/repo/swe-bench:v1\nRUN echo hello\n") + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=True), + ): + changed = await resolver.resolve_dockerfile(df) + assert changed + assert "FROM reg.example.com/ns/swe-bench:v1\n" in df.read_text() + + async def test_no_change_on_miss(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + df = tmp_path / "Dockerfile" + original = "FROM old-registry.com/repo/swe-bench:v1\nRUN echo hello\n" + df.write_text(original) + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=False), + ): + changed = await resolver.resolve_dockerfile(df) + assert not changed + assert df.read_text() == original + + async def test_no_change_without_env(self, tmp_path, resolver, monkeypatch): + monkeypatch.delenv(ROCK_REGISTRY_ENV, raising=False) + df = tmp_path / "Dockerfile" + original = "FROM python:3.11\nRUN echo hello\n" + df.write_text(original) + changed = await resolver.resolve_dockerfile(df) + assert not changed + assert df.read_text() == original + + async def test_preserves_as_clause(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + df = tmp_path / "Dockerfile" + df.write_text("FROM old-registry.com/repo/base:v1 AS builder\nRUN make\n") + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=True), + ): + changed = await resolver.resolve_dockerfile(df) + assert changed + assert "FROM reg.example.com/ns/base:v1 AS builder\n" in df.read_text() + + async def test_handles_platform_flag(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + df = tmp_path / "Dockerfile" + df.write_text("FROM --platform=linux/amd64 ghcr.io/laude-institute/t-bench/deveval:latest\nRUN echo hello\n") + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=True), + ): + changed = await resolver.resolve_dockerfile(df) + assert changed + assert "FROM --platform=linux/amd64 reg.example.com/ns/t-bench/deveval:latest\n" in df.read_text() + + async def test_skips_comment_lines(self, tmp_path, resolver, monkeypatch): + monkeypatch.setenv(ROCK_REGISTRY_ENV, "reg.example.com/ns") + df = tmp_path / "Dockerfile" + original = "# FROM old-registry.com/repo/base:v1\nFROM python:3.11\n" + df.write_text(original) + with patch.object( + RockRegistryResolver, + "_http_probe_manifest", + new=AsyncMock(return_value=False), + ): + changed = await resolver.resolve_dockerfile(df) + assert not changed