diff --git a/pyproject.toml b/pyproject.toml index c2cde938..0488fb6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ dependencies = [ # Core functionality (always include blaxel.core dependencies) core = [] +# HTTP/3 (QUIC) transport support +h3 = ["aioquic>=1.2.0"] + # Telemetry module telemetry = [ "opentelemetry-exporter-otlp>=1.28.0", diff --git a/src/blaxel/core/common/autoload.py b/src/blaxel/core/common/autoload.py index 4f810146..1119bd86 100644 --- a/src/blaxel/core/common/autoload.py +++ b/src/blaxel/core/common/autoload.py @@ -22,9 +22,6 @@ def autoload() -> None: client.with_base_url(settings.base_url) client.with_auth(settings.auth) - # Register response interceptors for authentication error handling - # Access the underlying httpx clients and add event hooks - # Use sync interceptors for sync clients and async interceptors for async clients httpx_client = client.get_httpx_client() httpx_client.event_hooks["response"] = response_interceptors_sync diff --git a/src/blaxel/core/common/h3transport.py b/src/blaxel/core/common/h3transport.py new file mode 100644 index 00000000..496c898e --- /dev/null +++ b/src/blaxel/core/common/h3transport.py @@ -0,0 +1,490 @@ +"""HTTP/3 transport for httpx via aioquic. + +Provides an async H3Transport (httpx.AsyncBaseTransport) and a sync +SyncH3Transport (httpx.BaseTransport) backed by a shared connection pool +keyed by (host, port). + +When UDP is blocked or an H3 connection fails mid-session, callers +automatically fall back to HTTP/2 (or HTTP/1.1 if h2 is unavailable). +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +import time +from collections import deque +from typing import AsyncIterator, Deque +from urllib.parse import urlparse + +import httpx + +try: + from aioquic.asyncio.client import connect + from aioquic.asyncio.protocol import QuicConnectionProtocol + from aioquic.h3.connection import H3_ALPN, H3Connection + from aioquic.h3.events import DataReceived, H3Event, HeadersReceived + from aioquic.quic.configuration import QuicConfiguration + from aioquic.quic.events import QuicEvent + + AIOQUIC_AVAILABLE = True +except ImportError: + AIOQUIC_AVAILABLE = False + +logging.getLogger("quic").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) + +_H3_CONNECT_TIMEOUT = 5.0 +_H3_FAIL_TTL = 300.0 # remember failures for 5 min before retrying + +try: + import h2 as _h2 # noqa: F401 + + HTTP2_AVAILABLE = True +except ImportError: + HTTP2_AVAILABLE = False + + +# --------------------------------------------------------------------------- +# All aioquic-dependent classes are guarded so the module can be imported +# even when aioquic is not installed (optional dependency). +# --------------------------------------------------------------------------- + +if AIOQUIC_AVAILABLE: + + class _H3ByteStream(httpx.AsyncByteStream): + def __init__(self, aiterator: AsyncIterator[bytes]): + self._aiterator = aiterator + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for part in self._aiterator: + yield part + + class _H3Transport(QuicConnectionProtocol, httpx.AsyncBaseTransport): + """httpx async transport over a single QUIC/H3 connection.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._http = H3Connection(self._quic) + self._read_queue: dict[int, Deque[H3Event]] = {} + self._read_ready: dict[int, asyncio.Event] = {} + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + + stream_id = self._quic.get_next_available_stream_id() + self._read_queue[stream_id] = deque() + self._read_ready[stream_id] = asyncio.Event() + + self._http.send_headers( + stream_id=stream_id, + headers=[ + (b":method", request.method.encode()), + (b":scheme", request.url.raw_scheme), + (b":authority", request.url.netloc), + (b":path", request.url.raw_path), + ] + + [ + (k.lower(), v) + for (k, v) in request.headers.raw + if k.lower() not in (b"connection", b"host") + ], + ) + async for data in request.stream: + self._http.send_data(stream_id=stream_id, data=data, end_stream=False) + self._http.send_data(stream_id=stream_id, data=b"", end_stream=True) + self.transmit() + + status_code, headers, stream_ended = await self._receive_response(stream_id) + + return httpx.Response( + status_code=status_code, + headers=headers, + stream=_H3ByteStream( + self._receive_response_data(stream_id, stream_ended) + ), + extensions={"http_version": b"HTTP/3"}, + ) + + # -- aioquic protocol callbacks -------------------------------------- + + def http_event_received(self, event: H3Event) -> None: + if isinstance(event, (HeadersReceived, DataReceived)): + stream_id = event.stream_id + if stream_id in self._read_queue: + self._read_queue[stream_id].append(event) + self._read_ready[stream_id].set() + + def quic_event_received(self, event: QuicEvent) -> None: + if self._http is not None: + for http_event in self._http.handle_event(event): + self.http_event_received(http_event) + + # -- internal helpers ------------------------------------------------ + + async def _receive_response(self, stream_id: int): + stream_ended = False + while True: + event = await self._wait_for_http_event(stream_id) + if isinstance(event, HeadersReceived): + stream_ended = event.stream_ended + break + + headers = [] + status_code = 0 + for header, value in event.headers: + if header == b":status": + status_code = int(value.decode()) + else: + headers.append((header, value)) + return status_code, headers, stream_ended + + async def _receive_response_data( + self, stream_id: int, stream_ended: bool + ) -> AsyncIterator[bytes]: + while not stream_ended: + event = await self._wait_for_http_event(stream_id) + if isinstance(event, DataReceived): + stream_ended = event.stream_ended + yield event.data + elif isinstance(event, HeadersReceived): + stream_ended = event.stream_ended + + async def _wait_for_http_event(self, stream_id: int) -> H3Event: + if not self._read_queue[stream_id]: + await self._read_ready[stream_id].wait() + event = self._read_queue[stream_id].popleft() + if not self._read_queue[stream_id]: + self._read_ready[stream_id].clear() + return event + + # ----------------------------------------------------------------------- + # Sync H3 bridge (delegates to the async _H3Transport via a bg loop) + # ----------------------------------------------------------------------- + + class _SyncH3Transport(httpx.BaseTransport): + """Sync httpx transport that delegates to an async _H3Transport.""" + + def __init__( + self, async_transport: _H3Transport, loop: asyncio.AbstractEventLoop + ): + self._async_transport = async_transport + self._loop = loop + + def handle_request(self, request: httpx.Request) -> httpx.Response: + future = asyncio.run_coroutine_threadsafe( + self._async_transport.handle_async_request(request), + self._loop, + ) + return future.result(timeout=300) + + def close(self) -> None: + pass + + # ----------------------------------------------------------------------- + # Fallback transports: try H3, auto-downgrade to HTTP/2 on failure + # ----------------------------------------------------------------------- + + class AsyncH3FallbackTransport(httpx.AsyncBaseTransport): + """Async transport that tries H3 and falls back to HTTP/2 on failure.""" + + def __init__(self, h3: _H3Transport, host: str, port: int): + self._h3 = h3 + self._host = host + self._port = port + self._h2_fallback: httpx.AsyncHTTPTransport | None = None + self._use_fallback = False + + async def handle_async_request( + self, request: httpx.Request + ) -> httpx.Response: + if self._use_fallback: + return await self._ensure_h2().handle_async_request(request) + try: + return await self._h3.handle_async_request(request) + except Exception: + logger.info( + "H3 request to %s:%d failed, downgrading to HTTP/2", + self._host, + self._port, + ) + self._use_fallback = True + pool._mark_failed(self._host, self._port) + return await self._ensure_h2().handle_async_request(request) + + def _ensure_h2(self) -> httpx.AsyncHTTPTransport: + if self._h2_fallback is None: + self._h2_fallback = httpx.AsyncHTTPTransport( + http2=HTTP2_AVAILABLE + ) + return self._h2_fallback + + async def aclose(self) -> None: + if self._h2_fallback is not None: + await self._h2_fallback.aclose() + + class SyncH3FallbackTransport(httpx.BaseTransport): + """Sync transport that tries H3 and falls back to HTTP/2 on failure.""" + + def __init__(self, sync_h3: _SyncH3Transport, host: str, port: int): + self._sync_h3 = sync_h3 + self._host = host + self._port = port + self._h2_fallback: httpx.HTTPTransport | None = None + self._use_fallback = False + + def handle_request(self, request: httpx.Request) -> httpx.Response: + if self._use_fallback: + return self._ensure_h2().handle_request(request) + try: + return self._sync_h3.handle_request(request) + except Exception: + logger.info( + "H3 request to %s:%d failed, downgrading to HTTP/2", + self._host, + self._port, + ) + self._use_fallback = True + pool._mark_failed(self._host, self._port) + return self._ensure_h2().handle_request(request) + + def _ensure_h2(self) -> httpx.HTTPTransport: + if self._h2_fallback is None: + self._h2_fallback = httpx.HTTPTransport(http2=HTTP2_AVAILABLE) + return self._h2_fallback + + def close(self) -> None: + if self._h2_fallback is not None: + self._h2_fallback.close() + + # ----------------------------------------------------------------------- + # Connection pool + # ----------------------------------------------------------------------- + + class H3Pool: + """Global pool of H3 transports keyed by (host, port). + + Manages a background event loop thread for sync callers and QUIC + connection lifecycle. Failed hosts are remembered for + ``_H3_FAIL_TTL`` seconds so that repeated connection attempts don't + add latency. + """ + + def __init__(self) -> None: + self._async_transports: dict[tuple[str, int], _H3Transport] = {} + self._connect_contexts: dict[tuple[str, int], object] = {} + self._failed_hosts: dict[tuple[str, int], float] = {} + self._lock = threading.Lock() + self._async_lock: asyncio.Lock | None = None + self._bg_loop: asyncio.AbstractEventLoop | None = None + self._bg_thread: threading.Thread | None = None + + def _get_async_lock(self) -> asyncio.Lock: + if self._async_lock is None: + self._async_lock = asyncio.Lock() + return self._async_lock + + # -- negative cache -------------------------------------------------- + + def _is_failed(self, host: str, port: int) -> bool: + key = (host, port) + with self._lock: + ts = self._failed_hosts.get(key) + if ts is None: + return False + if time.monotonic() - ts > _H3_FAIL_TTL: + del self._failed_hosts[key] + return False + return True + + def _mark_failed(self, host: str, port: int) -> None: + key = (host, port) + with self._lock: + self._failed_hosts[key] = time.monotonic() + self._connect_contexts.pop(key, None) + self._async_transports.pop(key, None) + + # -- background event loop for sync callers -------------------------- + + def _ensure_bg_loop(self) -> asyncio.AbstractEventLoop: + with self._lock: + if self._bg_loop is None or not self._bg_loop.is_running(): + ready = threading.Event() + + def _run_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._bg_loop = loop + ready.set() + loop.run_forever() + + self._bg_thread = threading.Thread( + target=_run_loop, daemon=True + ) + self._bg_thread.start() + ready.wait(timeout=5) + return self._bg_loop # type: ignore[return-value] + + # -- internal: raw H3 connection ------------------------------------- + + async def _get_or_connect( + self, host: str, port: int + ) -> _H3Transport | None: + """Get a cached _H3Transport or establish a new QUIC connection. + + Holds the async lock across the entire check+connect+store + sequence to prevent two concurrent callers from both creating + connections for the same (host, port), which would leak the + first QUIC connection context. + """ + key = (host, port) + async with self._get_async_lock(): + transport = self._async_transports.get(key) + if transport is not None: + return transport + try: + transport = await asyncio.wait_for( + self._connect(host, port), + timeout=_H3_CONNECT_TIMEOUT, + ) + self._async_transports[key] = transport + return transport + except Exception: + logger.debug( + "H3 connection to %s:%d failed", host, port + ) + self._mark_failed(host, port) + return None + + async def _connect(self, host: str, port: int) -> _H3Transport: + configuration = QuicConfiguration( + is_client=True, + alpn_protocols=H3_ALPN, + server_name=host, + ) + ctx = connect( + host, + port, + configuration=configuration, + create_protocol=_H3Transport, + ) + transport = await ctx.__aenter__() + with self._lock: + self._connect_contexts[(host, port)] = ctx + return transport # type: ignore[return-value] + + # -- public async API ------------------------------------------------ + + async def get_async_transport( + self, host: str, port: int = 443 + ) -> AsyncH3FallbackTransport | None: + """Get an H3 transport with automatic HTTP/2 fallback. + + Returns None if the QUIC handshake fails (caller should fall + back to HTTP/2 or use ``get_async_transport_for_url`` which + does this automatically). + """ + if self._is_failed(host, port): + return None + raw = await self._get_or_connect(host, port) + if raw is None: + return None + return AsyncH3FallbackTransport(raw, host, port) + + # -- public sync API (dispatches to bg loop) ------------------------- + + def get_sync_transport( + self, host: str, port: int = 443 + ) -> SyncH3FallbackTransport | None: + """Get a sync H3 transport with automatic HTTP/2 fallback. + + Returns None on failure (caller should fall back to HTTP/2 or + use ``get_sync_transport_for_url``). + """ + if self._is_failed(host, port): + return None + loop = self._ensure_bg_loop() + future = asyncio.run_coroutine_threadsafe( + self._get_or_connect(host, port), loop + ) + try: + raw = future.result(timeout=_H3_CONNECT_TIMEOUT + 1) + except Exception: + self._mark_failed(host, port) + return None + if raw is None: + return None + return SyncH3FallbackTransport( + _SyncH3Transport(raw, loop), host, port + ) + + # -- shutdown -------------------------------------------------------- + + async def close_all(self) -> None: + async with self._get_async_lock(): + for key, ctx in list(self._connect_contexts.items()): + try: + await ctx.__aexit__(None, None, None) + except Exception: + pass + self._async_transports.clear() + self._connect_contexts.clear() + + def close_all_sync(self) -> None: + loop = self._bg_loop + if loop is not None and loop.is_running(): + future = asyncio.run_coroutine_threadsafe( + self.close_all(), loop + ) + try: + future.result(timeout=5) + except Exception: + pass + + # -- module-level singleton (only when aioquic is available) ------------- + + pool: H3Pool | None = H3Pool() + +else: + pool: H3Pool | None = None # type: ignore[no-redef] + + +# --------------------------------------------------------------------------- +# Helpers — return the best available transport (H3 → HTTP/2 → None) +# --------------------------------------------------------------------------- + + +def _parse_host_port(url: str) -> tuple[str, int]: + parsed = urlparse(url) + host = parsed.hostname or "" + port = parsed.port or (443 if parsed.scheme == "https" else 80) + return host, port + + +async def get_async_transport_for_url(url: str) -> httpx.AsyncBaseTransport | None: + """Best-effort transport for *url*: H3 with fallback, else HTTP/2, else None.""" + host, port = _parse_host_port(url) + if not host: + return None + if pool is not None: + transport = await pool.get_async_transport(host, port) + if transport is not None: + return transport + if HTTP2_AVAILABLE: + return httpx.AsyncHTTPTransport(http2=True) + return None + + +def get_sync_transport_for_url(url: str) -> httpx.BaseTransport | None: + """Best-effort transport for *url*: H3 with fallback, else HTTP/2, else None.""" + host, port = _parse_host_port(url) + if not host: + return None + if pool is not None: + transport = pool.get_sync_transport(host, port) + if transport is not None: + return transport + if HTTP2_AVAILABLE: + return httpx.HTTPTransport(http2=True) + return None diff --git a/src/blaxel/core/sandbox/default/action.py b/src/blaxel/core/sandbox/default/action.py index 5cad3312..0a0f0a2c 100644 --- a/src/blaxel/core/sandbox/default/action.py +++ b/src/blaxel/core/sandbox/default/action.py @@ -2,6 +2,13 @@ from ...common.internal import get_forced_url, get_global_unique_hash from ...common.settings import settings + +try: + import h2 as _h2 # noqa: F401 + + HTTP2_AVAILABLE = True +except ImportError: + HTTP2_AVAILABLE = False from ..types import ResponseError, SandboxConfiguration @@ -54,17 +61,37 @@ def fallback_url(self) -> str | None: return None def get_client(self) -> httpx.AsyncClient: - """Get persistent HTTP client for this sandbox instance.""" + """Get persistent HTTP client for this sandbox instance. + + Headers are injected via an event hook so that token refreshes are + picked up automatically without recreating the client. + """ if self._client is None: base_url = self.sandbox_config.force_url or self.url + transport = getattr(self.sandbox_config, "h3_transport", None) + kwargs: dict = {} + if transport is not None: + kwargs["transport"] = transport + elif HTTP2_AVAILABLE: + kwargs["http2"] = True + + sandbox_config = self.sandbox_config + + async def _inject_headers(request: httpx.Request) -> None: + """Inject fresh headers before every request.""" + if sandbox_config.force_url: + fresh = sandbox_config.headers + else: + fresh = {**settings.headers, **sandbox_config.headers} + for key, value in fresh.items(): + request.headers[key] = value + self._client = httpx.AsyncClient( base_url=base_url, - headers=self.sandbox_config.headers - if self.sandbox_config.force_url - else {**settings.headers, **self.sandbox_config.headers}, - http2=False, + event_hooks={"request": [_inject_headers]}, limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), timeout=httpx.Timeout(300.0, connect=10.0), + **kwargs, ) return self._client diff --git a/src/blaxel/core/sandbox/default/filesystem.py b/src/blaxel/core/sandbox/default/filesystem.py index 49ecc044..12da6f22 100644 --- a/src/blaxel/core/sandbox/default/filesystem.py +++ b/src/blaxel/core/sandbox/default/filesystem.py @@ -394,7 +394,13 @@ async def start_watching(): url = f"{self.url}/watch/filesystem/{path}" headers = {**settings.headers, **self.sandbox_config.headers} - async with httpx.AsyncClient() as client_instance: + from ...common.h3transport import get_async_transport_for_url + + transport = await get_async_transport_for_url(url) + watch_kwargs: dict = {} + if transport is not None: + watch_kwargs["transport"] = transport + async with httpx.AsyncClient(**watch_kwargs) as client_instance: async with client_instance.stream( "GET", url, params=params, headers=headers ) as response: diff --git a/src/blaxel/core/sandbox/default/process.py b/src/blaxel/core/sandbox/default/process.py index 59e0b6b0..5ee38a4e 100644 --- a/src/blaxel/core/sandbox/default/process.py +++ b/src/blaxel/core/sandbox/default/process.py @@ -154,7 +154,13 @@ async def start_streaming(): headers = {**settings.headers, **self.sandbox_config.headers} try: - async with httpx.AsyncClient() as client_instance: + from ...common.h3transport import get_async_transport_for_url + + transport = await get_async_transport_for_url(url) + kwargs: dict = {} + if transport is not None: + kwargs["transport"] = transport + async with httpx.AsyncClient(**kwargs) as client_instance: async with client_instance.stream("GET", url, headers=headers) as response: if response.status_code != 200: raise Exception(f"Failed to stream logs: {await response.aread()}") @@ -296,7 +302,13 @@ async def _exec_with_streaming( else {**settings.headers, **self.sandbox_config.headers} ) - async with httpx.AsyncClient() as client_instance: + from ...common.h3transport import get_async_transport_for_url + + transport = await get_async_transport_for_url(self.url) + stream_kwargs: dict = {} + if transport is not None: + stream_kwargs["transport"] = transport + async with httpx.AsyncClient(**stream_kwargs) as client_instance: async with client_instance.stream( "POST", f"{self.url}/process", diff --git a/src/blaxel/core/sandbox/default/sandbox.py b/src/blaxel/core/sandbox/default/sandbox.py index 69c68ab0..d09710fa 100644 --- a/src/blaxel/core/sandbox/default/sandbox.py +++ b/src/blaxel/core/sandbox/default/sandbox.py @@ -60,10 +60,8 @@ def __init__(self, delete_func: Callable): def __get__(self, instance, owner): if instance is None: - # Called on the class: SandboxInstance.delete("name") return self._delete_func else: - # Called on an instance: instance.delete() async def instance_delete() -> Sandbox: return await self._delete_func(instance.metadata.name) @@ -260,12 +258,16 @@ async def create( sandbox.spec.runtime.image = sandbox.spec.runtime.image or default_image sandbox.spec.runtime.memory = sandbox.spec.runtime.memory or default_memory + # Extract region from existing Sandbox spec and apply it + region = sandbox.spec.region or settings.region + if region: + sandbox.spec.region = region + response = await create_sandbox( client=client, body=sandbox, ) - # Check if response is an error if isinstance(response, SandboxError): status_code = response.status_code if response.status_code is not UNSET else None code = response.code if response.code else None @@ -274,7 +276,7 @@ async def create( assert response is not None instance = cls(response) - # TODO remove this part once we have a better way to handle this + if safe: try: await instance.fs.ls("/") diff --git a/src/blaxel/core/sandbox/sync/action.py b/src/blaxel/core/sandbox/sync/action.py index 528a5aa7..63baef10 100644 --- a/src/blaxel/core/sandbox/sync/action.py +++ b/src/blaxel/core/sandbox/sync/action.py @@ -2,6 +2,13 @@ from ...common.internal import get_forced_url, get_global_unique_hash from ...common.settings import settings + +try: + import h2 as _h2 # noqa: F401 + + HTTP2_AVAILABLE = True +except ImportError: + HTTP2_AVAILABLE = False from ..types import ResponseError, SandboxConfiguration @@ -49,14 +56,34 @@ def fallback_url(self) -> str | None: return None def get_client(self) -> httpx.Client: - if self.sandbox_config.force_url: - return httpx.Client( - base_url=self.sandbox_config.force_url, - headers=self.sandbox_config.headers, - ) + """Get an HTTP client for this sandbox instance. + + Headers are injected via an event hook so that token refreshes are + picked up automatically without recreating the client. + """ + transport = getattr(self.sandbox_config, "h3_transport", None) + kwargs: dict = {} + if transport is not None: + kwargs["transport"] = transport + elif HTTP2_AVAILABLE: + kwargs["http2"] = True + + sandbox_config = self.sandbox_config + + def _inject_headers(request: httpx.Request) -> None: + """Inject fresh headers before every request.""" + if sandbox_config.force_url: + fresh = sandbox_config.headers + else: + fresh = {**settings.headers, **sandbox_config.headers} + for key, value in fresh.items(): + request.headers[key] = value + + base_url = self.sandbox_config.force_url or self.url return httpx.Client( - base_url=self.url, - headers={**settings.headers, **self.sandbox_config.headers}, + base_url=base_url, + event_hooks={"request": [_inject_headers]}, + **kwargs, ) def handle_response_error(self, response: httpx.Response): diff --git a/src/blaxel/core/sandbox/sync/filesystem.py b/src/blaxel/core/sandbox/sync/filesystem.py index 1b8037d1..662b1be1 100644 --- a/src/blaxel/core/sandbox/sync/filesystem.py +++ b/src/blaxel/core/sandbox/sync/filesystem.py @@ -191,7 +191,13 @@ def run(): params["ignore"] = ",".join(options["ignore"]) url = f"{self.url}/watch/filesystem/{path}" headers = {**settings.headers, **self.sandbox_config.headers} - with httpx.Client() as client_instance: + from ...common.h3transport import get_sync_transport_for_url + + transport = get_sync_transport_for_url(url) + watch_kw: dict = {} + if transport is not None: + watch_kw["transport"] = transport + with httpx.Client(**watch_kw) as client_instance: with client_instance.stream("GET", url, params=params, headers=headers) as response: if not response.is_success: raise Exception(f"Failed to start watching: {response.status_code}") diff --git a/src/blaxel/core/sandbox/sync/process.py b/src/blaxel/core/sandbox/sync/process.py index 7df15394..d6e67bae 100644 --- a/src/blaxel/core/sandbox/sync/process.py +++ b/src/blaxel/core/sandbox/sync/process.py @@ -123,7 +123,13 @@ def run(): url = f"{self.url}/process/{identifier}/logs/stream" headers = {**settings.headers, **self.sandbox_config.headers} try: - with httpx.Client() as client_instance: + from ...common.h3transport import get_sync_transport_for_url + + transport = get_sync_transport_for_url(url) + stream_kw: dict = {} + if transport is not None: + stream_kw["transport"] = transport + with httpx.Client(**stream_kw) as client_instance: with client_instance.stream("GET", url, headers=headers) as response: if response.status_code != 200: raise Exception(f"Failed to stream logs: {response.text}") @@ -242,7 +248,13 @@ def _exec_with_streaming( else {**settings.headers, **self.sandbox_config.headers} ) - with httpx.Client() as client_instance: + from ...common.h3transport import get_sync_transport_for_url + + transport = get_sync_transport_for_url(self.url) + exec_kw: dict = {} + if transport is not None: + exec_kw["transport"] = transport + with httpx.Client(**exec_kw) as client_instance: with client_instance.stream( "POST", f"{self.url}/process", diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index e8cd9a15..7dce4900 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -9,16 +9,9 @@ from ...client.api.compute.list_sandboxes import sync as list_sandboxes from ...client.api.compute.update_sandbox import sync as update_sandbox from ...client.client import client -from ...client.models import ( - Metadata, - Sandbox, - SandboxLifecycle, - SandboxRuntime, - SandboxSpec, -) -from ...client.models import ( - SandboxNetwork as SandboxNetworkModel, -) +from ...client.models import Metadata, Sandbox, SandboxLifecycle +from ...client.models import SandboxNetwork as SandboxNetworkModel +from ...client.models import SandboxRuntime, SandboxSpec from ...client.models.error import Error from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET @@ -50,10 +43,8 @@ def __init__(self, delete_func: Callable): def __get__(self, instance, owner): if instance is None: - # Called on the class: SyncSandboxInstance.delete("name") return self._delete_func else: - # Called on an instance: instance.delete() def instance_delete() -> Sandbox: return self._delete_func(instance.metadata.name) @@ -228,12 +219,17 @@ def create( sandbox.spec.runtime = SandboxRuntime(image=default_image, memory=default_memory) sandbox.spec.runtime.image = sandbox.spec.runtime.image or default_image sandbox.spec.runtime.memory = sandbox.spec.runtime.memory or default_memory + + # Extract region from existing Sandbox spec and apply it + region = sandbox.spec.region or settings.region + if region: + sandbox.spec.region = region + response = create_sandbox( client=client, body=sandbox, ) - # Check if response is an error if isinstance(response, SandboxError): status_code = response.status_code if response.status_code is not UNSET else None code = response.code if response.code else None @@ -241,6 +237,7 @@ def create( raise SandboxAPIError(message, status_code=status_code, code=code) instance = cls(response) + if safe: try: instance.fs.ls("/") diff --git a/src/blaxel/core/sandbox/types.py b/src/blaxel/core/sandbox/types.py index c5ac0d9d..ea4bb9e9 100644 --- a/src/blaxel/core/sandbox/types.py +++ b/src/blaxel/core/sandbox/types.py @@ -95,6 +95,7 @@ def __init__( self.force_url = force_url self.headers = headers or {} self.params = params or {} + self.h3_transport: Any = None @property def metadata(self): diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/benchmarks/bench_cold_call.py b/tests/benchmarks/bench_cold_call.py new file mode 100644 index 00000000..9aa2e943 --- /dev/null +++ b/tests/benchmarks/bench_cold_call.py @@ -0,0 +1,214 @@ +"""Cold-call benchmark: create → call → delete. + +Measures end-to-end latency of sandbox lifecycle operations with +per-phase breakdown (create / call / delete). + +Usage: + python -m tests.benchmarks.bench_cold_call + python -m tests.benchmarks.bench_cold_call --iterations 5 --warmup 0 +""" + +import argparse +import asyncio +import statistics +import time +from dataclasses import dataclass + +from blaxel.core.sandbox import SandboxInstance +from tests.helpers import default_image, default_labels, default_region, unique_name + +ITERATIONS = 10 +WARMUP_ITERATIONS = 1 + + +@dataclass +class Timing: + total: float = 0.0 + create: float = 0.0 + call: float = 0.0 + delete: float = 0.0 + + +async def bench_cold_ls(iteration: int) -> Timing: + """create -> fs.ls('/') -> delete""" + name = unique_name("bench-cold-ls") + t = Timing() + + t0 = time.perf_counter() + sandbox = await SandboxInstance.create({ + "name": name, + "image": default_image, + "labels": default_labels, + "memory": 2048, + "region": default_region, + }) + t1 = time.perf_counter() + t.create = t1 - t0 + + await sandbox.fs.ls("/") + t2 = time.perf_counter() + t.call = t2 - t1 + + try: + await SandboxInstance.delete(name) + except Exception: + pass + t3 = time.perf_counter() + t.delete = t3 - t2 + + t.total = t3 - t0 + return t + + +async def bench_cold_exec(iteration: int) -> Timing: + """create -> process.exec('echo ok') -> delete""" + name = unique_name("bench-cold-exec") + t = Timing() + + t0 = time.perf_counter() + sandbox = await SandboxInstance.create({ + "name": name, + "image": default_image, + "labels": default_labels, + "memory": 2048, + "region": default_region, + }) + t1 = time.perf_counter() + t.create = t1 - t0 + + await sandbox.process.exec({ + "command": "echo ok", + "wait_for_completion": True, + }) + t2 = time.perf_counter() + t.call = t2 - t1 + + try: + await SandboxInstance.delete(name) + except Exception: + pass + t3 = time.perf_counter() + t.delete = t3 - t2 + + t.total = t3 - t0 + return t + + +def fmt(seconds: float) -> str: + return f"{seconds * 1000:.0f}ms" + + +def percentile(values: list[float], p: float) -> float: + sorted_v = sorted(values) + k = (len(sorted_v) - 1) * (p / 100) + f = int(k) + c = f + 1 + if c >= len(sorted_v): + return sorted_v[f] + return sorted_v[f] + (k - f) * (sorted_v[c] - sorted_v[f]) + + + +def format_stats(values: list[float], label: str) -> str: + if not values: + return f" {label}: no data" + mean = statistics.mean(values) + p50 = percentile(values, 50) + p90 = percentile(values, 90) + p99 = percentile(values, 99) + mn = min(values) + mx = max(values) + return ( + f" {label:16s} " + f"mean={fmt(mean):>7s} p50={fmt(p50):>7s} " + f"p90={fmt(p90):>7s} p99={fmt(p99):>7s} " + f"min={fmt(mn):>7s} max={fmt(mx):>7s}" + ) + + +async def run_benchmark( + name: str, + fn, + iterations: int, + warmup: int, +) -> list[Timing]: + print(f"\n{'='*80}") + print(f" {name}") + print(f"{'='*80}") + + if warmup > 0: + print(f" warming up ({warmup} iteration{'s' if warmup > 1 else ''})...") + for i in range(warmup): + try: + await fn(i) + except Exception as e: + print(f" warmup {i+1} failed: {e}") + + timings: list[Timing] = [] + for i in range(iterations): + try: + t = await fn(i) + timings.append(t) + print( + f" [{i+1:>2}/{iterations}] " + f"total={fmt(t.total):>7s} " + f"create={fmt(t.create):>7s} " + f"call={fmt(t.call):>7s} " + f"delete={fmt(t.delete):>7s}" + ) + except Exception as e: + print(f" [{i+1:>2}/{iterations}] FAILED: {e}") + + if timings: + print() + print(format_stats([t.total for t in timings], "total")) + print(format_stats([t.create for t in timings], "create")) + print(format_stats([t.call for t in timings], "call")) + print(format_stats([t.delete for t in timings], "delete")) + return timings + + +async def main(iterations: int, warmup: int) -> None: + print(f"\nCold-call benchmark (create -> call -> delete)") + print(f" iterations={iterations}, warmup={warmup}") + print(f" image={default_image}, region={default_region}") + + ls_timings = await run_benchmark( + "create -> fs.ls('/') -> delete", + bench_cold_ls, + iterations, + warmup, + ) + + exec_timings = await run_benchmark( + "create -> process.exec('echo ok') -> delete", + bench_cold_exec, + iterations, + warmup, + ) + + print(f"\n{'='*80}") + print(" SUMMARY") + print(f"{'='*80}") + if ls_timings: + print(" fs.ls:") + print(format_stats([t.total for t in ls_timings], "total")) + print(format_stats([t.create for t in ls_timings], "create")) + print(format_stats([t.call for t in ls_timings], "call")) + print(format_stats([t.delete for t in ls_timings], "delete")) + if exec_timings: + print(" process.exec:") + print(format_stats([t.total for t in exec_timings], "total")) + print(format_stats([t.create for t in exec_timings], "create")) + print(format_stats([t.call for t in exec_timings], "call")) + print(format_stats([t.delete for t in exec_timings], "delete")) + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Cold-call sandbox benchmark") + parser.add_argument("--iterations", type=int, default=ITERATIONS) + parser.add_argument("--warmup", type=int, default=WARMUP_ITERATIONS) + args = parser.parse_args() + + asyncio.run(main(args.iterations, args.warmup)) diff --git a/tests/integration/openai/test_tools.py b/tests/integration/openai/test_tools.py index 1f0989b1..89459f77 100644 --- a/tests/integration/openai/test_tools.py +++ b/tests/integration/openai/test_tools.py @@ -10,7 +10,12 @@ from blaxel.core.sandbox import SandboxInstance # noqa: E402 from blaxel.openai import bl_model, bl_tools # noqa: E402 -from tests.helpers import default_image, default_labels, unique_name # noqa: E402 +from tests.helpers import ( # noqa: E402 + default_image, + default_labels, + unique_name, + wait_for_sandbox_deployed, +) @pytest.mark.asyncio(loop_scope="class") @@ -32,6 +37,7 @@ async def setup_sandbox(self, request): "labels": default_labels, } ) + await wait_for_sandbox_deployed(request.cls.sandbox_name) yield