Add HF sandbox provider#841
Conversation
| return url | ||
|
|
||
|
|
||
| class _LocalAuthProxy: |
There was a problem hiding this comment.
@Wauplin This class is a by product of hf auth and the url construction. I'm kinda ok with it here, but just wondered if there was a simpler way to deal with it.
HF sandbox benchmarkRun id: Throughput by workload
Stage timing
All benchmark jobs completed successfully. |
sergiopaniego
left a comment
There was a problem hiding this comment.
Thanks for the first iteration!! tested locally and works nicely. looking forward to this integration :) Some AI-assisted review below with some ideas about when this is scaled.
I left inline comments (with suggested changes) on the specific spots. The framing I'd suggest: the correctness items and repo-convention items should land in this PR, while the scale-hardening items below are fine as documented limitations plus follow-up issues, as long as the design doesn't close the door on them.
Why the correctness items matter here specifically: this provider will be used to run many environments in parallel for RL training, where a single env that fails silently is worse than one that fails loudly. It either wastes wall-clock (the whole batch waits on a straggler) or feeds garbage observations and rewards into the policy.
Scale hardening (follow-up OK, but worth flagging)
- Configurable startup timeouts. The 120s budgets are hardcoded. Heavy images and GPU flavors can take much longer to schedule, so these should be
__init__params. - Retry/backoff on
run_job. At high concurrency the Jobs API will rate-limit or return transient init errors. A bounded retry would avoid losing envs to a single blip. - Reuse a shared
httpx.AsyncClientin the proxy. A new client (and connection pool) is created per request. With many tool-call steps across many envs this is measurable hot-path overhead. - Per-env proxy footprint. Each provider instance starts its own uvicorn server, thread, and event loop. At high env concurrency that's a lot of local servers on the trainer host. Worth documenting as a known limit, or considering a single shared proxy that multiplexes upstreams.
- Observability. The provider logs nothing. With many concurrent envs, structured logging of job id, stage, and failure cause is what makes a failing env debuggable.
Tests & docs
- No tests yet. The pure helpers (
_job_port_url,_to_ws_url,_find_available_port, hop-by-hop filtering) pluswait_for_readyare all testable with no network, matching the existing provider test suite undertests/test_core/. - Docstrings. Public methods need the HF doc-builder docstring format used by the other providers.
| def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None: | ||
| deadline = time.time() + timeout_s | ||
| health_url = f"{base_url}/health" | ||
| while time.time() < deadline: | ||
| response = requests.get(health_url, timeout=5.0) | ||
| if response.status_code == 200: | ||
| return | ||
| time.sleep(1.0) | ||
| raise TimeoutError( | ||
| f"HF sandbox job at {base_url} did not become ready within {timeout_s}s" | ||
| ) |
There was a problem hiding this comment.
wait_for_ready aborts on the first transient error instead of honoring timeout_s.
requests.get(..., timeout=5.0) raises on a slow or refused response, and with no try/except the exception escapes the loop and kills the call long before the timeout_s budget is reached. I reproduced it by pointing wait_for_ready at a closed port: current code dies in 0.00s with ConnectionError, whereas with the try/except it retries and dies in ~8.0s with a clean TimeoutError.
| def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None: | |
| deadline = time.time() + timeout_s | |
| health_url = f"{base_url}/health" | |
| while time.time() < deadline: | |
| response = requests.get(health_url, timeout=5.0) | |
| if response.status_code == 200: | |
| return | |
| time.sleep(1.0) | |
| raise TimeoutError( | |
| f"HF sandbox job at {base_url} did not become ready within {timeout_s}s" | |
| ) | |
| def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None: | |
| deadline = time.time() + timeout_s | |
| health_url = f"{base_url}/health" | |
| while time.time() < deadline: | |
| try: | |
| response = requests.get(health_url, timeout=5.0) | |
| if response.status_code == 200: | |
| return | |
| except requests.exceptions.RequestException: | |
| pass | |
| time.sleep(1.0) | |
| raise TimeoutError( | |
| f"HF sandbox job at {base_url} did not become ready within {timeout_s}s" | |
| ) |
Good candidate for the first unit test (mock requests.get to raise, assert it still respects timeout_s).
| async with httpx.AsyncClient(follow_redirects=True) as client: | ||
| upstream = await client.request( | ||
| request.method, | ||
| target, | ||
| content=await request.body(), | ||
| headers=headers, | ||
| timeout=60.0, | ||
| ) |
There was a problem hiding this comment.
Distinguish "env died" from "env returned an error".
When the upstream job dies mid-rollout, the httpx transport error propagates and FastAPI returns a generic 500, indistinguishable from a legitimate application-level error from the env server itself. For a training loop these need different handling (restart the env vs. record a step error). Returning a 502 makes infra failures explicit and restartable:
| async with httpx.AsyncClient(follow_redirects=True) as client: | |
| upstream = await client.request( | |
| request.method, | |
| target, | |
| content=await request.body(), | |
| headers=headers, | |
| timeout=60.0, | |
| ) | |
| try: | |
| async with httpx.AsyncClient(follow_redirects=True) as client: | |
| upstream = await client.request( | |
| request.method, | |
| target, | |
| content=await request.body(), | |
| headers=headers, | |
| timeout=60.0, | |
| ) | |
| except httpx.HTTPError: | |
| return Response( | |
| content=b"upstream HF job unreachable", | |
| status_code=502, | |
| ) |
| while target_url is None and time.time() < deadline: | ||
| time.sleep(0.5) | ||
| self._job = self._api.inspect_job( | ||
| job_id=self._job.id, | ||
| namespace=self._job.owner.name, | ||
| token=self._token, | ||
| ) | ||
| target_url = _job_port_url(self._job, _DEFAULT_PORT) |
There was a problem hiding this comment.
Fail fast when the job terminates before exposing its port.
This loop only watches expose_urls, so a job that goes to a terminal state (bad image, server command crashes) waits the full timeout and then raises a generic message with no signal about what actually failed. Detecting a terminal stage lets a broken env fail fast and identifiably:
| while target_url is None and time.time() < deadline: | |
| time.sleep(0.5) | |
| self._job = self._api.inspect_job( | |
| job_id=self._job.id, | |
| namespace=self._job.owner.name, | |
| token=self._token, | |
| ) | |
| target_url = _job_port_url(self._job, _DEFAULT_PORT) | |
| while target_url is None and time.time() < deadline: | |
| time.sleep(0.5) | |
| self._job = self._api.inspect_job( | |
| job_id=self._job.id, | |
| namespace=self._job.owner.name, | |
| token=self._token, | |
| ) | |
| stage = getattr(self._job.status, "stage", None) | |
| if stage in ("ERROR", "COMPLETED", "DELETED"): | |
| raise RuntimeError( | |
| f"HF job {self._job.id} terminated early (stage={stage}) " | |
| f"before exposing port {_DEFAULT_PORT}" | |
| ) | |
| target_url = _job_port_url(self._job, _DEFAULT_PORT) |
(Worth confirming the exact terminal stage names against the huggingface_hub version you target.)
| if self._job is not None: | ||
| self._api.cancel_job( | ||
| job_id=self._job.id, | ||
| namespace=self._job.owner.name, | ||
| token=self._token, | ||
| ) | ||
| self._job = None | ||
| self._token = None |
There was a problem hiding this comment.
Make teardown idempotent and non-throwing.
self._job and self._token are reset after cancel_job. If cancel_job raises (network blip), the state is left dirty, and since close() runs on __exit__, that exception masks the original error from the with block. This matters at scale: training loops oversample and cancel unfinished rollouts aggressively, so teardown runs often and on the hot path. A teardown that can leak a (billable) job or throw will accumulate orphans over a long run.
| if self._job is not None: | |
| self._api.cancel_job( | |
| job_id=self._job.id, | |
| namespace=self._job.owner.name, | |
| token=self._token, | |
| ) | |
| self._job = None | |
| self._token = None | |
| if self._job is not None: | |
| try: | |
| self._api.cancel_job( | |
| job_id=self._job.id, | |
| namespace=self._job.owner.name, | |
| token=self._token, | |
| ) | |
| finally: | |
| self._job = None | |
| self._token = None |
| def start_container( | ||
| self, | ||
| image: str, | ||
| port: int | None = None, | ||
| env_vars: dict[str, str] | None = None, | ||
| ) -> str: |
There was a problem hiding this comment.
start_container drops **kwargs.
The abstract ContainerProvider.start_container declares **kwargs: Any, and EnvClient calls provider.start_container(image, **kwargs) (see env_client.py:296). Without it, the generic path raises TypeError for any forwarded kwarg. Any is already imported:
| def start_container( | |
| self, | |
| image: str, | |
| port: int | None = None, | |
| env_vars: dict[str, str] | None = None, | |
| ) -> str: | |
| def start_container( | |
| self, | |
| image: str, | |
| port: int | None = None, | |
| env_vars: dict[str, str] | None = None, | |
| **kwargs: Any, | |
| ) -> str: |
| @@ -0,0 +1,269 @@ | |||
| """Hugging Face-backed provider for OpenEnv environment servers.""" | |||
There was a problem hiding this comment.
Missing license header. Every other module in the repo carries the BSD header. Part of the repo convention, should be added.
| """Hugging Face-backed provider for OpenEnv environment servers.""" | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Hugging Face-backed provider for OpenEnv environment servers.""" |
| #!/usr/bin/env python3 | ||
| """Smoke-check a real OpenEnv server through the HF sandbox provider.""" |
There was a problem hiding this comment.
Missing license header. Same as the provider module (header goes after the shebang):
| #!/usr/bin/env python3 | |
| """Smoke-check a real OpenEnv server through the HF sandbox provider.""" | |
| #!/usr/bin/env python3 | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Smoke-check a real OpenEnv server through the HF sandbox provider.""" |
| def _to_ws_url(url: str) -> str: | ||
| if url.startswith("https://"): | ||
| return "wss://" + url[len("https://") :] | ||
| if url.startswith("http://"): | ||
| return "ws://" + url[len("http://") :] | ||
| return url |
There was a problem hiding this comment.
Enforce secure transport on the upstream URL. The sibling cloud provider rejects non-HTTPS URLs (the token is a bearer secret, and EnvClient derives its WebSocket URL from this base). _to_ws_url silently downgrades to ws://, and the proxy injects the Bearer token on it. Worth asserting target_url starts with https:// in _wait_for_job_url so the invariant is explicit. (Left as a comment rather than a suggestion since the guard belongs in _wait_for_job_url, not here.)
| async def _client_to_upstream(self, websocket: WebSocket, upstream: Any) -> None: | ||
| async for message in websocket.iter_text(): | ||
| await upstream.send(message) |
There was a problem hiding this comment.
Noted, not a blocker: the relay handles bytes and text upstream-to-client but only text client-to-upstream (iter_text()). Fine in practice since EnvClient only ever sends JSON text, but a short comment noting the assumption would help future readers.
This PR adds a minimal Hugging Face sandbox-backed provider for OpenEnv environment servers.