Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions rock/sdk/envhub/regionless/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from rock.sdk.envhub.regionless.compose import compose_pull, resolve_compose
from rock.sdk.envhub.regionless.resolver import (
ROCK_REGISTRY_ENV,
RegionlessResolver,
RockRegistryResolver,
resolve_dockerfile,
resolve_image,
)

__all__ = [
"ROCK_REGISTRY_ENV",
"RegionlessResolver",
"RockRegistryResolver",
"compose_pull",
"resolve_compose",
"resolve_dockerfile",
"resolve_image",
]
120 changes: 120 additions & 0 deletions rock/sdk/envhub/regionless/compose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Compose-level regionless support: resolve images in compose files and pull."""

from __future__ import annotations

import asyncio
import os
import subprocess
from pathlib import Path

import yaml

from rock.logger import init_logger
from rock.sdk.envhub.regionless.resolver import _DEFAULT_PROBE_TIMEOUT_SEC, RockRegistryResolver

logger = init_logger(__name__)

_default_resolver = RockRegistryResolver()


async def resolve_compose(
compose_path: Path,
*,
timeout_sec: float = _DEFAULT_PROBE_TIMEOUT_SEC,
resolver: RockRegistryResolver | None = None,
) -> bool:
"""Rewrite ``image:`` of every service in a compose file to ROCK mirrors when available.

Only rewrites ``image:`` fields; ``build:`` sections are left untouched.
Returns True if any service image was rewritten.
"""
r = resolver or _default_resolver

try:
text = compose_path.read_text()
data = yaml.safe_load(text)
except Exception:
logger.warning("Failed to parse compose file %s, skipping regionless rewrite", compose_path, exc_info=True)
return False

if not isinstance(data, dict):
return False

services = data.get("services")
if not isinstance(services, dict):
return False

changed = False
for svc_name, svc_config in services.items():
if not isinstance(svc_config, dict):
continue
image = svc_config.get("image")
if not isinstance(image, str) or not image:
continue
resolved = await r.resolve_image(image, timeout_sec=timeout_sec)
if resolved != image:
svc_config["image"] = resolved
changed = True

if changed:
compose_path.write_text(yaml.dump(data, default_flow_style=False, sort_keys=False))

return changed


async def compose_pull(
compose_path: Path,
*,
services: list[str] | None = None,
project_name: str | None = None,
env: dict[str, str] | None = None,
timeout_sec: float = _DEFAULT_PROBE_TIMEOUT_SEC,
extra_args: list[str] | None = None,
resolver: RockRegistryResolver | None = None,
) -> subprocess.CompletedProcess:
"""Resolve images in *compose_path* to ROCK mirrors, then ``docker compose pull``.

1. resolve_compose(compose_path) — rewrite service images to regional mirrors
2. docker compose -f <compose_path> [-p <project>] pull [services...]

Resolve failures are non-blocking (fail-safe). Pull failures raise RuntimeError.
"""
try:
await resolve_compose(compose_path, timeout_sec=timeout_sec, resolver=resolver)
except Exception:
logger.warning("resolve_compose failed for %s, proceeding with original images", compose_path, exc_info=True)

cmd = ["docker", "compose", "-f", str(compose_path)]
if project_name:
cmd.extend(["-p", project_name])
cmd.append("pull")
if extra_args:
cmd.extend(extra_args)
if services:
cmd.extend(services)

run_env = dict(os.environ)
if env:
run_env.update(env)

proc = await asyncio.create_subprocess_exec(
*cmd,
env=run_env,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout_bytes, stderr_bytes = await proc.communicate()

result = subprocess.CompletedProcess(
args=cmd,
returncode=proc.returncode,
stdout=stdout_bytes.decode(errors="replace") if stdout_bytes else "",
stderr=stderr_bytes.decode(errors="replace") if stderr_bytes else "",
)

if result.returncode != 0:
raise RuntimeError(
f"docker compose pull failed (exit {result.returncode}):\nstdout: {result.stdout}\nstderr: {result.stderr}"
)

return result
238 changes: 238 additions & 0 deletions rock/sdk/envhub/regionless/resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
"""Rewrite container image references to a ROCK mirror registry when available.

Resolution rule:

- Read ``INSTANCE_ROCK_REGISTRY`` (one or more ``host/namespace`` entries
separated by ``,`` or ``;``, e.g.
``reg-a.aliyuncs.com/mirror-1,reg-b.aliyuncs.com/mirror-2``). If unset/empty,
return the original image unchanged.
- Take the *last* path segment of the image reference as the image name and
combine it with the original tag/digest:
``swebench/sweb.eval.x86_64.foo:latest`` →
``<registry>/sweb.eval.x86_64.foo:latest``.
- For each configured registry in order, probe the candidate via the Docker
Registry v2 manifest API (``GET /v2/{repo}/manifests/{tag}``), with Bearer
token authentication support and a short timeout. Return the *first*
candidate that exists; if none exist (or the probes time out / fail),
fall back to the original image. Probe results are cached in-process so
concurrent trials do not hammer the registry.
"""

from __future__ import annotations

import asyncio
import os
import re
from pathlib import Path
from urllib.parse import urlencode

import httpx

from rock.logger import init_logger

logger = init_logger(__name__)

ROCK_REGISTRY_ENV = "INSTANCE_ROCK_REGISTRY"
_DEFAULT_PROBE_TIMEOUT_SEC = 5.0
_REGISTRY_SEPARATORS = (",", ";")

_FROM_RE = re.compile(
r"^(?P<prefix>\s*FROM\s+(?:--\S+\s+)*)(?P<image>\S+)(?P<suffix>.*)$",
re.IGNORECASE,
)


class RockRegistryResolver:
"""Resolves container image references to ROCK mirror registries."""

def __init__(self) -> None:
self._resolve_cache: dict[str, str] = {}
self._cache_lock = asyncio.Lock()

@staticmethod
def parse_registries(raw: str) -> list[str]:
"""Split the env value into an ordered list of non-empty registry entries."""
if not raw:
return []
tokens = [raw]
for sep in _REGISTRY_SEPARATORS:
tokens = [piece for token in tokens for piece in token.split(sep)]
return [token.strip().rstrip("/") for token in tokens if token.strip()]

@staticmethod
def split_tag_or_digest(image: str) -> tuple[str, str]:
"""Split image into (path, tag-or-digest-suffix).

Examples:
"foo/bar:1.2" -> ("foo/bar", ":1.2")
"foo/bar@sha256:abc" -> ("foo/bar", "@sha256:abc")
"foo/bar" -> ("foo/bar", ":latest")
"""
if "@" in image:
path, _, digest = image.partition("@")
return path, f"@{digest}"
last_slash = image.rfind("/")
last_colon = image.rfind(":")
if last_colon > last_slash:
return image[:last_colon], f":{image[last_colon + 1 :]}"
return image, ":latest"

@staticmethod
def build_candidate(image: str, registry: str) -> str:
"""Build the candidate image reference under the ROCK registry.

Strips the original registry and first-level namespace, preserving any
nested namespaces and the image name plus tag/digest.
Example: ``ghcr.io/foo/bar/baz:v1`` with registry ``reg/ns`` →
``reg/ns/bar/baz:v1``.
"""
path, suffix = RockRegistryResolver.split_tag_or_digest(image)
if "/" in path:
_, path = path.split("/", 1)
if "/" in path:
_, path = path.split("/", 1)
return f"{registry.rstrip('/')}/{path}{suffix}"

@staticmethod
def _parse_bearer_challenge(header: str) -> dict[str, str]:
"""Parse ``realm``, ``service``, ``scope`` from a Bearer WWW-Authenticate header."""
return {m.group(1): m.group(2) for m in re.finditer(r'(\w+)="([^"]*)"', header)}

@staticmethod
def _parse_image_parts(image: str) -> tuple[str, str, str]:
"""Extract (registry_host, repo_path, tag) from a fully-qualified image reference."""
path, suffix = RockRegistryResolver.split_tag_or_digest(image)
tag = suffix.lstrip(":")
first_slash = path.find("/")
if first_slash == -1:
return (path, "", tag)
registry = path[:first_slash]
repo = path[first_slash + 1 :]
return (registry, repo, tag)

async def _http_probe_manifest(self, image: str, timeout_sec: float) -> bool:
"""Check whether *image* exists on its registry via the v2 manifest API."""
registry, repo, tag = self._parse_image_parts(image)
if not registry or not repo:
return False

url = f"https://{registry}/v2/{repo}/manifests/{tag}"
headers = {
"Accept": ", ".join(
[
"application/vnd.docker.distribution.manifest.v2+json",
"application/vnd.oci.image.manifest.v1+json",
"application/vnd.docker.distribution.manifest.list.v2+json",
"application/vnd.oci.image.index.v1+json",
]
)
}

try:
async with httpx.AsyncClient(timeout=timeout_sec) as client:
resp = await client.get(url, headers=headers)

if resp.status_code == 401 and "www-authenticate" in resp.headers:
www_auth = resp.headers["www-authenticate"]
if www_auth.startswith("Bearer "):
params = self._parse_bearer_challenge(www_auth)
realm = params.get("realm", "")
service = params.get("service", "")
scope = params.get("scope", "")
token_url = f"{realm}?{urlencode({'service': service, 'scope': scope})}"
token_resp = await client.get(token_url)
if token_resp.status_code == 200:
data = token_resp.json()
token = data.get("token") or data.get("access_token")
if token:
headers["Authorization"] = f"Bearer {token}"
resp = await client.get(url, headers=headers)

return resp.status_code == 200
except httpx.HTTPError:
logger.debug("HTTP probe for %s failed (network/protocol)", image, exc_info=True)
return False
except (ValueError, KeyError):
logger.debug("HTTP probe for %s failed (response parsing)", image, exc_info=True)
return False
except Exception:
logger.warning("HTTP probe for %s failed (unexpected)", image, exc_info=True)
return False

async def resolve_image(
self,
image: str,
*,
timeout_sec: float = _DEFAULT_PROBE_TIMEOUT_SEC,
) -> str:
"""Return a ROCK-mirrored image reference if available, else ``image``."""
if not image:
return image

if "@" in image:
return image

registries = self.parse_registries(os.environ.get(ROCK_REGISTRY_ENV, ""))
if not registries:
return image

cache_key = f"{'|'.join(registries)}||{image}"
async with self._cache_lock:
cached = self._resolve_cache.get(cache_key)
if cached is not None:
return cached

resolved = image
for registry in registries:
candidate = self.build_candidate(image, registry)
if candidate == image:
continue
if await self._http_probe_manifest(candidate, timeout_sec=timeout_sec):
logger.info("Rewriting image %s -> %s (ROCK mirror)", image, candidate)
resolved = candidate
break

async with self._cache_lock:
self._resolve_cache[cache_key] = resolved
return resolved

async def resolve_dockerfile(
self,
dockerfile: Path,
*,
timeout_sec: float = _DEFAULT_PROBE_TIMEOUT_SEC,
) -> bool:
"""Rewrite ``FROM`` images in *dockerfile* to ROCK mirrors when available.

Returns True if any ``FROM`` line was rewritten.
"""
text = dockerfile.read_text()
lines = text.splitlines(keepends=True)
changed = False

for i, line in enumerate(lines):
m = _FROM_RE.match(line.rstrip("\n\r"))
if not m:
continue
original = m.group("image")
resolved = await self.resolve_image(original, timeout_sec=timeout_sec)
if resolved != original:
eol = line[len(line.rstrip("\n\r")) :]
lines[i] = f"{m.group('prefix')}{resolved}{m.group('suffix')}{eol}"
changed = True

if changed:
dockerfile.write_text("".join(lines))
return changed

def reset_cache(self) -> None:
"""Clear the in-process resolve cache. Test helper."""
self._resolve_cache.clear()


RegionlessResolver = RockRegistryResolver

_default_resolver = RockRegistryResolver()

resolve_image = _default_resolver.resolve_image
resolve_dockerfile = _default_resolver.resolve_dockerfile
Empty file added tests/unit/envhub/__init__.py
Empty file.
Empty file.
Loading
Loading