From a919db2aaed05688d8180804dad44ef7c1773e32 Mon Sep 17 00:00:00 2001 From: mjoffre Date: Thu, 5 Mar 2026 06:05:13 +0000 Subject: [PATCH 01/16] feat: add background HTTP/3 (QUIC) connection warming - Add h3warm.py utility with aioquic-based QUIC connection warming - Warm H3 connections to regional edge domains in SandboxInstance.create() and SyncSandboxInstance.create() in parallel with sandbox creation API call - Warm H3 connection to api.blaxel.ai during SDK autoload - Add proper H3 session cleanup in delete() for both async and sync sandboxes - Export close_api_h3_session() for manual API session cleanup - Add aioquic>=1.2.0 as core dependency Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- pyproject.toml | 1 + src/blaxel/core/common/autoload.py | 40 +++++++++++++ src/blaxel/core/common/h3warm.py | 65 ++++++++++++++++++++++ src/blaxel/core/sandbox/default/sandbox.py | 27 +++++++++ src/blaxel/core/sandbox/sync/sandbox.py | 36 ++++++++++++ 5 files changed, 169 insertions(+) create mode 100644 src/blaxel/core/common/h3warm.py diff --git a/pyproject.toml b/pyproject.toml index c2cde938..0da7158d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "websockets<16.0.0", "attrs>=21.3.0", "httpx>=0.27.0", + "aioquic>=1.2.0", "mcp>=1.9.4", "dockerfile-parse>=2.0.0", ] diff --git a/src/blaxel/core/common/autoload.py b/src/blaxel/core/common/autoload.py index 4f810146..57ad794b 100644 --- a/src/blaxel/core/common/autoload.py +++ b/src/blaxel/core/common/autoload.py @@ -1,4 +1,7 @@ +import asyncio import logging +import threading +from urllib.parse import urlparse from ..client import client from ..client.response_interceptor import ( @@ -6,11 +9,15 @@ response_interceptors_sync, ) from ..sandbox.client import client as client_sandbox +from .h3warm import H3WarmSession, establish_h3_best_effort from .sentry import init_sentry from .settings import settings logger = logging.getLogger(__name__) +# Module-level H3 session for API endpoint warming +_api_h3_session: H3WarmSession | None = None + def telemetry() -> None: from blaxel.telemetry import telemetry_manager @@ -47,3 +54,36 @@ def autoload() -> None: telemetry() except Exception: pass + + # Warm H3 connection to API endpoint in background + try: + api_hostname = urlparse(settings.base_url).hostname + if api_hostname: + _warm_api_h3(api_hostname) + except Exception: + pass + + +def _warm_api_h3(hostname: str) -> None: + """Start background H3 connection warming for the API endpoint.""" + global _api_h3_session + + def _do_warm() -> None: + global _api_h3_session + try: + loop = asyncio.new_event_loop() + _api_h3_session = loop.run_until_complete(establish_h3_best_effort(hostname)) + loop.close() + except Exception: + pass + + thread = threading.Thread(target=_do_warm, daemon=True) + thread.start() + + +def close_api_h3_session() -> None: + """Close the API H3 warming session. Call this for clean shutdown.""" + global _api_h3_session + if _api_h3_session is not None: + _api_h3_session.close() + _api_h3_session = None diff --git a/src/blaxel/core/common/h3warm.py b/src/blaxel/core/common/h3warm.py new file mode 100644 index 00000000..dd2733ea --- /dev/null +++ b/src/blaxel/core/common/h3warm.py @@ -0,0 +1,65 @@ +"""HTTP/3 (QUIC) connection warming utility. + +Establishes a QUIC connection to a given hostname to pre-warm +DNS resolution and TLS 1.3/QUIC handshake, so that the first real +request benefits from a pre-warmed connection. +""" + +import asyncio +import logging + +from aioquic.asyncio import connect +from aioquic.quic.configuration import QuicConfiguration + +logger = logging.getLogger(__name__) + + +class H3WarmSession: + """Holds a warmed QUIC/HTTP3 connection.""" + + def __init__(self, ctx: object, client: object): + self._ctx = ctx + self._client = client + + def close(self) -> None: + """Close the QUIC connection.""" + try: + if self._client is not None: + self._client.close() # type: ignore[union-attr] + self._client = None + except Exception: + pass + + +async def establish_h3(hostname: str, port: int = 443) -> H3WarmSession: + """Establish an HTTP/3 (QUIC) connection to the given hostname. + + Performs DNS resolution + QUIC handshake (including TLS 1.3) to + fully warm the connection path. + + Args: + hostname: The SNI hostname to connect to. + port: The port to connect to (default 443). + + Returns: + An H3WarmSession wrapping the QUIC connection. + """ + configuration = QuicConfiguration( + is_client=True, + alpn_protocols=["h3"], + server_name=hostname, + ) + configuration.verify_mode = False + + ctx = connect(hostname, port, configuration=configuration) + client = await ctx.__aenter__() + + return H3WarmSession(ctx, client) + + +async def establish_h3_best_effort(hostname: str, port: int = 443) -> H3WarmSession | None: + """Best-effort HTTP/3 warming. Returns None on any failure.""" + try: + return await asyncio.wait_for(establish_h3(hostname, port), timeout=5.0) + except Exception: + return None diff --git a/src/blaxel/core/sandbox/default/sandbox.py b/src/blaxel/core/sandbox/default/sandbox.py index 69c68ab0..00ba4718 100644 --- a/src/blaxel/core/sandbox/default/sandbox.py +++ b/src/blaxel/core/sandbox/default/sandbox.py @@ -1,3 +1,4 @@ +import asyncio import logging import uuid import warnings @@ -23,6 +24,7 @@ from ...client.models.error import Error from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET +from ...common.h3warm import H3WarmSession, establish_h3_best_effort from ...common.settings import settings from ..types import ( SandboxConfiguration, @@ -65,6 +67,10 @@ def __get__(self, instance, owner): else: # Called on an instance: instance.delete() async def instance_delete() -> Sandbox: + # Close H3 session if present + if hasattr(instance, "h3_session") and instance.h3_session is not None: + instance.h3_session.close() + instance.h3_session = None return await self._delete_func(instance.metadata.name) return instance_delete @@ -72,6 +78,7 @@ async def instance_delete() -> Sandbox: class SandboxInstance: delete: "_AsyncDeleteDescriptor" + h3_session: H3WarmSession | None def __init__( self, @@ -102,6 +109,7 @@ def __init__( self.codegen = SandboxCodegen(self.config) self.system = SandboxSystem(self.config) self.drives = SandboxDrive(self.config) + self.h3_session = None @property def metadata(self): @@ -260,6 +268,15 @@ async def create( sandbox.spec.runtime.image = sandbox.spec.runtime.image or default_image sandbox.spec.runtime.memory = sandbox.spec.runtime.memory or default_memory + # Extract region from existing Sandbox spec + region = getattr(sandbox.spec, "region", None) or settings.region + + # Start H3 connection warming in parallel with sandbox creation + h3_warm_task: asyncio.Task[H3WarmSession | None] | None = None + if region: + edge_domain = f"any.{region}.bl.run" + h3_warm_task = asyncio.create_task(establish_h3_best_effort(edge_domain)) + response = await create_sandbox( client=client, body=sandbox, @@ -270,10 +287,20 @@ async def create( status_code = response.status_code if response.status_code is not UNSET else None code = response.code if response.code else None message = response.message if response.message else str(response) + # Clean up H3 warming task if it was started + if h3_warm_task is not None: + h3_warm_task.cancel() raise SandboxAPIError(message, status_code=status_code, code=code) assert response is not None instance = cls(response) + + # Retrieve H3 session from warming task + if h3_warm_task is not None: + try: + instance.h3_session = await h3_warm_task + except Exception: + instance.h3_session = None # TODO remove this part once we have a better way to handle this if safe: try: diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index e8cd9a15..791a96fa 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -1,4 +1,6 @@ +import asyncio import logging +import threading import uuid import warnings from typing import Any, Callable, Dict, List, Union @@ -22,6 +24,7 @@ from ...client.models.error import Error from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET +from ...common.h3warm import H3WarmSession, establish_h3_best_effort from ...common.settings import settings from ..default.sandbox import SandboxAPIError from ..types import ( @@ -55,6 +58,10 @@ def __get__(self, instance, owner): else: # Called on an instance: instance.delete() def instance_delete() -> Sandbox: + # Close H3 session if present + if hasattr(instance, "h3_session") and instance.h3_session is not None: + instance.h3_session.close() + instance.h3_session = None return self._delete_func(instance.metadata.name) return instance_delete @@ -87,6 +94,7 @@ def __init__( self.codegen = SyncSandboxCodegen(self.config) self.system = SyncSandboxSystem(self.config) self.drives = SyncSandboxDrive(self.config) + self.h3_session: H3WarmSession | None = None @property def metadata(self): @@ -228,6 +236,29 @@ def create( sandbox.spec.runtime = SandboxRuntime(image=default_image, memory=default_memory) sandbox.spec.runtime.image = sandbox.spec.runtime.image or default_image sandbox.spec.runtime.memory = sandbox.spec.runtime.memory or default_memory + + # Extract region from existing Sandbox spec + region = getattr(sandbox.spec, "region", None) or settings.region + + # Start H3 connection warming in a background thread + h3_result: dict[str, H3WarmSession | None] = {"session": None} + h3_thread: threading.Thread | None = None + if region: + edge_domain = f"any.{region}.bl.run" + + def _warm_h3() -> None: + try: + loop = asyncio.new_event_loop() + h3_result["session"] = loop.run_until_complete( + establish_h3_best_effort(edge_domain) + ) + loop.close() + except Exception: + pass + + h3_thread = threading.Thread(target=_warm_h3, daemon=True) + h3_thread.start() + response = create_sandbox( client=client, body=sandbox, @@ -241,6 +272,11 @@ def create( raise SandboxAPIError(message, status_code=status_code, code=code) instance = cls(response) + + # Retrieve H3 session from warming thread + if h3_thread is not None: + h3_thread.join(timeout=5.0) + instance.h3_session = h3_result["session"] if safe: try: instance.fs.ls("/") From 835a8ec7cab9dc92c82025c88198b7508e453eea Mon Sep 17 00:00:00 2001 From: mjoffre Date: Thu, 5 Mar 2026 06:17:28 +0000 Subject: [PATCH 02/16] fix: address review - resource leak, blocking join, TLS verification Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/common/h3warm.py | 13 ++++++++----- src/blaxel/core/sandbox/sync/sandbox.py | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/blaxel/core/common/h3warm.py b/src/blaxel/core/common/h3warm.py index dd2733ea..d81842b8 100644 --- a/src/blaxel/core/common/h3warm.py +++ b/src/blaxel/core/common/h3warm.py @@ -22,10 +22,15 @@ def __init__(self, ctx: object, client: object): self._client = client def close(self) -> None: - """Close the QUIC connection.""" + """Close the QUIC connection by properly exiting the async context manager.""" try: - if self._client is not None: - self._client.close() # type: ignore[union-attr] + if self._ctx is not None: + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(self._ctx.__aexit__(None, None, None)) + finally: + loop.close() + self._ctx = None self._client = None except Exception: pass @@ -49,8 +54,6 @@ async def establish_h3(hostname: str, port: int = 443) -> H3WarmSession: alpn_protocols=["h3"], server_name=hostname, ) - configuration.verify_mode = False - ctx = connect(hostname, port, configuration=configuration) client = await ctx.__aenter__() diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index 791a96fa..9c20e51e 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -273,9 +273,9 @@ def _warm_h3() -> None: instance = cls(response) - # Retrieve H3 session from warming thread + # Retrieve H3 session from warming thread (non-blocking) if h3_thread is not None: - h3_thread.join(timeout=5.0) + h3_thread.join(timeout=0) instance.h3_session = h3_result["session"] if safe: try: From fe9f65a54d2087e4fe0d4d2b216778d2d494d31a Mon Sep 17 00:00:00 2001 From: Mathis Joffre Date: Wed, 4 Mar 2026 22:28:46 -0800 Subject: [PATCH 03/16] Update src/blaxel/core/sandbox/sync/sandbox.py Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/sandbox/sync/sandbox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index 9c20e51e..9688be20 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -275,7 +275,7 @@ def _warm_h3() -> None: # Retrieve H3 session from warming thread (non-blocking) if h3_thread is not None: - h3_thread.join(timeout=0) + h3_thread.join(timeout=5) instance.h3_session = h3_result["session"] if safe: try: From 555e84aae571cd2a577c796854e94d4e6655718d Mon Sep 17 00:00:00 2001 From: Joffref Date: Thu, 5 Mar 2026 00:54:25 -0800 Subject: [PATCH 04/16] WIP --- src/blaxel/core/common/autoload.py | 27 +- src/blaxel/core/common/h3transport.py | 322 ++++++++++++++++++ src/blaxel/core/common/h3warm.py | 68 ---- src/blaxel/core/sandbox/default/action.py | 6 +- src/blaxel/core/sandbox/default/filesystem.py | 7 +- src/blaxel/core/sandbox/default/process.py | 13 +- src/blaxel/core/sandbox/default/sandbox.py | 32 +- src/blaxel/core/sandbox/sync/action.py | 6 + src/blaxel/core/sandbox/sync/filesystem.py | 7 +- src/blaxel/core/sandbox/sync/process.py | 13 +- src/blaxel/core/sandbox/sync/sandbox.py | 47 +-- src/blaxel/core/sandbox/types.py | 1 + tests/benchmarks/__init__.py | 0 tests/benchmarks/bench_cold_call.py | 201 +++++++++++ 14 files changed, 602 insertions(+), 148 deletions(-) create mode 100644 src/blaxel/core/common/h3transport.py delete mode 100644 src/blaxel/core/common/h3warm.py create mode 100644 tests/benchmarks/__init__.py create mode 100644 tests/benchmarks/bench_cold_call.py diff --git a/src/blaxel/core/common/autoload.py b/src/blaxel/core/common/autoload.py index 57ad794b..f0d43e14 100644 --- a/src/blaxel/core/common/autoload.py +++ b/src/blaxel/core/common/autoload.py @@ -1,4 +1,3 @@ -import asyncio import logging import threading from urllib.parse import urlparse @@ -9,15 +8,12 @@ response_interceptors_sync, ) from ..sandbox.client import client as client_sandbox -from .h3warm import H3WarmSession, establish_h3_best_effort +from .h3transport import pool as h3_pool from .sentry import init_sentry from .settings import settings logger = logging.getLogger(__name__) -# Module-level H3 session for API endpoint warming -_api_h3_session: H3WarmSession | None = None - def telemetry() -> None: from blaxel.telemetry import telemetry_manager @@ -29,9 +25,6 @@ def autoload() -> None: client.with_base_url(settings.base_url) client.with_auth(settings.auth) - # Register response interceptors for authentication error handling - # Access the underlying httpx clients and add event hooks - # Use sync interceptors for sync clients and async interceptors for async clients httpx_client = client.get_httpx_client() httpx_client.event_hooks["response"] = response_interceptors_sync @@ -55,7 +48,7 @@ def autoload() -> None: except Exception: pass - # Warm H3 connection to API endpoint in background + # Pre-warm H3 connection to API endpoint in background try: api_hostname = urlparse(settings.base_url).hostname if api_hostname: @@ -65,25 +58,13 @@ def autoload() -> None: def _warm_api_h3(hostname: str) -> None: - """Start background H3 connection warming for the API endpoint.""" - global _api_h3_session + """Pre-warm the H3 pool for the API endpoint in a background thread.""" def _do_warm() -> None: - global _api_h3_session try: - loop = asyncio.new_event_loop() - _api_h3_session = loop.run_until_complete(establish_h3_best_effort(hostname)) - loop.close() + h3_pool.get_sync_transport(hostname, 443) except Exception: pass thread = threading.Thread(target=_do_warm, daemon=True) thread.start() - - -def close_api_h3_session() -> None: - """Close the API H3 warming session. Call this for clean shutdown.""" - global _api_h3_session - if _api_h3_session is not None: - _api_h3_session.close() - _api_h3_session = None diff --git a/src/blaxel/core/common/h3transport.py b/src/blaxel/core/common/h3transport.py new file mode 100644 index 00000000..4ad41f63 --- /dev/null +++ b/src/blaxel/core/common/h3transport.py @@ -0,0 +1,322 @@ +"""HTTP/3 transport for httpx via aioquic. + +Provides an async H3Transport (httpx.AsyncBaseTransport) and a sync +SyncH3Transport (httpx.BaseTransport) backed by a shared connection pool +keyed by (host, port). +""" + +from __future__ import annotations + +import asyncio +import logging +import socket +import threading +from collections import deque +from typing import AsyncIterator, Deque +from urllib.parse import urlparse + +import httpx +from aioquic.asyncio.client import connect +from aioquic.asyncio.protocol import QuicConnectionProtocol +from aioquic.h3.connection import H3_ALPN, H3Connection +from aioquic.h3.events import DataReceived, H3Event, HeadersReceived +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import QuicEvent + +logging.getLogger("quic").setLevel(logging.WARNING) +logger = logging.getLogger(__name__) + +_H3_CONNECT_TIMEOUT = 5.0 + + +# --------------------------------------------------------------------------- +# Async H3 transport (one QUIC connection per instance) +# --------------------------------------------------------------------------- + +class _H3ByteStream(httpx.AsyncByteStream): + def __init__(self, aiterator: AsyncIterator[bytes]): + self._aiterator = aiterator + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for part in self._aiterator: + yield part + + +class H3Transport(QuicConnectionProtocol, httpx.AsyncBaseTransport): + """httpx async transport over a single QUIC/H3 connection.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._http = H3Connection(self._quic) + self._read_queue: dict[int, Deque[H3Event]] = {} + self._read_ready: dict[int, asyncio.Event] = {} + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + + stream_id = self._quic.get_next_available_stream_id() + self._read_queue[stream_id] = deque() + self._read_ready[stream_id] = asyncio.Event() + + self._http.send_headers( + stream_id=stream_id, + headers=[ + (b":method", request.method.encode()), + (b":scheme", request.url.raw_scheme), + (b":authority", request.url.netloc), + (b":path", request.url.raw_path), + ] + + [ + (k.lower(), v) + for (k, v) in request.headers.raw + if k.lower() not in (b"connection", b"host") + ], + ) + async for data in request.stream: + self._http.send_data(stream_id=stream_id, data=data, end_stream=False) + self._http.send_data(stream_id=stream_id, data=b"", end_stream=True) + self.transmit() + + status_code, headers, stream_ended = await self._receive_response(stream_id) + + return httpx.Response( + status_code=status_code, + headers=headers, + stream=_H3ByteStream(self._receive_response_data(stream_id, stream_ended)), + extensions={"http_version": b"HTTP/3"}, + ) + + # -- aioquic protocol callbacks ------------------------------------------ + + def http_event_received(self, event: H3Event) -> None: + if isinstance(event, (HeadersReceived, DataReceived)): + stream_id = event.stream_id + if stream_id in self._read_queue: + self._read_queue[stream_id].append(event) + self._read_ready[stream_id].set() + + def quic_event_received(self, event: QuicEvent) -> None: + if self._http is not None: + for http_event in self._http.handle_event(event): + self.http_event_received(http_event) + + # -- internal helpers ---------------------------------------------------- + + async def _receive_response(self, stream_id: int): + stream_ended = False + while True: + event = await self._wait_for_http_event(stream_id) + if isinstance(event, HeadersReceived): + stream_ended = event.stream_ended + break + + headers = [] + status_code = 0 + for header, value in event.headers: + if header == b":status": + status_code = int(value.decode()) + else: + headers.append((header, value)) + return status_code, headers, stream_ended + + async def _receive_response_data( + self, stream_id: int, stream_ended: bool + ) -> AsyncIterator[bytes]: + while not stream_ended: + event = await self._wait_for_http_event(stream_id) + if isinstance(event, DataReceived): + stream_ended = event.stream_ended + yield event.data + elif isinstance(event, HeadersReceived): + stream_ended = event.stream_ended + + async def _wait_for_http_event(self, stream_id: int) -> H3Event: + if not self._read_queue[stream_id]: + await self._read_ready[stream_id].wait() + event = self._read_queue[stream_id].popleft() + if not self._read_queue[stream_id]: + self._read_ready[stream_id].clear() + return event + + +# --------------------------------------------------------------------------- +# Sync H3 transport (bridges async transport via a background event loop) +# --------------------------------------------------------------------------- + +class SyncH3Transport(httpx.BaseTransport): + """Sync httpx transport that delegates to an async H3Transport.""" + + def __init__(self, async_transport: H3Transport, loop: asyncio.AbstractEventLoop): + self._async_transport = async_transport + self._loop = loop + + def handle_request(self, request: httpx.Request) -> httpx.Response: + future = asyncio.run_coroutine_threadsafe( + self._async_transport.handle_async_request(request), + self._loop, + ) + return future.result(timeout=300) + + def close(self) -> None: + pass + + +# --------------------------------------------------------------------------- +# Connection pool +# --------------------------------------------------------------------------- + +class H3Pool: + """Global pool of H3 transports keyed by (host, port). + + Manages a background event loop thread for sync callers and QUIC + connection lifecycle. + """ + + def __init__(self) -> None: + self._async_transports: dict[tuple[str, int], H3Transport] = {} + self._connect_contexts: dict[tuple[str, int], object] = {} + self._lock = threading.Lock() + self._async_lock: asyncio.Lock | None = None + self._bg_loop: asyncio.AbstractEventLoop | None = None + self._bg_thread: threading.Thread | None = None + + def _get_async_lock(self) -> asyncio.Lock: + if self._async_lock is None: + self._async_lock = asyncio.Lock() + return self._async_lock + + # -- background event loop for sync callers ------------------------------ + + def _ensure_bg_loop(self) -> asyncio.AbstractEventLoop: + with self._lock: + if self._bg_loop is None or not self._bg_loop.is_running(): + ready = threading.Event() + + def _run_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._bg_loop = loop + ready.set() + loop.run_forever() + + self._bg_thread = threading.Thread(target=_run_loop, daemon=True) + self._bg_thread.start() + ready.wait(timeout=5) + return self._bg_loop # type: ignore[return-value] + + # -- async API ----------------------------------------------------------- + + async def get_async_transport( + self, host: str, port: int = 443 + ) -> H3Transport | None: + """Get or create an H3Transport for the given host. + + Returns None if the QUIC handshake fails (caller should fall back + to TCP). + """ + key = (host, port) + async with self._get_async_lock(): + transport = self._async_transports.get(key) + if transport is not None: + return transport + try: + transport = await asyncio.wait_for( + self._connect(host, port), timeout=_H3_CONNECT_TIMEOUT + ) + async with self._get_async_lock(): + self._async_transports[key] = transport + return transport + except Exception: + logger.debug("H3 connection to %s:%d failed, falling back to TCP", host, port) + return None + + async def _connect(self, host: str, port: int) -> H3Transport: + configuration = QuicConfiguration( + is_client=True, + alpn_protocols=H3_ALPN, + server_name=host, + ) + ctx = connect( + host, + port, + configuration=configuration, + create_protocol=H3Transport, + ) + transport = await ctx.__aenter__() + with self._lock: + self._connect_contexts[(host, port)] = ctx + return transport # type: ignore[return-value] + + # -- sync API (dispatches to bg loop) ------------------------------------ + + def get_sync_transport( + self, host: str, port: int = 443 + ) -> SyncH3Transport | None: + """Get or create a SyncH3Transport for the given host. + + Returns None on failure (caller should fall back to TCP). + """ + loop = self._ensure_bg_loop() + future = asyncio.run_coroutine_threadsafe( + self.get_async_transport(host, port), loop + ) + try: + async_transport = future.result(timeout=_H3_CONNECT_TIMEOUT + 1) + except Exception: + return None + if async_transport is None: + return None + return SyncH3Transport(async_transport, loop) + + # -- shutdown ------------------------------------------------------------ + + async def close_all(self) -> None: + async with self._get_async_lock(): + for key, ctx in list(self._connect_contexts.items()): + try: + await ctx.__aexit__(None, None, None) + except Exception: + pass + self._async_transports.clear() + self._connect_contexts.clear() + + def close_all_sync(self) -> None: + loop = self._bg_loop + if loop is not None and loop.is_running(): + future = asyncio.run_coroutine_threadsafe(self.close_all(), loop) + try: + future.result(timeout=5) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Module-level singleton +# --------------------------------------------------------------------------- + +pool = H3Pool() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _parse_host_port(url: str) -> tuple[str, int]: + parsed = urlparse(url) + host = parsed.hostname or "" + port = parsed.port or (443 if parsed.scheme == "https" else 80) + return host, port + + +async def get_async_transport_for_url(url: str) -> H3Transport | None: + host, port = _parse_host_port(url) + if not host: + return None + return await pool.get_async_transport(host, port) + + +def get_sync_transport_for_url(url: str) -> SyncH3Transport | None: + host, port = _parse_host_port(url) + if not host: + return None + return pool.get_sync_transport(host, port) diff --git a/src/blaxel/core/common/h3warm.py b/src/blaxel/core/common/h3warm.py deleted file mode 100644 index d81842b8..00000000 --- a/src/blaxel/core/common/h3warm.py +++ /dev/null @@ -1,68 +0,0 @@ -"""HTTP/3 (QUIC) connection warming utility. - -Establishes a QUIC connection to a given hostname to pre-warm -DNS resolution and TLS 1.3/QUIC handshake, so that the first real -request benefits from a pre-warmed connection. -""" - -import asyncio -import logging - -from aioquic.asyncio import connect -from aioquic.quic.configuration import QuicConfiguration - -logger = logging.getLogger(__name__) - - -class H3WarmSession: - """Holds a warmed QUIC/HTTP3 connection.""" - - def __init__(self, ctx: object, client: object): - self._ctx = ctx - self._client = client - - def close(self) -> None: - """Close the QUIC connection by properly exiting the async context manager.""" - try: - if self._ctx is not None: - loop = asyncio.new_event_loop() - try: - loop.run_until_complete(self._ctx.__aexit__(None, None, None)) - finally: - loop.close() - self._ctx = None - self._client = None - except Exception: - pass - - -async def establish_h3(hostname: str, port: int = 443) -> H3WarmSession: - """Establish an HTTP/3 (QUIC) connection to the given hostname. - - Performs DNS resolution + QUIC handshake (including TLS 1.3) to - fully warm the connection path. - - Args: - hostname: The SNI hostname to connect to. - port: The port to connect to (default 443). - - Returns: - An H3WarmSession wrapping the QUIC connection. - """ - configuration = QuicConfiguration( - is_client=True, - alpn_protocols=["h3"], - server_name=hostname, - ) - ctx = connect(hostname, port, configuration=configuration) - client = await ctx.__aenter__() - - return H3WarmSession(ctx, client) - - -async def establish_h3_best_effort(hostname: str, port: int = 443) -> H3WarmSession | None: - """Best-effort HTTP/3 warming. Returns None on any failure.""" - try: - return await asyncio.wait_for(establish_h3(hostname, port), timeout=5.0) - except Exception: - return None diff --git a/src/blaxel/core/sandbox/default/action.py b/src/blaxel/core/sandbox/default/action.py index 5cad3312..69b7d95b 100644 --- a/src/blaxel/core/sandbox/default/action.py +++ b/src/blaxel/core/sandbox/default/action.py @@ -57,14 +57,18 @@ def get_client(self) -> httpx.AsyncClient: """Get persistent HTTP client for this sandbox instance.""" if self._client is None: base_url = self.sandbox_config.force_url or self.url + transport = getattr(self.sandbox_config, "h3_transport", None) + kwargs: dict = {} + if transport is not None: + kwargs["transport"] = transport self._client = httpx.AsyncClient( base_url=base_url, headers=self.sandbox_config.headers if self.sandbox_config.force_url else {**settings.headers, **self.sandbox_config.headers}, - http2=False, limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), timeout=httpx.Timeout(300.0, connect=10.0), + **kwargs, ) return self._client diff --git a/src/blaxel/core/sandbox/default/filesystem.py b/src/blaxel/core/sandbox/default/filesystem.py index 49ecc044..aca4f4d0 100644 --- a/src/blaxel/core/sandbox/default/filesystem.py +++ b/src/blaxel/core/sandbox/default/filesystem.py @@ -7,6 +7,7 @@ import httpx +from ...common.h3transport import get_async_transport_for_url from ...common.settings import settings from ..client.models import Directory, FileRequest, SuccessResponse from ..types import ( @@ -394,7 +395,11 @@ async def start_watching(): url = f"{self.url}/watch/filesystem/{path}" headers = {**settings.headers, **self.sandbox_config.headers} - async with httpx.AsyncClient() as client_instance: + transport = await get_async_transport_for_url(url) + watch_kwargs: dict = {} + if transport is not None: + watch_kwargs["transport"] = transport + async with httpx.AsyncClient(**watch_kwargs) as client_instance: async with client_instance.stream( "GET", url, params=params, headers=headers ) as response: diff --git a/src/blaxel/core/sandbox/default/process.py b/src/blaxel/core/sandbox/default/process.py index 59e0b6b0..7f5e949e 100644 --- a/src/blaxel/core/sandbox/default/process.py +++ b/src/blaxel/core/sandbox/default/process.py @@ -3,6 +3,7 @@ import httpx +from ...common.h3transport import get_async_transport_for_url from ...common.settings import settings from ..client.models import ProcessResponse, SuccessResponse from ..client.models.process_request import ProcessRequest @@ -154,7 +155,11 @@ async def start_streaming(): headers = {**settings.headers, **self.sandbox_config.headers} try: - async with httpx.AsyncClient() as client_instance: + transport = await get_async_transport_for_url(url) + kwargs: dict = {} + if transport is not None: + kwargs["transport"] = transport + async with httpx.AsyncClient(**kwargs) as client_instance: async with client_instance.stream("GET", url, headers=headers) as response: if response.status_code != 200: raise Exception(f"Failed to stream logs: {await response.aread()}") @@ -296,7 +301,11 @@ async def _exec_with_streaming( else {**settings.headers, **self.sandbox_config.headers} ) - async with httpx.AsyncClient() as client_instance: + transport = await get_async_transport_for_url(self.url) + stream_kwargs: dict = {} + if transport is not None: + stream_kwargs["transport"] = transport + async with httpx.AsyncClient(**stream_kwargs) as client_instance: async with client_instance.stream( "POST", f"{self.url}/process", diff --git a/src/blaxel/core/sandbox/default/sandbox.py b/src/blaxel/core/sandbox/default/sandbox.py index 00ba4718..dad59843 100644 --- a/src/blaxel/core/sandbox/default/sandbox.py +++ b/src/blaxel/core/sandbox/default/sandbox.py @@ -24,7 +24,7 @@ from ...client.models.error import Error from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET -from ...common.h3warm import H3WarmSession, establish_h3_best_effort +from ...common.h3transport import H3Transport, pool as h3_pool from ...common.settings import settings from ..types import ( SandboxConfiguration, @@ -62,15 +62,9 @@ def __init__(self, delete_func: Callable): def __get__(self, instance, owner): if instance is None: - # Called on the class: SandboxInstance.delete("name") return self._delete_func else: - # Called on an instance: instance.delete() async def instance_delete() -> Sandbox: - # Close H3 session if present - if hasattr(instance, "h3_session") and instance.h3_session is not None: - instance.h3_session.close() - instance.h3_session = None return await self._delete_func(instance.metadata.name) return instance_delete @@ -78,7 +72,6 @@ async def instance_delete() -> Sandbox: class SandboxInstance: delete: "_AsyncDeleteDescriptor" - h3_session: H3WarmSession | None def __init__( self, @@ -109,7 +102,6 @@ def __init__( self.codegen = SandboxCodegen(self.config) self.system = SandboxSystem(self.config) self.drives = SandboxDrive(self.config) - self.h3_session = None @property def metadata(self): @@ -271,37 +263,39 @@ async def create( # Extract region from existing Sandbox spec region = getattr(sandbox.spec, "region", None) or settings.region - # Start H3 connection warming in parallel with sandbox creation - h3_warm_task: asyncio.Task[H3WarmSession | None] | None = None + # Pre-warm H3 transport in parallel with sandbox creation + h3_warm_task: asyncio.Task[H3Transport | None] | None = None if region: edge_domain = f"any.{region}.bl.run" - h3_warm_task = asyncio.create_task(establish_h3_best_effort(edge_domain)) + h3_warm_task = asyncio.create_task(h3_pool.get_async_transport(edge_domain)) response = await create_sandbox( client=client, body=sandbox, ) - # Check if response is an error if isinstance(response, SandboxError): status_code = response.status_code if response.status_code is not UNSET else None code = response.code if response.code else None message = response.message if response.message else str(response) - # Clean up H3 warming task if it was started if h3_warm_task is not None: h3_warm_task.cancel() raise SandboxAPIError(message, status_code=status_code, code=code) assert response is not None - instance = cls(response) - # Retrieve H3 session from warming task + # Await H3 transport so the pool is warm for subsequent data-plane calls + h3_transport: H3Transport | None = None if h3_warm_task is not None: try: - instance.h3_session = await h3_warm_task + h3_transport = await h3_warm_task except Exception: - instance.h3_session = None - # TODO remove this part once we have a better way to handle this + h3_transport = None + + config = SandboxConfiguration(sandbox=response) + config.h3_transport = h3_transport + instance = cls(config) + if safe: try: await instance.fs.ls("/") diff --git a/src/blaxel/core/sandbox/sync/action.py b/src/blaxel/core/sandbox/sync/action.py index 528a5aa7..19894520 100644 --- a/src/blaxel/core/sandbox/sync/action.py +++ b/src/blaxel/core/sandbox/sync/action.py @@ -49,14 +49,20 @@ def fallback_url(self) -> str | None: return None def get_client(self) -> httpx.Client: + transport = getattr(self.sandbox_config, "h3_transport", None) + kwargs: dict = {} + if transport is not None: + kwargs["transport"] = transport if self.sandbox_config.force_url: return httpx.Client( base_url=self.sandbox_config.force_url, headers=self.sandbox_config.headers, + **kwargs, ) return httpx.Client( base_url=self.url, headers={**settings.headers, **self.sandbox_config.headers}, + **kwargs, ) def handle_response_error(self, response: httpx.Response): diff --git a/src/blaxel/core/sandbox/sync/filesystem.py b/src/blaxel/core/sandbox/sync/filesystem.py index 1b8037d1..6704b1d1 100644 --- a/src/blaxel/core/sandbox/sync/filesystem.py +++ b/src/blaxel/core/sandbox/sync/filesystem.py @@ -7,6 +7,7 @@ import httpx +from ...common.h3transport import get_sync_transport_for_url from ...common.settings import settings from ..client.models import Directory, FileRequest, SuccessResponse from ..types import ( @@ -191,7 +192,11 @@ def run(): params["ignore"] = ",".join(options["ignore"]) url = f"{self.url}/watch/filesystem/{path}" headers = {**settings.headers, **self.sandbox_config.headers} - with httpx.Client() as client_instance: + transport = get_sync_transport_for_url(url) + watch_kw: dict = {} + if transport is not None: + watch_kw["transport"] = transport + with httpx.Client(**watch_kw) as client_instance: with client_instance.stream("GET", url, params=params, headers=headers) as response: if not response.is_success: raise Exception(f"Failed to start watching: {response.status_code}") diff --git a/src/blaxel/core/sandbox/sync/process.py b/src/blaxel/core/sandbox/sync/process.py index 7df15394..9e772138 100644 --- a/src/blaxel/core/sandbox/sync/process.py +++ b/src/blaxel/core/sandbox/sync/process.py @@ -4,6 +4,7 @@ import httpx +from ...common.h3transport import get_sync_transport_for_url from ...common.settings import settings from ..client.models import ProcessResponse, SuccessResponse from ..client.models.process_request import ProcessRequest @@ -123,7 +124,11 @@ def run(): url = f"{self.url}/process/{identifier}/logs/stream" headers = {**settings.headers, **self.sandbox_config.headers} try: - with httpx.Client() as client_instance: + transport = get_sync_transport_for_url(url) + stream_kw: dict = {} + if transport is not None: + stream_kw["transport"] = transport + with httpx.Client(**stream_kw) as client_instance: with client_instance.stream("GET", url, headers=headers) as response: if response.status_code != 200: raise Exception(f"Failed to stream logs: {response.text}") @@ -242,7 +247,11 @@ def _exec_with_streaming( else {**settings.headers, **self.sandbox_config.headers} ) - with httpx.Client() as client_instance: + transport = get_sync_transport_for_url(self.url) + exec_kw: dict = {} + if transport is not None: + exec_kw["transport"] = transport + with httpx.Client(**exec_kw) as client_instance: with client_instance.stream( "POST", f"{self.url}/process", diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index 9688be20..b05ed87a 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -1,4 +1,3 @@ -import asyncio import logging import threading import uuid @@ -11,20 +10,14 @@ from ...client.api.compute.list_sandboxes import sync as list_sandboxes from ...client.api.compute.update_sandbox import sync as update_sandbox from ...client.client import client -from ...client.models import ( - Metadata, - Sandbox, - SandboxLifecycle, - SandboxRuntime, - SandboxSpec, -) -from ...client.models import ( - SandboxNetwork as SandboxNetworkModel, -) +from ...client.models import Metadata, Sandbox, SandboxLifecycle +from ...client.models import SandboxNetwork as SandboxNetworkModel +from ...client.models import SandboxRuntime, SandboxSpec from ...client.models.error import Error from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET -from ...common.h3warm import H3WarmSession, establish_h3_best_effort +from ...common.h3transport import SyncH3Transport +from ...common.h3transport import pool as h3_pool from ...common.settings import settings from ..default.sandbox import SandboxAPIError from ..types import ( @@ -53,15 +46,9 @@ def __init__(self, delete_func: Callable): def __get__(self, instance, owner): if instance is None: - # Called on the class: SyncSandboxInstance.delete("name") return self._delete_func else: - # Called on an instance: instance.delete() def instance_delete() -> Sandbox: - # Close H3 session if present - if hasattr(instance, "h3_session") and instance.h3_session is not None: - instance.h3_session.close() - instance.h3_session = None return self._delete_func(instance.metadata.name) return instance_delete @@ -94,7 +81,6 @@ def __init__( self.codegen = SyncSandboxCodegen(self.config) self.system = SyncSandboxSystem(self.config) self.drives = SyncSandboxDrive(self.config) - self.h3_session: H3WarmSession | None = None @property def metadata(self): @@ -240,19 +226,15 @@ def create( # Extract region from existing Sandbox spec region = getattr(sandbox.spec, "region", None) or settings.region - # Start H3 connection warming in a background thread - h3_result: dict[str, H3WarmSession | None] = {"session": None} + # Pre-warm H3 transport in a background thread + h3_result: dict[str, SyncH3Transport | None] = {"transport": None} h3_thread: threading.Thread | None = None if region: edge_domain = f"any.{region}.bl.run" def _warm_h3() -> None: try: - loop = asyncio.new_event_loop() - h3_result["session"] = loop.run_until_complete( - establish_h3_best_effort(edge_domain) - ) - loop.close() + h3_result["transport"] = h3_pool.get_sync_transport(edge_domain) except Exception: pass @@ -264,19 +246,22 @@ def _warm_h3() -> None: body=sandbox, ) - # Check if response is an error if isinstance(response, SandboxError): status_code = response.status_code if response.status_code is not UNSET else None code = response.code if response.code else None message = response.message if response.message else str(response) raise SandboxAPIError(message, status_code=status_code, code=code) - instance = cls(response) - - # Retrieve H3 session from warming thread (non-blocking) + # Wait for H3 warmup to finish + h3_transport: SyncH3Transport | None = None if h3_thread is not None: h3_thread.join(timeout=5) - instance.h3_session = h3_result["session"] + h3_transport = h3_result["transport"] + + config = SandboxConfiguration(sandbox=response) + config.h3_transport = h3_transport + instance = cls(config) + if safe: try: instance.fs.ls("/") diff --git a/src/blaxel/core/sandbox/types.py b/src/blaxel/core/sandbox/types.py index c5ac0d9d..ea4bb9e9 100644 --- a/src/blaxel/core/sandbox/types.py +++ b/src/blaxel/core/sandbox/types.py @@ -95,6 +95,7 @@ def __init__( self.force_url = force_url self.headers = headers or {} self.params = params or {} + self.h3_transport: Any = None @property def metadata(self): diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/benchmarks/bench_cold_call.py b/tests/benchmarks/bench_cold_call.py new file mode 100644 index 00000000..a8b5cb6d --- /dev/null +++ b/tests/benchmarks/bench_cold_call.py @@ -0,0 +1,201 @@ +"""Cold-call benchmark: create → call → delete. + +Measures end-to-end latency of sandbox lifecycle operations with +per-phase breakdown (create / call / delete). + +Usage: + python -m tests.benchmarks.bench_cold_call + python -m tests.benchmarks.bench_cold_call --iterations 5 --warmup 0 +""" + +import argparse +import asyncio +import statistics +import time +from dataclasses import dataclass, field + +from blaxel.core.sandbox import SandboxInstance +from tests.helpers import default_image, default_labels, default_region, unique_name + +ITERATIONS = 10 +WARMUP_ITERATIONS = 1 + + +@dataclass +class Timing: + total: float = 0.0 + create: float = 0.0 + call: float = 0.0 + delete: float = 0.0 + + +async def bench_cold_ls(iteration: int) -> Timing: + """create -> fs.ls('/') -> delete""" + name = unique_name("bench-cold-ls") + t = Timing() + + t0 = time.perf_counter() + sandbox = await SandboxInstance.create({ + "name": name, + "image": default_image, + "labels": default_labels, + "memory": 2048, + "region": default_region, + }) + t1 = time.perf_counter() + t.create = t1 - t0 + + await sandbox.fs.ls("/") + t2 = time.perf_counter() + t.call = t2 - t1 + + try: + await SandboxInstance.delete(name) + except Exception: + pass + t3 = time.perf_counter() + t.delete = t3 - t2 + + t.total = t3 - t0 + return t + + +async def bench_cold_exec(iteration: int) -> Timing: + """create -> process.exec('echo ok') -> delete""" + name = unique_name("bench-cold-exec") + t = Timing() + + t0 = time.perf_counter() + sandbox = await SandboxInstance.create({ + "name": name, + "image": default_image, + "labels": default_labels, + "memory": 2048, + "region": default_region, + }) + t1 = time.perf_counter() + t.create = t1 - t0 + + await sandbox.process.exec({ + "command": "echo ok", + "wait_for_completion": True, + }) + t2 = time.perf_counter() + t.call = t2 - t1 + + try: + await SandboxInstance.delete(name) + except Exception: + pass + t3 = time.perf_counter() + t.delete = t3 - t2 + + t.total = t3 - t0 + return t + + +def fmt(seconds: float) -> str: + return f"{seconds * 1000:.0f}ms" + + +def format_stats(values: list[float], label: str) -> str: + if not values: + return f" {label}: no data" + mean = statistics.mean(values) + med = statistics.median(values) + mn = min(values) + mx = max(values) + std = statistics.stdev(values) if len(values) > 1 else 0.0 + return ( + f" {label:16s} " + f"mean={fmt(mean):>7s} med={fmt(med):>7s} " + f"min={fmt(mn):>7s} max={fmt(mx):>7s} std={fmt(std):>7s}" + ) + + +async def run_benchmark( + name: str, + fn, + iterations: int, + warmup: int, +) -> list[Timing]: + print(f"\n{'='*80}") + print(f" {name}") + print(f"{'='*80}") + + if warmup > 0: + print(f" warming up ({warmup} iteration{'s' if warmup > 1 else ''})...") + for i in range(warmup): + try: + await fn(i) + except Exception as e: + print(f" warmup {i+1} failed: {e}") + + timings: list[Timing] = [] + for i in range(iterations): + try: + t = await fn(i) + timings.append(t) + print( + f" [{i+1:>2}/{iterations}] " + f"total={fmt(t.total):>7s} " + f"create={fmt(t.create):>7s} " + f"call={fmt(t.call):>7s} " + f"delete={fmt(t.delete):>7s}" + ) + except Exception as e: + print(f" [{i+1:>2}/{iterations}] FAILED: {e}") + + if timings: + print() + print(format_stats([t.total for t in timings], "total")) + print(format_stats([t.create for t in timings], "create")) + print(format_stats([t.call for t in timings], "call")) + print(format_stats([t.delete for t in timings], "delete")) + return timings + + +async def main(iterations: int, warmup: int) -> None: + print(f"\nCold-call benchmark (create -> call -> delete)") + print(f" iterations={iterations}, warmup={warmup}") + print(f" image={default_image}, region={default_region}") + + ls_timings = await run_benchmark( + "create -> fs.ls('/') -> delete", + bench_cold_ls, + iterations, + warmup, + ) + + exec_timings = await run_benchmark( + "create -> process.exec('echo ok') -> delete", + bench_cold_exec, + iterations, + warmup, + ) + + print(f"\n{'='*80}") + print(" SUMMARY") + print(f"{'='*80}") + if ls_timings: + print(" fs.ls:") + print(format_stats([t.total for t in ls_timings], "total")) + print(format_stats([t.create for t in ls_timings], "create")) + print(format_stats([t.call for t in ls_timings], "call")) + print(format_stats([t.delete for t in ls_timings], "delete")) + if exec_timings: + print(" process.exec:") + print(format_stats([t.total for t in exec_timings], "total")) + print(format_stats([t.create for t in exec_timings], "create")) + print(format_stats([t.call for t in exec_timings], "call")) + print(format_stats([t.delete for t in exec_timings], "delete")) + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Cold-call sandbox benchmark") + parser.add_argument("--iterations", type=int, default=ITERATIONS) + parser.add_argument("--warmup", type=int, default=WARMUP_ITERATIONS) + args = parser.parse_args() + + asyncio.run(main(args.iterations, args.warmup)) From c46c500ac06cb375eeb7a3e702c4d07300e9f22a Mon Sep 17 00:00:00 2001 From: Joffref Date: Thu, 5 Mar 2026 01:13:34 -0800 Subject: [PATCH 05/16] Test benchmark: fix format_stats function --- tests/benchmarks/bench_cold_call.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/benchmarks/bench_cold_call.py b/tests/benchmarks/bench_cold_call.py index a8b5cb6d..9aa2e943 100644 --- a/tests/benchmarks/bench_cold_call.py +++ b/tests/benchmarks/bench_cold_call.py @@ -12,7 +12,7 @@ import asyncio import statistics import time -from dataclasses import dataclass, field +from dataclasses import dataclass from blaxel.core.sandbox import SandboxInstance from tests.helpers import default_image, default_labels, default_region, unique_name @@ -98,18 +98,31 @@ def fmt(seconds: float) -> str: return f"{seconds * 1000:.0f}ms" +def percentile(values: list[float], p: float) -> float: + sorted_v = sorted(values) + k = (len(sorted_v) - 1) * (p / 100) + f = int(k) + c = f + 1 + if c >= len(sorted_v): + return sorted_v[f] + return sorted_v[f] + (k - f) * (sorted_v[c] - sorted_v[f]) + + + def format_stats(values: list[float], label: str) -> str: if not values: return f" {label}: no data" mean = statistics.mean(values) - med = statistics.median(values) + p50 = percentile(values, 50) + p90 = percentile(values, 90) + p99 = percentile(values, 99) mn = min(values) mx = max(values) - std = statistics.stdev(values) if len(values) > 1 else 0.0 return ( f" {label:16s} " - f"mean={fmt(mean):>7s} med={fmt(med):>7s} " - f"min={fmt(mn):>7s} max={fmt(mx):>7s} std={fmt(std):>7s}" + f"mean={fmt(mean):>7s} p50={fmt(p50):>7s} " + f"p90={fmt(p90):>7s} p99={fmt(p99):>7s} " + f"min={fmt(mn):>7s} max={fmt(mx):>7s}" ) From 1bc9d9d81c91facb31e788c1aec574523e615eed Mon Sep 17 00:00:00 2001 From: Joffref Date: Thu, 5 Mar 2026 12:00:23 -0800 Subject: [PATCH 06/16] feat: implement fallback transport for H3 connections - Introduced AsyncH3FallbackTransport and SyncH3FallbackTransport classes to handle automatic downgrading from H3 to HTTP/2 on connection failures. - Updated H3Pool to remember failed hosts for a specified TTL to optimize connection attempts. - Enhanced sandbox actions to utilize HTTP/2 when H3 transport is unavailable. This change improves resilience and performance in scenarios where H3 connections may fail. --- src/blaxel/core/common/h3transport.py | 207 +++++++++++++++++---- src/blaxel/core/sandbox/default/action.py | 3 + src/blaxel/core/sandbox/default/sandbox.py | 6 +- src/blaxel/core/sandbox/sync/action.py | 3 + src/blaxel/core/sandbox/sync/sandbox.py | 5 +- 5 files changed, 185 insertions(+), 39 deletions(-) diff --git a/src/blaxel/core/common/h3transport.py b/src/blaxel/core/common/h3transport.py index 4ad41f63..af80411e 100644 --- a/src/blaxel/core/common/h3transport.py +++ b/src/blaxel/core/common/h3transport.py @@ -3,14 +3,17 @@ Provides an async H3Transport (httpx.AsyncBaseTransport) and a sync SyncH3Transport (httpx.BaseTransport) backed by a shared connection pool keyed by (host, port). + +When UDP is blocked or an H3 connection fails mid-session, callers +automatically fall back to HTTP/2 (or HTTP/1.1 if h2 is unavailable). """ from __future__ import annotations import asyncio import logging -import socket import threading +import time from collections import deque from typing import AsyncIterator, Deque from urllib.parse import urlparse @@ -27,12 +30,21 @@ logger = logging.getLogger(__name__) _H3_CONNECT_TIMEOUT = 5.0 +_H3_FAIL_TTL = 300.0 # remember failures for 5 min before retrying + +try: + import h2 as _h2 # noqa: F401 + + HTTP2_AVAILABLE = True +except ImportError: + HTTP2_AVAILABLE = False # --------------------------------------------------------------------------- # Async H3 transport (one QUIC connection per instance) # --------------------------------------------------------------------------- + class _H3ByteStream(httpx.AsyncByteStream): def __init__(self, aiterator: AsyncIterator[bytes]): self._aiterator = aiterator @@ -42,7 +54,7 @@ async def __aiter__(self) -> AsyncIterator[bytes]: yield part -class H3Transport(QuicConnectionProtocol, httpx.AsyncBaseTransport): +class _H3Transport(QuicConnectionProtocol, httpx.AsyncBaseTransport): """httpx async transport over a single QUIC/H3 connection.""" def __init__(self, *args, **kwargs) -> None: @@ -140,13 +152,14 @@ async def _wait_for_http_event(self, stream_id: int) -> H3Event: # --------------------------------------------------------------------------- -# Sync H3 transport (bridges async transport via a background event loop) +# Sync H3 bridge (delegates to the async _H3Transport via a bg loop) # --------------------------------------------------------------------------- -class SyncH3Transport(httpx.BaseTransport): - """Sync httpx transport that delegates to an async H3Transport.""" - def __init__(self, async_transport: H3Transport, loop: asyncio.AbstractEventLoop): +class _SyncH3Transport(httpx.BaseTransport): + """Sync httpx transport that delegates to an async _H3Transport.""" + + def __init__(self, async_transport: _H3Transport, loop: asyncio.AbstractEventLoop): self._async_transport = async_transport self._loop = loop @@ -161,20 +174,98 @@ def close(self) -> None: pass +# --------------------------------------------------------------------------- +# Fallback transports: try H3, auto-downgrade to HTTP/2 on failure +# --------------------------------------------------------------------------- + + +class AsyncH3FallbackTransport(httpx.AsyncBaseTransport): + """Async transport that tries H3 and falls back to HTTP/2 on failure.""" + + def __init__(self, h3: _H3Transport, host: str, port: int): + self._h3 = h3 + self._host = host + self._port = port + self._h2_fallback: httpx.AsyncHTTPTransport | None = None + self._use_fallback = False + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + if self._use_fallback: + return await self._ensure_h2().handle_async_request(request) + try: + return await self._h3.handle_async_request(request) + except Exception: + logger.info( + "H3 request to %s:%d failed, downgrading to HTTP/2", + self._host, + self._port, + ) + self._use_fallback = True + pool._mark_failed(self._host, self._port) + return await self._ensure_h2().handle_async_request(request) + + def _ensure_h2(self) -> httpx.AsyncHTTPTransport: + if self._h2_fallback is None: + self._h2_fallback = httpx.AsyncHTTPTransport(http2=HTTP2_AVAILABLE) + return self._h2_fallback + + async def aclose(self) -> None: + if self._h2_fallback is not None: + await self._h2_fallback.aclose() + + +class SyncH3FallbackTransport(httpx.BaseTransport): + """Sync transport that tries H3 and falls back to HTTP/2 on failure.""" + + def __init__(self, sync_h3: _SyncH3Transport, host: str, port: int): + self._sync_h3 = sync_h3 + self._host = host + self._port = port + self._h2_fallback: httpx.HTTPTransport | None = None + self._use_fallback = False + + def handle_request(self, request: httpx.Request) -> httpx.Response: + if self._use_fallback: + return self._ensure_h2().handle_request(request) + try: + return self._sync_h3.handle_request(request) + except Exception: + logger.info( + "H3 request to %s:%d failed, downgrading to HTTP/2", + self._host, + self._port, + ) + self._use_fallback = True + pool._mark_failed(self._host, self._port) + return self._ensure_h2().handle_request(request) + + def _ensure_h2(self) -> httpx.HTTPTransport: + if self._h2_fallback is None: + self._h2_fallback = httpx.HTTPTransport(http2=HTTP2_AVAILABLE) + return self._h2_fallback + + def close(self) -> None: + if self._h2_fallback is not None: + self._h2_fallback.close() + + # --------------------------------------------------------------------------- # Connection pool # --------------------------------------------------------------------------- + class H3Pool: """Global pool of H3 transports keyed by (host, port). Manages a background event loop thread for sync callers and QUIC - connection lifecycle. + connection lifecycle. Failed hosts are remembered for ``_H3_FAIL_TTL`` + seconds so that repeated connection attempts don't add latency. """ def __init__(self) -> None: - self._async_transports: dict[tuple[str, int], H3Transport] = {} + self._async_transports: dict[tuple[str, int], _H3Transport] = {} self._connect_contexts: dict[tuple[str, int], object] = {} + self._failed_hosts: dict[tuple[str, int], float] = {} self._lock = threading.Lock() self._async_lock: asyncio.Lock | None = None self._bg_loop: asyncio.AbstractEventLoop | None = None @@ -185,6 +276,26 @@ def _get_async_lock(self) -> asyncio.Lock: self._async_lock = asyncio.Lock() return self._async_lock + # -- negative cache ------------------------------------------------------ + + def _is_failed(self, host: str, port: int) -> bool: + key = (host, port) + with self._lock: + ts = self._failed_hosts.get(key) + if ts is None: + return False + if time.monotonic() - ts > _H3_FAIL_TTL: + del self._failed_hosts[key] + return False + return True + + def _mark_failed(self, host: str, port: int) -> None: + key = (host, port) + with self._lock: + self._failed_hosts[key] = time.monotonic() + self._connect_contexts.pop(key, None) + self._async_transports.pop(key, None) + # -- background event loop for sync callers ------------------------------ def _ensure_bg_loop(self) -> asyncio.AbstractEventLoop: @@ -204,16 +315,10 @@ def _run_loop(): ready.wait(timeout=5) return self._bg_loop # type: ignore[return-value] - # -- async API ----------------------------------------------------------- - - async def get_async_transport( - self, host: str, port: int = 443 - ) -> H3Transport | None: - """Get or create an H3Transport for the given host. + # -- internal: raw H3 connection ----------------------------------------- - Returns None if the QUIC handshake fails (caller should fall back - to TCP). - """ + async def _get_or_connect(self, host: str, port: int) -> _H3Transport | None: + """Get a cached _H3Transport or establish a new QUIC connection.""" key = (host, port) async with self._get_async_lock(): transport = self._async_transports.get(key) @@ -227,10 +332,11 @@ async def get_async_transport( self._async_transports[key] = transport return transport except Exception: - logger.debug("H3 connection to %s:%d failed, falling back to TCP", host, port) + logger.debug("H3 connection to %s:%d failed", host, port) + self._mark_failed(host, port) return None - async def _connect(self, host: str, port: int) -> H3Transport: + async def _connect(self, host: str, port: int) -> _H3Transport: configuration = QuicConfiguration( is_client=True, alpn_protocols=H3_ALPN, @@ -240,33 +346,55 @@ async def _connect(self, host: str, port: int) -> H3Transport: host, port, configuration=configuration, - create_protocol=H3Transport, + create_protocol=_H3Transport, ) transport = await ctx.__aenter__() with self._lock: self._connect_contexts[(host, port)] = ctx return transport # type: ignore[return-value] - # -- sync API (dispatches to bg loop) ------------------------------------ + # -- public async API ---------------------------------------------------- + + async def get_async_transport( + self, host: str, port: int = 443 + ) -> AsyncH3FallbackTransport | None: + """Get an H3 transport with automatic HTTP/2 fallback for the host. + + Returns None if the QUIC handshake fails (caller should fall back + to HTTP/2 or use ``get_async_transport_for_url`` which does this + automatically). + """ + if self._is_failed(host, port): + return None + raw = await self._get_or_connect(host, port) + if raw is None: + return None + return AsyncH3FallbackTransport(raw, host, port) + + # -- public sync API (dispatches to bg loop) ----------------------------- def get_sync_transport( self, host: str, port: int = 443 - ) -> SyncH3Transport | None: - """Get or create a SyncH3Transport for the given host. + ) -> SyncH3FallbackTransport | None: + """Get a sync H3 transport with automatic HTTP/2 fallback. - Returns None on failure (caller should fall back to TCP). + Returns None on failure (caller should fall back to HTTP/2 or use + ``get_sync_transport_for_url``). """ + if self._is_failed(host, port): + return None loop = self._ensure_bg_loop() future = asyncio.run_coroutine_threadsafe( - self.get_async_transport(host, port), loop + self._get_or_connect(host, port), loop ) try: - async_transport = future.result(timeout=_H3_CONNECT_TIMEOUT + 1) + raw = future.result(timeout=_H3_CONNECT_TIMEOUT + 1) except Exception: + self._mark_failed(host, port) return None - if async_transport is None: + if raw is None: return None - return SyncH3Transport(async_transport, loop) + return SyncH3FallbackTransport(_SyncH3Transport(raw, loop), host, port) # -- shutdown ------------------------------------------------------------ @@ -298,9 +426,10 @@ def close_all_sync(self) -> None: # --------------------------------------------------------------------------- -# Helpers +# Helpers — return the best available transport (H3 → HTTP/2 → None) # --------------------------------------------------------------------------- + def _parse_host_port(url: str) -> tuple[str, int]: parsed = urlparse(url) host = parsed.hostname or "" @@ -308,15 +437,27 @@ def _parse_host_port(url: str) -> tuple[str, int]: return host, port -async def get_async_transport_for_url(url: str) -> H3Transport | None: +async def get_async_transport_for_url(url: str) -> httpx.AsyncBaseTransport | None: + """Best-effort transport for *url*: H3 with fallback, else HTTP/2, else None.""" host, port = _parse_host_port(url) if not host: return None - return await pool.get_async_transport(host, port) + transport = await pool.get_async_transport(host, port) + if transport is not None: + return transport + if HTTP2_AVAILABLE: + return httpx.AsyncHTTPTransport(http2=True) + return None -def get_sync_transport_for_url(url: str) -> SyncH3Transport | None: +def get_sync_transport_for_url(url: str) -> httpx.BaseTransport | None: + """Best-effort transport for *url*: H3 with fallback, else HTTP/2, else None.""" host, port = _parse_host_port(url) if not host: return None - return pool.get_sync_transport(host, port) + transport = pool.get_sync_transport(host, port) + if transport is not None: + return transport + if HTTP2_AVAILABLE: + return httpx.HTTPTransport(http2=True) + return None diff --git a/src/blaxel/core/sandbox/default/action.py b/src/blaxel/core/sandbox/default/action.py index 69b7d95b..be717108 100644 --- a/src/blaxel/core/sandbox/default/action.py +++ b/src/blaxel/core/sandbox/default/action.py @@ -1,5 +1,6 @@ import httpx +from ...common.h3transport import HTTP2_AVAILABLE from ...common.internal import get_forced_url, get_global_unique_hash from ...common.settings import settings from ..types import ResponseError, SandboxConfiguration @@ -61,6 +62,8 @@ def get_client(self) -> httpx.AsyncClient: kwargs: dict = {} if transport is not None: kwargs["transport"] = transport + elif HTTP2_AVAILABLE: + kwargs["http2"] = True self._client = httpx.AsyncClient( base_url=base_url, headers=self.sandbox_config.headers diff --git a/src/blaxel/core/sandbox/default/sandbox.py b/src/blaxel/core/sandbox/default/sandbox.py index dad59843..597e30d7 100644 --- a/src/blaxel/core/sandbox/default/sandbox.py +++ b/src/blaxel/core/sandbox/default/sandbox.py @@ -24,7 +24,7 @@ from ...client.models.error import Error from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET -from ...common.h3transport import H3Transport, pool as h3_pool +from ...common.h3transport import pool as h3_pool from ...common.settings import settings from ..types import ( SandboxConfiguration, @@ -264,7 +264,7 @@ async def create( region = getattr(sandbox.spec, "region", None) or settings.region # Pre-warm H3 transport in parallel with sandbox creation - h3_warm_task: asyncio.Task[H3Transport | None] | None = None + h3_warm_task: asyncio.Task | None = None if region: edge_domain = f"any.{region}.bl.run" h3_warm_task = asyncio.create_task(h3_pool.get_async_transport(edge_domain)) @@ -285,7 +285,7 @@ async def create( assert response is not None # Await H3 transport so the pool is warm for subsequent data-plane calls - h3_transport: H3Transport | None = None + h3_transport = None if h3_warm_task is not None: try: h3_transport = await h3_warm_task diff --git a/src/blaxel/core/sandbox/sync/action.py b/src/blaxel/core/sandbox/sync/action.py index 19894520..860154ef 100644 --- a/src/blaxel/core/sandbox/sync/action.py +++ b/src/blaxel/core/sandbox/sync/action.py @@ -1,5 +1,6 @@ import httpx +from ...common.h3transport import HTTP2_AVAILABLE from ...common.internal import get_forced_url, get_global_unique_hash from ...common.settings import settings from ..types import ResponseError, SandboxConfiguration @@ -53,6 +54,8 @@ def get_client(self) -> httpx.Client: kwargs: dict = {} if transport is not None: kwargs["transport"] = transport + elif HTTP2_AVAILABLE: + kwargs["http2"] = True if self.sandbox_config.force_url: return httpx.Client( base_url=self.sandbox_config.force_url, diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index b05ed87a..46e913df 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -16,7 +16,6 @@ from ...client.models.error import Error from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET -from ...common.h3transport import SyncH3Transport from ...common.h3transport import pool as h3_pool from ...common.settings import settings from ..default.sandbox import SandboxAPIError @@ -227,7 +226,7 @@ def create( region = getattr(sandbox.spec, "region", None) or settings.region # Pre-warm H3 transport in a background thread - h3_result: dict[str, SyncH3Transport | None] = {"transport": None} + h3_result: dict = {"transport": None} h3_thread: threading.Thread | None = None if region: edge_domain = f"any.{region}.bl.run" @@ -253,7 +252,7 @@ def _warm_h3() -> None: raise SandboxAPIError(message, status_code=status_code, code=code) # Wait for H3 warmup to finish - h3_transport: SyncH3Transport | None = None + h3_transport = None if h3_thread is not None: h3_thread.join(timeout=5) h3_transport = h3_result["transport"] From a42a765b61fc8cae849592d15f02b97aa79f5cb4 Mon Sep 17 00:00:00 2001 From: mjoffre Date: Thu, 5 Mar 2026 20:28:40 +0000 Subject: [PATCH 07/16] fix: lazy h3transport imports + fresh headers via event hooks - Make all h3transport/aioquic imports lazy (inside functions) so that aioquic is never loaded unless sandbox operations actually need it. This prevents the H3Pool background event loop from interfering with pytest-asyncio during integration tests. - Use httpx event hooks to inject fresh settings.headers on every request instead of baking them at client creation time. This ensures token refreshes are picked up automatically without recreating the client. Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/common/autoload.py | 6 ++-- src/blaxel/core/sandbox/default/action.py | 30 +++++++++++++--- src/blaxel/core/sandbox/default/filesystem.py | 3 +- src/blaxel/core/sandbox/default/process.py | 5 ++- src/blaxel/core/sandbox/default/sandbox.py | 3 +- src/blaxel/core/sandbox/sync/action.py | 36 ++++++++++++++----- src/blaxel/core/sandbox/sync/filesystem.py | 3 +- src/blaxel/core/sandbox/sync/process.py | 5 ++- src/blaxel/core/sandbox/sync/sandbox.py | 3 +- 9 files changed, 72 insertions(+), 22 deletions(-) diff --git a/src/blaxel/core/common/autoload.py b/src/blaxel/core/common/autoload.py index f0d43e14..46a5c9a4 100644 --- a/src/blaxel/core/common/autoload.py +++ b/src/blaxel/core/common/autoload.py @@ -8,7 +8,6 @@ response_interceptors_sync, ) from ..sandbox.client import client as client_sandbox -from .h3transport import pool as h3_pool from .sentry import init_sentry from .settings import settings @@ -48,7 +47,8 @@ def autoload() -> None: except Exception: pass - # Pre-warm H3 connection to API endpoint in background + # Pre-warm H3 connection to API endpoint in background. + # Import is lazy so that aioquic is not loaded unless really needed. try: api_hostname = urlparse(settings.base_url).hostname if api_hostname: @@ -62,6 +62,8 @@ def _warm_api_h3(hostname: str) -> None: def _do_warm() -> None: try: + from .h3transport import pool as h3_pool + h3_pool.get_sync_transport(hostname, 443) except Exception: pass diff --git a/src/blaxel/core/sandbox/default/action.py b/src/blaxel/core/sandbox/default/action.py index be717108..0a0f0a2c 100644 --- a/src/blaxel/core/sandbox/default/action.py +++ b/src/blaxel/core/sandbox/default/action.py @@ -1,8 +1,14 @@ import httpx -from ...common.h3transport import HTTP2_AVAILABLE from ...common.internal import get_forced_url, get_global_unique_hash from ...common.settings import settings + +try: + import h2 as _h2 # noqa: F401 + + HTTP2_AVAILABLE = True +except ImportError: + HTTP2_AVAILABLE = False from ..types import ResponseError, SandboxConfiguration @@ -55,7 +61,11 @@ def fallback_url(self) -> str | None: return None def get_client(self) -> httpx.AsyncClient: - """Get persistent HTTP client for this sandbox instance.""" + """Get persistent HTTP client for this sandbox instance. + + Headers are injected via an event hook so that token refreshes are + picked up automatically without recreating the client. + """ if self._client is None: base_url = self.sandbox_config.force_url or self.url transport = getattr(self.sandbox_config, "h3_transport", None) @@ -64,11 +74,21 @@ def get_client(self) -> httpx.AsyncClient: kwargs["transport"] = transport elif HTTP2_AVAILABLE: kwargs["http2"] = True + + sandbox_config = self.sandbox_config + + async def _inject_headers(request: httpx.Request) -> None: + """Inject fresh headers before every request.""" + if sandbox_config.force_url: + fresh = sandbox_config.headers + else: + fresh = {**settings.headers, **sandbox_config.headers} + for key, value in fresh.items(): + request.headers[key] = value + self._client = httpx.AsyncClient( base_url=base_url, - headers=self.sandbox_config.headers - if self.sandbox_config.force_url - else {**settings.headers, **self.sandbox_config.headers}, + event_hooks={"request": [_inject_headers]}, limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), timeout=httpx.Timeout(300.0, connect=10.0), **kwargs, diff --git a/src/blaxel/core/sandbox/default/filesystem.py b/src/blaxel/core/sandbox/default/filesystem.py index aca4f4d0..12da6f22 100644 --- a/src/blaxel/core/sandbox/default/filesystem.py +++ b/src/blaxel/core/sandbox/default/filesystem.py @@ -7,7 +7,6 @@ import httpx -from ...common.h3transport import get_async_transport_for_url from ...common.settings import settings from ..client.models import Directory, FileRequest, SuccessResponse from ..types import ( @@ -395,6 +394,8 @@ async def start_watching(): url = f"{self.url}/watch/filesystem/{path}" headers = {**settings.headers, **self.sandbox_config.headers} + from ...common.h3transport import get_async_transport_for_url + transport = await get_async_transport_for_url(url) watch_kwargs: dict = {} if transport is not None: diff --git a/src/blaxel/core/sandbox/default/process.py b/src/blaxel/core/sandbox/default/process.py index 7f5e949e..5ee38a4e 100644 --- a/src/blaxel/core/sandbox/default/process.py +++ b/src/blaxel/core/sandbox/default/process.py @@ -3,7 +3,6 @@ import httpx -from ...common.h3transport import get_async_transport_for_url from ...common.settings import settings from ..client.models import ProcessResponse, SuccessResponse from ..client.models.process_request import ProcessRequest @@ -155,6 +154,8 @@ async def start_streaming(): headers = {**settings.headers, **self.sandbox_config.headers} try: + from ...common.h3transport import get_async_transport_for_url + transport = await get_async_transport_for_url(url) kwargs: dict = {} if transport is not None: @@ -301,6 +302,8 @@ async def _exec_with_streaming( else {**settings.headers, **self.sandbox_config.headers} ) + from ...common.h3transport import get_async_transport_for_url + transport = await get_async_transport_for_url(self.url) stream_kwargs: dict = {} if transport is not None: diff --git a/src/blaxel/core/sandbox/default/sandbox.py b/src/blaxel/core/sandbox/default/sandbox.py index 597e30d7..16ae4bd6 100644 --- a/src/blaxel/core/sandbox/default/sandbox.py +++ b/src/blaxel/core/sandbox/default/sandbox.py @@ -24,7 +24,6 @@ from ...client.models.error import Error from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET -from ...common.h3transport import pool as h3_pool from ...common.settings import settings from ..types import ( SandboxConfiguration, @@ -266,6 +265,8 @@ async def create( # Pre-warm H3 transport in parallel with sandbox creation h3_warm_task: asyncio.Task | None = None if region: + from ...common.h3transport import pool as h3_pool + edge_domain = f"any.{region}.bl.run" h3_warm_task = asyncio.create_task(h3_pool.get_async_transport(edge_domain)) diff --git a/src/blaxel/core/sandbox/sync/action.py b/src/blaxel/core/sandbox/sync/action.py index 860154ef..63baef10 100644 --- a/src/blaxel/core/sandbox/sync/action.py +++ b/src/blaxel/core/sandbox/sync/action.py @@ -1,8 +1,14 @@ import httpx -from ...common.h3transport import HTTP2_AVAILABLE from ...common.internal import get_forced_url, get_global_unique_hash from ...common.settings import settings + +try: + import h2 as _h2 # noqa: F401 + + HTTP2_AVAILABLE = True +except ImportError: + HTTP2_AVAILABLE = False from ..types import ResponseError, SandboxConfiguration @@ -50,21 +56,33 @@ def fallback_url(self) -> str | None: return None def get_client(self) -> httpx.Client: + """Get an HTTP client for this sandbox instance. + + Headers are injected via an event hook so that token refreshes are + picked up automatically without recreating the client. + """ transport = getattr(self.sandbox_config, "h3_transport", None) kwargs: dict = {} if transport is not None: kwargs["transport"] = transport elif HTTP2_AVAILABLE: kwargs["http2"] = True - if self.sandbox_config.force_url: - return httpx.Client( - base_url=self.sandbox_config.force_url, - headers=self.sandbox_config.headers, - **kwargs, - ) + + sandbox_config = self.sandbox_config + + def _inject_headers(request: httpx.Request) -> None: + """Inject fresh headers before every request.""" + if sandbox_config.force_url: + fresh = sandbox_config.headers + else: + fresh = {**settings.headers, **sandbox_config.headers} + for key, value in fresh.items(): + request.headers[key] = value + + base_url = self.sandbox_config.force_url or self.url return httpx.Client( - base_url=self.url, - headers={**settings.headers, **self.sandbox_config.headers}, + base_url=base_url, + event_hooks={"request": [_inject_headers]}, **kwargs, ) diff --git a/src/blaxel/core/sandbox/sync/filesystem.py b/src/blaxel/core/sandbox/sync/filesystem.py index 6704b1d1..662b1be1 100644 --- a/src/blaxel/core/sandbox/sync/filesystem.py +++ b/src/blaxel/core/sandbox/sync/filesystem.py @@ -7,7 +7,6 @@ import httpx -from ...common.h3transport import get_sync_transport_for_url from ...common.settings import settings from ..client.models import Directory, FileRequest, SuccessResponse from ..types import ( @@ -192,6 +191,8 @@ def run(): params["ignore"] = ",".join(options["ignore"]) url = f"{self.url}/watch/filesystem/{path}" headers = {**settings.headers, **self.sandbox_config.headers} + from ...common.h3transport import get_sync_transport_for_url + transport = get_sync_transport_for_url(url) watch_kw: dict = {} if transport is not None: diff --git a/src/blaxel/core/sandbox/sync/process.py b/src/blaxel/core/sandbox/sync/process.py index 9e772138..d6e67bae 100644 --- a/src/blaxel/core/sandbox/sync/process.py +++ b/src/blaxel/core/sandbox/sync/process.py @@ -4,7 +4,6 @@ import httpx -from ...common.h3transport import get_sync_transport_for_url from ...common.settings import settings from ..client.models import ProcessResponse, SuccessResponse from ..client.models.process_request import ProcessRequest @@ -124,6 +123,8 @@ def run(): url = f"{self.url}/process/{identifier}/logs/stream" headers = {**settings.headers, **self.sandbox_config.headers} try: + from ...common.h3transport import get_sync_transport_for_url + transport = get_sync_transport_for_url(url) stream_kw: dict = {} if transport is not None: @@ -247,6 +248,8 @@ def _exec_with_streaming( else {**settings.headers, **self.sandbox_config.headers} ) + from ...common.h3transport import get_sync_transport_for_url + transport = get_sync_transport_for_url(self.url) exec_kw: dict = {} if transport is not None: diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index 46e913df..16d117a4 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -16,7 +16,6 @@ from ...client.models.error import Error from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET -from ...common.h3transport import pool as h3_pool from ...common.settings import settings from ..default.sandbox import SandboxAPIError from ..types import ( @@ -233,6 +232,8 @@ def create( def _warm_h3() -> None: try: + from ...common.h3transport import pool as h3_pool + h3_result["transport"] = h3_pool.get_sync_transport(edge_domain) except Exception: pass From f87ae369964d947dcc5bede5a33d7707ba44b9d5 Mon Sep 17 00:00:00 2001 From: mjoffre Date: Thu, 5 Mar 2026 20:55:52 +0000 Subject: [PATCH 08/16] fix: remove API endpoint H3 warming from autoload() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The _warm_api_h3() background thread creates a persistent event loop via H3Pool._ensure_bg_loop() that interferes with pytest-asyncio in integration tests. The API endpoint (api.blaxel.ai) is never accessed through h3transport anyway — only sandbox data-plane endpoints use H3. Sandbox edge-domain warming already happens in SandboxInstance.create(). Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/common/autoload.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/src/blaxel/core/common/autoload.py b/src/blaxel/core/common/autoload.py index 46a5c9a4..1119bd86 100644 --- a/src/blaxel/core/common/autoload.py +++ b/src/blaxel/core/common/autoload.py @@ -1,6 +1,4 @@ import logging -import threading -from urllib.parse import urlparse from ..client import client from ..client.response_interceptor import ( @@ -46,27 +44,3 @@ def autoload() -> None: telemetry() except Exception: pass - - # Pre-warm H3 connection to API endpoint in background. - # Import is lazy so that aioquic is not loaded unless really needed. - try: - api_hostname = urlparse(settings.base_url).hostname - if api_hostname: - _warm_api_h3(api_hostname) - except Exception: - pass - - -def _warm_api_h3(hostname: str) -> None: - """Pre-warm the H3 pool for the API endpoint in a background thread.""" - - def _do_warm() -> None: - try: - from .h3transport import pool as h3_pool - - h3_pool.get_sync_transport(hostname, 443) - except Exception: - pass - - thread = threading.Thread(target=_do_warm, daemon=True) - thread.start() From 6d195abacb88e6620e41759598eae7a03f2fcdf6 Mon Sep 17 00:00:00 2001 From: mjoffre Date: Thu, 5 Mar 2026 21:24:24 +0000 Subject: [PATCH 09/16] fix: remove H3 warming from SandboxInstance.create() to fix integration tests The H3 warming task (asyncio.create_task for QUIC connection) was interfering with the MCP client's streamablehttp_client during integration tests, causing 'Session terminated' errors. The warming imported aioquic and created QUIC protocol handlers on the event loop that disrupted anyio task groups used by the MCP library. The h3transport module and fallback transports remain available for future use. Sandbox data-plane calls still benefit from HTTP/2 via the event-hook headers pattern. Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/sandbox/default/sandbox.py | 24 +------------------ src/blaxel/core/sandbox/sync/sandbox.py | 28 +--------------------- 2 files changed, 2 insertions(+), 50 deletions(-) diff --git a/src/blaxel/core/sandbox/default/sandbox.py b/src/blaxel/core/sandbox/default/sandbox.py index 16ae4bd6..311972ee 100644 --- a/src/blaxel/core/sandbox/default/sandbox.py +++ b/src/blaxel/core/sandbox/default/sandbox.py @@ -1,4 +1,3 @@ -import asyncio import logging import uuid import warnings @@ -262,14 +261,6 @@ async def create( # Extract region from existing Sandbox spec region = getattr(sandbox.spec, "region", None) or settings.region - # Pre-warm H3 transport in parallel with sandbox creation - h3_warm_task: asyncio.Task | None = None - if region: - from ...common.h3transport import pool as h3_pool - - edge_domain = f"any.{region}.bl.run" - h3_warm_task = asyncio.create_task(h3_pool.get_async_transport(edge_domain)) - response = await create_sandbox( client=client, body=sandbox, @@ -279,23 +270,10 @@ async def create( status_code = response.status_code if response.status_code is not UNSET else None code = response.code if response.code else None message = response.message if response.message else str(response) - if h3_warm_task is not None: - h3_warm_task.cancel() raise SandboxAPIError(message, status_code=status_code, code=code) assert response is not None - - # Await H3 transport so the pool is warm for subsequent data-plane calls - h3_transport = None - if h3_warm_task is not None: - try: - h3_transport = await h3_warm_task - except Exception: - h3_transport = None - - config = SandboxConfiguration(sandbox=response) - config.h3_transport = h3_transport - instance = cls(config) + instance = cls(response) if safe: try: diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index 16d117a4..f30a1a06 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -1,5 +1,4 @@ import logging -import threading import uuid import warnings from typing import Any, Callable, Dict, List, Union @@ -224,23 +223,6 @@ def create( # Extract region from existing Sandbox spec region = getattr(sandbox.spec, "region", None) or settings.region - # Pre-warm H3 transport in a background thread - h3_result: dict = {"transport": None} - h3_thread: threading.Thread | None = None - if region: - edge_domain = f"any.{region}.bl.run" - - def _warm_h3() -> None: - try: - from ...common.h3transport import pool as h3_pool - - h3_result["transport"] = h3_pool.get_sync_transport(edge_domain) - except Exception: - pass - - h3_thread = threading.Thread(target=_warm_h3, daemon=True) - h3_thread.start() - response = create_sandbox( client=client, body=sandbox, @@ -252,15 +234,7 @@ def _warm_h3() -> None: message = response.message if response.message else str(response) raise SandboxAPIError(message, status_code=status_code, code=code) - # Wait for H3 warmup to finish - h3_transport = None - if h3_thread is not None: - h3_thread.join(timeout=5) - h3_transport = h3_result["transport"] - - config = SandboxConfiguration(sandbox=response) - config.h3_transport = h3_transport - instance = cls(config) + instance = cls(response) if safe: try: From 956883b364e7646dffa19f641747329d039504f3 Mon Sep 17 00:00:00 2001 From: Joffref Date: Fri, 6 Mar 2026 21:54:25 -0800 Subject: [PATCH 10/16] feat: enhance PersistentMcpClient to resolve sandbox URLs - Added logic to resolve the MCP server URL for sandbox environments by fetching metadata from the management API. - Introduced a new property `_resolved_url` to store the resolved URL and updated the transport selection logic to prioritize this URL when available. - Updated integration tests to include a wait for sandbox deployment to ensure proper setup before tests run. This change improves the handling of sandbox connections, ensuring accurate URL resolution for better reliability. --- src/blaxel/core/tools/__init__.py | 37 +++++++++++++++++++++++--- tests/integration/openai/test_tools.py | 8 +++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/src/blaxel/core/tools/__init__.py b/src/blaxel/core/tools/__init__.py index 8f0df278..18b26607 100644 --- a/src/blaxel/core/tools/__init__.py +++ b/src/blaxel/core/tools/__init__.py @@ -56,6 +56,7 @@ def __init__( self.use_fallback_url = False self.transport_name = transport self.metas = {} + self._resolved_url: str | None = None @property def _internal_url(self): @@ -169,7 +170,30 @@ async def _get_transport_type(self) -> str: if self.transport_name: return self.transport_name - # Make a request to the / endpoint to determine transport type + if self.type == "sandbox" and not self._resolved_url: + # For sandboxes, the MCP server runs at sandbox.metadata.url/mcp — a direct URL + # like https://sbx-{name}-{workspace}.{region}.bl.run/mcp, NOT at the API gateway + # path used by PersistentMcpClient._external_url. Fetch the sandbox metadata via + # the management API to get the correct direct URL. + try: + from ..client.api.compute.get_sandbox import asyncio as get_sandbox_api + from ..client.client import client as bl_client + from ..client.models.error import Error as BLError + from ..client.types import UNSET + + sandbox_response = await get_sandbox_api(self.name, client=bl_client) + if not isinstance(sandbox_response, BLError) and sandbox_response is not None: + meta = getattr(sandbox_response, "metadata", None) + url = getattr(meta, "url", None) if meta else None + if url and url is not UNSET and url != "": + self._resolved_url = str(url).rstrip("/") + logger.debug(f"Resolved sandbox MCP URL for {self.name}: {self._resolved_url}") + except Exception as e: + logger.warning(f"Failed to resolve sandbox URL for {self.name}: {e}") + self.transport_name = "http-stream" + return self.transport_name + + # Make a request to the / endpoint to determine transport type for non-sandbox resources try: async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as http_client: # Make a GET request to the root endpoint @@ -192,11 +216,16 @@ async def _get_transport_type(self) -> str: async def _get_transport(self, url: str = None): """Get the appropriate transport for the connection.""" - if url is None: - url = self._url - transport_type = await self._get_transport_type() + if url is None: + # Use the resolved URL if available (e.g. sandbox direct URL from metadata), + # falling back to the computed URL. Skip resolved URL when in fallback mode. + if self._resolved_url and not self.use_fallback_url: + url = self._resolved_url + else: + url = self._url + if transport_type == "http-stream": # Use streamablehttp_client for http-stream result = await self.client_exit_stack.enter_async_context( diff --git a/tests/integration/openai/test_tools.py b/tests/integration/openai/test_tools.py index 1f0989b1..89459f77 100644 --- a/tests/integration/openai/test_tools.py +++ b/tests/integration/openai/test_tools.py @@ -10,7 +10,12 @@ from blaxel.core.sandbox import SandboxInstance # noqa: E402 from blaxel.openai import bl_model, bl_tools # noqa: E402 -from tests.helpers import default_image, default_labels, unique_name # noqa: E402 +from tests.helpers import ( # noqa: E402 + default_image, + default_labels, + unique_name, + wait_for_sandbox_deployed, +) @pytest.mark.asyncio(loop_scope="class") @@ -32,6 +37,7 @@ async def setup_sandbox(self, request): "labels": default_labels, } ) + await wait_for_sandbox_deployed(request.cls.sandbox_name) yield From 9f90722235d2c76b9c6c77b4e95b3409925a93ee Mon Sep 17 00:00:00 2001 From: mjoffre Date: Mon, 9 Mar 2026 21:22:15 +0000 Subject: [PATCH 11/16] refactor: replace getattr with direct attribute access in sandbox URL resolution Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/tools/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/blaxel/core/tools/__init__.py b/src/blaxel/core/tools/__init__.py index 18b26607..7b29b16f 100644 --- a/src/blaxel/core/tools/__init__.py +++ b/src/blaxel/core/tools/__init__.py @@ -183,8 +183,7 @@ async def _get_transport_type(self) -> str: sandbox_response = await get_sandbox_api(self.name, client=bl_client) if not isinstance(sandbox_response, BLError) and sandbox_response is not None: - meta = getattr(sandbox_response, "metadata", None) - url = getattr(meta, "url", None) if meta else None + url = sandbox_response.metadata.url if url and url is not UNSET and url != "": self._resolved_url = str(url).rstrip("/") logger.debug(f"Resolved sandbox MCP URL for {self.name}: {self._resolved_url}") From cbfb12f2eb1ba4ac90494e76295a4c67c04b4832 Mon Sep 17 00:00:00 2001 From: mjoffre Date: Mon, 9 Mar 2026 22:13:03 +0000 Subject: [PATCH 12/16] fix: probe resolved sandbox URL for transport type instead of hard-coding http-stream Sandboxes now resolve their direct data-plane URL from metadata, then probe that URL for transport type (websocket vs http-stream) dynamically. This preserves backward compatibility with sandboxes that may still use the gateway URL or WebSocket transport. Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/tools/__init__.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/blaxel/core/tools/__init__.py b/src/blaxel/core/tools/__init__.py index 7b29b16f..ca9182c6 100644 --- a/src/blaxel/core/tools/__init__.py +++ b/src/blaxel/core/tools/__init__.py @@ -171,10 +171,10 @@ async def _get_transport_type(self) -> str: return self.transport_name if self.type == "sandbox" and not self._resolved_url: - # For sandboxes, the MCP server runs at sandbox.metadata.url/mcp — a direct URL - # like https://sbx-{name}-{workspace}.{region}.bl.run/mcp, NOT at the API gateway - # path used by PersistentMcpClient._external_url. Fetch the sandbox metadata via - # the management API to get the correct direct URL. + # For sandboxes, the MCP server may run at sandbox.metadata.url — a direct + # data-plane URL like https://sbx-{name}-{workspace}.{region}.bl.run, which + # differs from the API gateway path in _external_url. Resolve the direct URL + # from sandbox metadata so transport detection probes the right endpoint. try: from ..client.api.compute.get_sandbox import asyncio as get_sandbox_api from ..client.client import client as bl_client @@ -189,14 +189,15 @@ async def _get_transport_type(self) -> str: logger.debug(f"Resolved sandbox MCP URL for {self.name}: {self._resolved_url}") except Exception as e: logger.warning(f"Failed to resolve sandbox URL for {self.name}: {e}") - self.transport_name = "http-stream" - return self.transport_name - # Make a request to the / endpoint to determine transport type for non-sandbox resources + # Determine the URL to probe for transport type: use the resolved sandbox URL + # if available, otherwise fall back to the standard computed URL. + probe_url = self._resolved_url if self._resolved_url else self._url + + # Make a request to the / endpoint to determine transport type try: async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as http_client: - # Make a GET request to the root endpoint - response = await http_client.get(self._url + "/", headers=settings.headers) + response = await http_client.get(probe_url + "/", headers=settings.headers) if "websocket" in response.text.lower(): self.transport_name = "websocket" else: From caf8e5ac5675b0e4dcfd30336d065247025b2c04 Mon Sep 17 00:00:00 2001 From: mjoffre Date: Mon, 9 Mar 2026 22:22:13 +0000 Subject: [PATCH 13/16] fix: address Mendral review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. tools/__init__.py: probe_url now respects use_fallback_url so fallback probes the gateway URL instead of the unreachable direct sandbox URL 2. h3transport.py: move _async_transports.pop() inside the lock in _mark_failed() to fix data race with _get_or_connect() 3. pyproject.toml: move aioquic to optional [h3] extras group instead of core dependency — users who don't need H3 transport avoid the heavy transitive deps (pyopenssl, pylsqpack, service-identity) Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- pyproject.toml | 4 +++- src/blaxel/core/common/h3transport.py | 2 +- src/blaxel/core/tools/__init__.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0da7158d..0488fb6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "websockets<16.0.0", "attrs>=21.3.0", "httpx>=0.27.0", - "aioquic>=1.2.0", "mcp>=1.9.4", "dockerfile-parse>=2.0.0", ] @@ -25,6 +24,9 @@ dependencies = [ # Core functionality (always include blaxel.core dependencies) core = [] +# HTTP/3 (QUIC) transport support +h3 = ["aioquic>=1.2.0"] + # Telemetry module telemetry = [ "opentelemetry-exporter-otlp>=1.28.0", diff --git a/src/blaxel/core/common/h3transport.py b/src/blaxel/core/common/h3transport.py index af80411e..dad3b136 100644 --- a/src/blaxel/core/common/h3transport.py +++ b/src/blaxel/core/common/h3transport.py @@ -294,7 +294,7 @@ def _mark_failed(self, host: str, port: int) -> None: with self._lock: self._failed_hosts[key] = time.monotonic() self._connect_contexts.pop(key, None) - self._async_transports.pop(key, None) + self._async_transports.pop(key, None) # -- background event loop for sync callers ------------------------------ diff --git a/src/blaxel/core/tools/__init__.py b/src/blaxel/core/tools/__init__.py index ca9182c6..e75cccbf 100644 --- a/src/blaxel/core/tools/__init__.py +++ b/src/blaxel/core/tools/__init__.py @@ -191,8 +191,8 @@ async def _get_transport_type(self) -> str: logger.warning(f"Failed to resolve sandbox URL for {self.name}: {e}") # Determine the URL to probe for transport type: use the resolved sandbox URL - # if available, otherwise fall back to the standard computed URL. - probe_url = self._resolved_url if self._resolved_url else self._url + # if available and not in fallback mode, otherwise fall back to the standard computed URL. + probe_url = (self._resolved_url if self._resolved_url and not self.use_fallback_url else None) or self._url # Make a request to the / endpoint to determine transport type try: From 8f55cd72c9f3e1349c7c08284ab9c81dcefe8fc7 Mon Sep 17 00:00:00 2001 From: mjoffre Date: Mon, 9 Mar 2026 22:34:20 +0000 Subject: [PATCH 14/16] fix: make aioquic imports conditional for optional dependency support Wrap aioquic imports in try/except so h3transport module can be imported even when aioquic is not installed. The pool singleton and helper functions gracefully degrade to HTTP/2 or HTTP/1.1 when aioquic is unavailable. Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/common/h3transport.py | 34 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/blaxel/core/common/h3transport.py b/src/blaxel/core/common/h3transport.py index dad3b136..164d9cfa 100644 --- a/src/blaxel/core/common/h3transport.py +++ b/src/blaxel/core/common/h3transport.py @@ -19,12 +19,18 @@ from urllib.parse import urlparse import httpx -from aioquic.asyncio.client import connect -from aioquic.asyncio.protocol import QuicConnectionProtocol -from aioquic.h3.connection import H3_ALPN, H3Connection -from aioquic.h3.events import DataReceived, H3Event, HeadersReceived -from aioquic.quic.configuration import QuicConfiguration -from aioquic.quic.events import QuicEvent + +try: + from aioquic.asyncio.client import connect + from aioquic.asyncio.protocol import QuicConnectionProtocol + from aioquic.h3.connection import H3_ALPN, H3Connection + from aioquic.h3.events import DataReceived, H3Event, HeadersReceived + from aioquic.quic.configuration import QuicConfiguration + from aioquic.quic.events import QuicEvent + + AIOQUIC_AVAILABLE = True +except ImportError: + AIOQUIC_AVAILABLE = False logging.getLogger("quic").setLevel(logging.WARNING) logger = logging.getLogger(__name__) @@ -422,7 +428,7 @@ def close_all_sync(self) -> None: # Module-level singleton # --------------------------------------------------------------------------- -pool = H3Pool() +pool = H3Pool() if AIOQUIC_AVAILABLE else None # --------------------------------------------------------------------------- @@ -442,9 +448,10 @@ async def get_async_transport_for_url(url: str) -> httpx.AsyncBaseTransport | No host, port = _parse_host_port(url) if not host: return None - transport = await pool.get_async_transport(host, port) - if transport is not None: - return transport + if pool is not None: + transport = await pool.get_async_transport(host, port) + if transport is not None: + return transport if HTTP2_AVAILABLE: return httpx.AsyncHTTPTransport(http2=True) return None @@ -455,9 +462,10 @@ def get_sync_transport_for_url(url: str) -> httpx.BaseTransport | None: host, port = _parse_host_port(url) if not host: return None - transport = pool.get_sync_transport(host, port) - if transport is not None: - return transport + if pool is not None: + transport = pool.get_sync_transport(host, port) + if transport is not None: + return transport if HTTP2_AVAILABLE: return httpx.HTTPTransport(http2=True) return None From 881c74dfa14fe7a8f5097258476f588f2038dfef Mon Sep 17 00:00:00 2001 From: mjoffre Date: Mon, 9 Mar 2026 22:47:41 +0000 Subject: [PATCH 15/16] fix: guard all aioquic-dependent classes in if AIOQUIC_AVAILABLE block When aioquic is not installed (optional dependency), class definitions that inherit from QuicConnectionProtocol or reference aioquic types caused NameError at import time. Now all H3 transport classes, fallback transports, connection pool, and the pool singleton are inside the AIOQUIC_AVAILABLE guard. Helper functions gracefully return None. Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/common/h3transport.py | 745 +++++++++++++------------- 1 file changed, 378 insertions(+), 367 deletions(-) diff --git a/src/blaxel/core/common/h3transport.py b/src/blaxel/core/common/h3transport.py index 164d9cfa..e3e8667e 100644 --- a/src/blaxel/core/common/h3transport.py +++ b/src/blaxel/core/common/h3transport.py @@ -47,388 +47,399 @@ # --------------------------------------------------------------------------- -# Async H3 transport (one QUIC connection per instance) +# All aioquic-dependent classes are guarded so the module can be imported +# even when aioquic is not installed (optional dependency). # --------------------------------------------------------------------------- - -class _H3ByteStream(httpx.AsyncByteStream): - def __init__(self, aiterator: AsyncIterator[bytes]): - self._aiterator = aiterator - - async def __aiter__(self) -> AsyncIterator[bytes]: - async for part in self._aiterator: - yield part - - -class _H3Transport(QuicConnectionProtocol, httpx.AsyncBaseTransport): - """httpx async transport over a single QUIC/H3 connection.""" - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._http = H3Connection(self._quic) - self._read_queue: dict[int, Deque[H3Event]] = {} - self._read_ready: dict[int, asyncio.Event] = {} - - async def handle_async_request(self, request: httpx.Request) -> httpx.Response: - assert isinstance(request.stream, httpx.AsyncByteStream) - - stream_id = self._quic.get_next_available_stream_id() - self._read_queue[stream_id] = deque() - self._read_ready[stream_id] = asyncio.Event() - - self._http.send_headers( - stream_id=stream_id, - headers=[ - (b":method", request.method.encode()), - (b":scheme", request.url.raw_scheme), - (b":authority", request.url.netloc), - (b":path", request.url.raw_path), - ] - + [ - (k.lower(), v) - for (k, v) in request.headers.raw - if k.lower() not in (b"connection", b"host") - ], - ) - async for data in request.stream: - self._http.send_data(stream_id=stream_id, data=data, end_stream=False) - self._http.send_data(stream_id=stream_id, data=b"", end_stream=True) - self.transmit() - - status_code, headers, stream_ended = await self._receive_response(stream_id) - - return httpx.Response( - status_code=status_code, - headers=headers, - stream=_H3ByteStream(self._receive_response_data(stream_id, stream_ended)), - extensions={"http_version": b"HTTP/3"}, - ) - - # -- aioquic protocol callbacks ------------------------------------------ - - def http_event_received(self, event: H3Event) -> None: - if isinstance(event, (HeadersReceived, DataReceived)): - stream_id = event.stream_id - if stream_id in self._read_queue: - self._read_queue[stream_id].append(event) - self._read_ready[stream_id].set() - - def quic_event_received(self, event: QuicEvent) -> None: - if self._http is not None: - for http_event in self._http.handle_event(event): - self.http_event_received(http_event) - - # -- internal helpers ---------------------------------------------------- - - async def _receive_response(self, stream_id: int): - stream_ended = False - while True: - event = await self._wait_for_http_event(stream_id) - if isinstance(event, HeadersReceived): - stream_ended = event.stream_ended - break - - headers = [] - status_code = 0 - for header, value in event.headers: - if header == b":status": - status_code = int(value.decode()) - else: - headers.append((header, value)) - return status_code, headers, stream_ended - - async def _receive_response_data( - self, stream_id: int, stream_ended: bool - ) -> AsyncIterator[bytes]: - while not stream_ended: - event = await self._wait_for_http_event(stream_id) - if isinstance(event, DataReceived): - stream_ended = event.stream_ended - yield event.data - elif isinstance(event, HeadersReceived): - stream_ended = event.stream_ended - - async def _wait_for_http_event(self, stream_id: int) -> H3Event: - if not self._read_queue[stream_id]: - await self._read_ready[stream_id].wait() - event = self._read_queue[stream_id].popleft() - if not self._read_queue[stream_id]: - self._read_ready[stream_id].clear() - return event - - -# --------------------------------------------------------------------------- -# Sync H3 bridge (delegates to the async _H3Transport via a bg loop) -# --------------------------------------------------------------------------- - - -class _SyncH3Transport(httpx.BaseTransport): - """Sync httpx transport that delegates to an async _H3Transport.""" - - def __init__(self, async_transport: _H3Transport, loop: asyncio.AbstractEventLoop): - self._async_transport = async_transport - self._loop = loop - - def handle_request(self, request: httpx.Request) -> httpx.Response: - future = asyncio.run_coroutine_threadsafe( - self._async_transport.handle_async_request(request), - self._loop, - ) - return future.result(timeout=300) - - def close(self) -> None: - pass - - -# --------------------------------------------------------------------------- -# Fallback transports: try H3, auto-downgrade to HTTP/2 on failure -# --------------------------------------------------------------------------- - - -class AsyncH3FallbackTransport(httpx.AsyncBaseTransport): - """Async transport that tries H3 and falls back to HTTP/2 on failure.""" - - def __init__(self, h3: _H3Transport, host: str, port: int): - self._h3 = h3 - self._host = host - self._port = port - self._h2_fallback: httpx.AsyncHTTPTransport | None = None - self._use_fallback = False - - async def handle_async_request(self, request: httpx.Request) -> httpx.Response: - if self._use_fallback: - return await self._ensure_h2().handle_async_request(request) - try: - return await self._h3.handle_async_request(request) - except Exception: - logger.info( - "H3 request to %s:%d failed, downgrading to HTTP/2", - self._host, - self._port, +if AIOQUIC_AVAILABLE: + + class _H3ByteStream(httpx.AsyncByteStream): + def __init__(self, aiterator: AsyncIterator[bytes]): + self._aiterator = aiterator + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for part in self._aiterator: + yield part + + class _H3Transport(QuicConnectionProtocol, httpx.AsyncBaseTransport): + """httpx async transport over a single QUIC/H3 connection.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._http = H3Connection(self._quic) + self._read_queue: dict[int, Deque[H3Event]] = {} + self._read_ready: dict[int, asyncio.Event] = {} + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + + stream_id = self._quic.get_next_available_stream_id() + self._read_queue[stream_id] = deque() + self._read_ready[stream_id] = asyncio.Event() + + self._http.send_headers( + stream_id=stream_id, + headers=[ + (b":method", request.method.encode()), + (b":scheme", request.url.raw_scheme), + (b":authority", request.url.netloc), + (b":path", request.url.raw_path), + ] + + [ + (k.lower(), v) + for (k, v) in request.headers.raw + if k.lower() not in (b"connection", b"host") + ], ) - self._use_fallback = True - pool._mark_failed(self._host, self._port) - return await self._ensure_h2().handle_async_request(request) - - def _ensure_h2(self) -> httpx.AsyncHTTPTransport: - if self._h2_fallback is None: - self._h2_fallback = httpx.AsyncHTTPTransport(http2=HTTP2_AVAILABLE) - return self._h2_fallback - - async def aclose(self) -> None: - if self._h2_fallback is not None: - await self._h2_fallback.aclose() - - -class SyncH3FallbackTransport(httpx.BaseTransport): - """Sync transport that tries H3 and falls back to HTTP/2 on failure.""" - - def __init__(self, sync_h3: _SyncH3Transport, host: str, port: int): - self._sync_h3 = sync_h3 - self._host = host - self._port = port - self._h2_fallback: httpx.HTTPTransport | None = None - self._use_fallback = False - - def handle_request(self, request: httpx.Request) -> httpx.Response: - if self._use_fallback: - return self._ensure_h2().handle_request(request) - try: - return self._sync_h3.handle_request(request) - except Exception: - logger.info( - "H3 request to %s:%d failed, downgrading to HTTP/2", - self._host, - self._port, + async for data in request.stream: + self._http.send_data(stream_id=stream_id, data=data, end_stream=False) + self._http.send_data(stream_id=stream_id, data=b"", end_stream=True) + self.transmit() + + status_code, headers, stream_ended = await self._receive_response(stream_id) + + return httpx.Response( + status_code=status_code, + headers=headers, + stream=_H3ByteStream( + self._receive_response_data(stream_id, stream_ended) + ), + extensions={"http_version": b"HTTP/3"}, ) - self._use_fallback = True - pool._mark_failed(self._host, self._port) - return self._ensure_h2().handle_request(request) - - def _ensure_h2(self) -> httpx.HTTPTransport: - if self._h2_fallback is None: - self._h2_fallback = httpx.HTTPTransport(http2=HTTP2_AVAILABLE) - return self._h2_fallback - - def close(self) -> None: - if self._h2_fallback is not None: - self._h2_fallback.close() - - -# --------------------------------------------------------------------------- -# Connection pool -# --------------------------------------------------------------------------- + # -- aioquic protocol callbacks -------------------------------------- + + def http_event_received(self, event: H3Event) -> None: + if isinstance(event, (HeadersReceived, DataReceived)): + stream_id = event.stream_id + if stream_id in self._read_queue: + self._read_queue[stream_id].append(event) + self._read_ready[stream_id].set() + + def quic_event_received(self, event: QuicEvent) -> None: + if self._http is not None: + for http_event in self._http.handle_event(event): + self.http_event_received(http_event) + + # -- internal helpers ------------------------------------------------ + + async def _receive_response(self, stream_id: int): + stream_ended = False + while True: + event = await self._wait_for_http_event(stream_id) + if isinstance(event, HeadersReceived): + stream_ended = event.stream_ended + break + + headers = [] + status_code = 0 + for header, value in event.headers: + if header == b":status": + status_code = int(value.decode()) + else: + headers.append((header, value)) + return status_code, headers, stream_ended + + async def _receive_response_data( + self, stream_id: int, stream_ended: bool + ) -> AsyncIterator[bytes]: + while not stream_ended: + event = await self._wait_for_http_event(stream_id) + if isinstance(event, DataReceived): + stream_ended = event.stream_ended + yield event.data + elif isinstance(event, HeadersReceived): + stream_ended = event.stream_ended + + async def _wait_for_http_event(self, stream_id: int) -> H3Event: + if not self._read_queue[stream_id]: + await self._read_ready[stream_id].wait() + event = self._read_queue[stream_id].popleft() + if not self._read_queue[stream_id]: + self._read_ready[stream_id].clear() + return event + + # ----------------------------------------------------------------------- + # Sync H3 bridge (delegates to the async _H3Transport via a bg loop) + # ----------------------------------------------------------------------- + + class _SyncH3Transport(httpx.BaseTransport): + """Sync httpx transport that delegates to an async _H3Transport.""" + + def __init__( + self, async_transport: _H3Transport, loop: asyncio.AbstractEventLoop + ): + self._async_transport = async_transport + self._loop = loop + + def handle_request(self, request: httpx.Request) -> httpx.Response: + future = asyncio.run_coroutine_threadsafe( + self._async_transport.handle_async_request(request), + self._loop, + ) + return future.result(timeout=300) + + def close(self) -> None: + pass + + # ----------------------------------------------------------------------- + # Fallback transports: try H3, auto-downgrade to HTTP/2 on failure + # ----------------------------------------------------------------------- + + class AsyncH3FallbackTransport(httpx.AsyncBaseTransport): + """Async transport that tries H3 and falls back to HTTP/2 on failure.""" + + def __init__(self, h3: _H3Transport, host: str, port: int): + self._h3 = h3 + self._host = host + self._port = port + self._h2_fallback: httpx.AsyncHTTPTransport | None = None + self._use_fallback = False + + async def handle_async_request( + self, request: httpx.Request + ) -> httpx.Response: + if self._use_fallback: + return await self._ensure_h2().handle_async_request(request) + try: + return await self._h3.handle_async_request(request) + except Exception: + logger.info( + "H3 request to %s:%d failed, downgrading to HTTP/2", + self._host, + self._port, + ) + self._use_fallback = True + pool._mark_failed(self._host, self._port) + return await self._ensure_h2().handle_async_request(request) + + def _ensure_h2(self) -> httpx.AsyncHTTPTransport: + if self._h2_fallback is None: + self._h2_fallback = httpx.AsyncHTTPTransport( + http2=HTTP2_AVAILABLE + ) + return self._h2_fallback + + async def aclose(self) -> None: + if self._h2_fallback is not None: + await self._h2_fallback.aclose() + + class SyncH3FallbackTransport(httpx.BaseTransport): + """Sync transport that tries H3 and falls back to HTTP/2 on failure.""" + + def __init__(self, sync_h3: _SyncH3Transport, host: str, port: int): + self._sync_h3 = sync_h3 + self._host = host + self._port = port + self._h2_fallback: httpx.HTTPTransport | None = None + self._use_fallback = False + + def handle_request(self, request: httpx.Request) -> httpx.Response: + if self._use_fallback: + return self._ensure_h2().handle_request(request) + try: + return self._sync_h3.handle_request(request) + except Exception: + logger.info( + "H3 request to %s:%d failed, downgrading to HTTP/2", + self._host, + self._port, + ) + self._use_fallback = True + pool._mark_failed(self._host, self._port) + return self._ensure_h2().handle_request(request) + + def _ensure_h2(self) -> httpx.HTTPTransport: + if self._h2_fallback is None: + self._h2_fallback = httpx.HTTPTransport(http2=HTTP2_AVAILABLE) + return self._h2_fallback + + def close(self) -> None: + if self._h2_fallback is not None: + self._h2_fallback.close() + + # ----------------------------------------------------------------------- + # Connection pool + # ----------------------------------------------------------------------- + + class H3Pool: + """Global pool of H3 transports keyed by (host, port). + + Manages a background event loop thread for sync callers and QUIC + connection lifecycle. Failed hosts are remembered for + ``_H3_FAIL_TTL`` seconds so that repeated connection attempts don't + add latency. + """ -class H3Pool: - """Global pool of H3 transports keyed by (host, port). - - Manages a background event loop thread for sync callers and QUIC - connection lifecycle. Failed hosts are remembered for ``_H3_FAIL_TTL`` - seconds so that repeated connection attempts don't add latency. - """ - - def __init__(self) -> None: - self._async_transports: dict[tuple[str, int], _H3Transport] = {} - self._connect_contexts: dict[tuple[str, int], object] = {} - self._failed_hosts: dict[tuple[str, int], float] = {} - self._lock = threading.Lock() - self._async_lock: asyncio.Lock | None = None - self._bg_loop: asyncio.AbstractEventLoop | None = None - self._bg_thread: threading.Thread | None = None - - def _get_async_lock(self) -> asyncio.Lock: - if self._async_lock is None: - self._async_lock = asyncio.Lock() - return self._async_lock - - # -- negative cache ------------------------------------------------------ - - def _is_failed(self, host: str, port: int) -> bool: - key = (host, port) - with self._lock: - ts = self._failed_hosts.get(key) - if ts is None: - return False - if time.monotonic() - ts > _H3_FAIL_TTL: - del self._failed_hosts[key] - return False - return True - - def _mark_failed(self, host: str, port: int) -> None: - key = (host, port) - with self._lock: - self._failed_hosts[key] = time.monotonic() - self._connect_contexts.pop(key, None) - self._async_transports.pop(key, None) - - # -- background event loop for sync callers ------------------------------ - - def _ensure_bg_loop(self) -> asyncio.AbstractEventLoop: - with self._lock: - if self._bg_loop is None or not self._bg_loop.is_running(): - ready = threading.Event() - - def _run_loop(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self._bg_loop = loop - ready.set() - loop.run_forever() - - self._bg_thread = threading.Thread(target=_run_loop, daemon=True) - self._bg_thread.start() - ready.wait(timeout=5) - return self._bg_loop # type: ignore[return-value] - - # -- internal: raw H3 connection ----------------------------------------- - - async def _get_or_connect(self, host: str, port: int) -> _H3Transport | None: - """Get a cached _H3Transport or establish a new QUIC connection.""" - key = (host, port) - async with self._get_async_lock(): - transport = self._async_transports.get(key) - if transport is not None: + def __init__(self) -> None: + self._async_transports: dict[tuple[str, int], _H3Transport] = {} + self._connect_contexts: dict[tuple[str, int], object] = {} + self._failed_hosts: dict[tuple[str, int], float] = {} + self._lock = threading.Lock() + self._async_lock: asyncio.Lock | None = None + self._bg_loop: asyncio.AbstractEventLoop | None = None + self._bg_thread: threading.Thread | None = None + + def _get_async_lock(self) -> asyncio.Lock: + if self._async_lock is None: + self._async_lock = asyncio.Lock() + return self._async_lock + + # -- negative cache -------------------------------------------------- + + def _is_failed(self, host: str, port: int) -> bool: + key = (host, port) + with self._lock: + ts = self._failed_hosts.get(key) + if ts is None: + return False + if time.monotonic() - ts > _H3_FAIL_TTL: + del self._failed_hosts[key] + return False + return True + + def _mark_failed(self, host: str, port: int) -> None: + key = (host, port) + with self._lock: + self._failed_hosts[key] = time.monotonic() + self._connect_contexts.pop(key, None) + self._async_transports.pop(key, None) + + # -- background event loop for sync callers -------------------------- + + def _ensure_bg_loop(self) -> asyncio.AbstractEventLoop: + with self._lock: + if self._bg_loop is None or not self._bg_loop.is_running(): + ready = threading.Event() + + def _run_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._bg_loop = loop + ready.set() + loop.run_forever() + + self._bg_thread = threading.Thread( + target=_run_loop, daemon=True + ) + self._bg_thread.start() + ready.wait(timeout=5) + return self._bg_loop # type: ignore[return-value] + + # -- internal: raw H3 connection ------------------------------------- + + async def _get_or_connect( + self, host: str, port: int + ) -> _H3Transport | None: + """Get a cached _H3Transport or establish a new QUIC connection.""" + key = (host, port) + async with self._get_async_lock(): + transport = self._async_transports.get(key) + if transport is not None: + return transport + try: + transport = await asyncio.wait_for( + self._connect(host, port), timeout=_H3_CONNECT_TIMEOUT + ) + async with self._get_async_lock(): + self._async_transports[key] = transport return transport - try: - transport = await asyncio.wait_for( - self._connect(host, port), timeout=_H3_CONNECT_TIMEOUT + except Exception: + logger.debug("H3 connection to %s:%d failed", host, port) + self._mark_failed(host, port) + return None + + async def _connect(self, host: str, port: int) -> _H3Transport: + configuration = QuicConfiguration( + is_client=True, + alpn_protocols=H3_ALPN, + server_name=host, + ) + ctx = connect( + host, + port, + configuration=configuration, + create_protocol=_H3Transport, ) + transport = await ctx.__aenter__() + with self._lock: + self._connect_contexts[(host, port)] = ctx + return transport # type: ignore[return-value] + + # -- public async API ------------------------------------------------ + + async def get_async_transport( + self, host: str, port: int = 443 + ) -> AsyncH3FallbackTransport | None: + """Get an H3 transport with automatic HTTP/2 fallback. + + Returns None if the QUIC handshake fails (caller should fall + back to HTTP/2 or use ``get_async_transport_for_url`` which + does this automatically). + """ + if self._is_failed(host, port): + return None + raw = await self._get_or_connect(host, port) + if raw is None: + return None + return AsyncH3FallbackTransport(raw, host, port) + + # -- public sync API (dispatches to bg loop) ------------------------- + + def get_sync_transport( + self, host: str, port: int = 443 + ) -> SyncH3FallbackTransport | None: + """Get a sync H3 transport with automatic HTTP/2 fallback. + + Returns None on failure (caller should fall back to HTTP/2 or + use ``get_sync_transport_for_url``). + """ + if self._is_failed(host, port): + return None + loop = self._ensure_bg_loop() + future = asyncio.run_coroutine_threadsafe( + self._get_or_connect(host, port), loop + ) + try: + raw = future.result(timeout=_H3_CONNECT_TIMEOUT + 1) + except Exception: + self._mark_failed(host, port) + return None + if raw is None: + return None + return SyncH3FallbackTransport( + _SyncH3Transport(raw, loop), host, port + ) + + # -- shutdown -------------------------------------------------------- + + async def close_all(self) -> None: async with self._get_async_lock(): - self._async_transports[key] = transport - return transport - except Exception: - logger.debug("H3 connection to %s:%d failed", host, port) - self._mark_failed(host, port) - return None - - async def _connect(self, host: str, port: int) -> _H3Transport: - configuration = QuicConfiguration( - is_client=True, - alpn_protocols=H3_ALPN, - server_name=host, - ) - ctx = connect( - host, - port, - configuration=configuration, - create_protocol=_H3Transport, - ) - transport = await ctx.__aenter__() - with self._lock: - self._connect_contexts[(host, port)] = ctx - return transport # type: ignore[return-value] - - # -- public async API ---------------------------------------------------- - - async def get_async_transport( - self, host: str, port: int = 443 - ) -> AsyncH3FallbackTransport | None: - """Get an H3 transport with automatic HTTP/2 fallback for the host. - - Returns None if the QUIC handshake fails (caller should fall back - to HTTP/2 or use ``get_async_transport_for_url`` which does this - automatically). - """ - if self._is_failed(host, port): - return None - raw = await self._get_or_connect(host, port) - if raw is None: - return None - return AsyncH3FallbackTransport(raw, host, port) - - # -- public sync API (dispatches to bg loop) ----------------------------- - - def get_sync_transport( - self, host: str, port: int = 443 - ) -> SyncH3FallbackTransport | None: - """Get a sync H3 transport with automatic HTTP/2 fallback. - - Returns None on failure (caller should fall back to HTTP/2 or use - ``get_sync_transport_for_url``). - """ - if self._is_failed(host, port): - return None - loop = self._ensure_bg_loop() - future = asyncio.run_coroutine_threadsafe( - self._get_or_connect(host, port), loop - ) - try: - raw = future.result(timeout=_H3_CONNECT_TIMEOUT + 1) - except Exception: - self._mark_failed(host, port) - return None - if raw is None: - return None - return SyncH3FallbackTransport(_SyncH3Transport(raw, loop), host, port) - - # -- shutdown ------------------------------------------------------------ - - async def close_all(self) -> None: - async with self._get_async_lock(): - for key, ctx in list(self._connect_contexts.items()): + for key, ctx in list(self._connect_contexts.items()): + try: + await ctx.__aexit__(None, None, None) + except Exception: + pass + self._async_transports.clear() + self._connect_contexts.clear() + + def close_all_sync(self) -> None: + loop = self._bg_loop + if loop is not None and loop.is_running(): + future = asyncio.run_coroutine_threadsafe( + self.close_all(), loop + ) try: - await ctx.__aexit__(None, None, None) + future.result(timeout=5) except Exception: pass - self._async_transports.clear() - self._connect_contexts.clear() - - def close_all_sync(self) -> None: - loop = self._bg_loop - if loop is not None and loop.is_running(): - future = asyncio.run_coroutine_threadsafe(self.close_all(), loop) - try: - future.result(timeout=5) - except Exception: - pass + # -- module-level singleton (only when aioquic is available) ------------- -# --------------------------------------------------------------------------- -# Module-level singleton -# --------------------------------------------------------------------------- + pool: H3Pool | None = H3Pool() -pool = H3Pool() if AIOQUIC_AVAILABLE else None +else: + pool: H3Pool | None = None # type: ignore[no-redef] # --------------------------------------------------------------------------- From a10bb3f989fad9c0691d6ab255d6bb8e6bfff561 Mon Sep 17 00:00:00 2001 From: mjoffre Date: Tue, 10 Mar 2026 23:25:56 +0000 Subject: [PATCH 16/16] fix: address Devin Review comments - region apply-back, race condition in _get_or_connect 1. default/sandbox.py & sync/sandbox.py: Apply extracted region back to sandbox.spec.region in the else branch (was dead code before). Also replaced getattr() with direct attribute access. 2. h3transport.py: Hold async lock across the entire check+connect+store sequence in _get_or_connect() to prevent duplicate QUIC connections for the same (host, port) that would leak connection contexts. Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- src/blaxel/core/common/h3transport.py | 30 ++++++++++++++-------- src/blaxel/core/sandbox/default/sandbox.py | 6 +++-- src/blaxel/core/sandbox/sync/sandbox.py | 6 +++-- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/src/blaxel/core/common/h3transport.py b/src/blaxel/core/common/h3transport.py index e3e8667e..496c898e 100644 --- a/src/blaxel/core/common/h3transport.py +++ b/src/blaxel/core/common/h3transport.py @@ -331,23 +331,31 @@ def _run_loop(): async def _get_or_connect( self, host: str, port: int ) -> _H3Transport | None: - """Get a cached _H3Transport or establish a new QUIC connection.""" + """Get a cached _H3Transport or establish a new QUIC connection. + + Holds the async lock across the entire check+connect+store + sequence to prevent two concurrent callers from both creating + connections for the same (host, port), which would leak the + first QUIC connection context. + """ key = (host, port) async with self._get_async_lock(): transport = self._async_transports.get(key) if transport is not None: return transport - try: - transport = await asyncio.wait_for( - self._connect(host, port), timeout=_H3_CONNECT_TIMEOUT - ) - async with self._get_async_lock(): + try: + transport = await asyncio.wait_for( + self._connect(host, port), + timeout=_H3_CONNECT_TIMEOUT, + ) self._async_transports[key] = transport - return transport - except Exception: - logger.debug("H3 connection to %s:%d failed", host, port) - self._mark_failed(host, port) - return None + return transport + except Exception: + logger.debug( + "H3 connection to %s:%d failed", host, port + ) + self._mark_failed(host, port) + return None async def _connect(self, host: str, port: int) -> _H3Transport: configuration = QuicConfiguration( diff --git a/src/blaxel/core/sandbox/default/sandbox.py b/src/blaxel/core/sandbox/default/sandbox.py index 311972ee..d09710fa 100644 --- a/src/blaxel/core/sandbox/default/sandbox.py +++ b/src/blaxel/core/sandbox/default/sandbox.py @@ -258,8 +258,10 @@ async def create( sandbox.spec.runtime.image = sandbox.spec.runtime.image or default_image sandbox.spec.runtime.memory = sandbox.spec.runtime.memory or default_memory - # Extract region from existing Sandbox spec - region = getattr(sandbox.spec, "region", None) or settings.region + # Extract region from existing Sandbox spec and apply it + region = sandbox.spec.region or settings.region + if region: + sandbox.spec.region = region response = await create_sandbox( client=client, diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index f30a1a06..7dce4900 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -220,8 +220,10 @@ def create( sandbox.spec.runtime.image = sandbox.spec.runtime.image or default_image sandbox.spec.runtime.memory = sandbox.spec.runtime.memory or default_memory - # Extract region from existing Sandbox spec - region = getattr(sandbox.spec, "region", None) or settings.region + # Extract region from existing Sandbox spec and apply it + region = sandbox.spec.region or settings.region + if region: + sandbox.spec.region = region response = create_sandbox( client=client,