diff --git a/src/blaxel/core/common/settings.py b/src/blaxel/core/common/settings.py index c745238..34457c8 100644 --- a/src/blaxel/core/common/settings.py +++ b/src/blaxel/core/common/settings.py @@ -11,6 +11,16 @@ BLAXEL_API_VERSION = "2026-04-16" +def _get_int_env(name: str, default: int) -> int: + value = os.environ.get(name) + if value is None: + return default + try: + return int(value) + except ValueError: + return default + + def _get_os_arch() -> str: """Get OS and architecture information.""" try: @@ -97,6 +107,16 @@ def api_version(self) -> str: """Get the API version sent in the Blaxel-Version header.""" return os.environ.get("BL_API_VERSION", BLAXEL_API_VERSION) + @property + def fs_part_retries(self) -> int: + """Retry budget for idempotent filesystem upload PUTs.""" + return _get_int_env("BL_FS_PART_RETRIES", 3) + + @property + def sandbox_read_retries(self) -> int: + """Retry budget for idempotent sandbox read/list operations.""" + return _get_int_env("BL_SANDBOX_READ_RETRIES", 5) + @property def headers(self) -> Dict[str, str]: """Get the headers for API requests.""" diff --git a/src/blaxel/core/sandbox/default/drive.py b/src/blaxel/core/sandbox/default/drive.py index 990ec3d..357ec1b 100644 --- a/src/blaxel/core/sandbox/default/drive.py +++ b/src/blaxel/core/sandbox/default/drive.py @@ -14,6 +14,7 @@ DriveUnmountResponse, ErrorResponse, ) +from ..transient_retry import retry_on_transient_reset_async from ..types import SandboxConfiguration from .action import SandboxAction @@ -100,15 +101,19 @@ async def list(self) -> List[DriveMountInfo]: Returns: List of DriveMountInfo for each mounted drive """ - client = Client( - base_url=self.url, - headers={**settings.headers, **self.sandbox_config.headers}, - ) - async with client: - response = await get_drives_mount(client=client) - if response is None: - raise Exception("Failed to list drives") - if isinstance(response, ErrorResponse): - raise Exception(f"List drives failed: {response.error}") - return list(response.mounts) if response.mounts else [] + async def list_once() -> List[DriveMountInfo]: + client = Client( + base_url=self.url, + headers={**settings.headers, **self.sandbox_config.headers}, + ) + + async with client: + response = await get_drives_mount(client=client) + if response is None: + raise Exception("Failed to list drives") + if isinstance(response, ErrorResponse): + raise Exception(f"List drives failed: {response.error}") + return list(response.mounts) if response.mounts else [] + + return await retry_on_transient_reset_async(list_once) diff --git a/src/blaxel/core/sandbox/default/filesystem.py b/src/blaxel/core/sandbox/default/filesystem.py index c5e8a1d..5e87819 100644 --- a/src/blaxel/core/sandbox/default/filesystem.py +++ b/src/blaxel/core/sandbox/default/filesystem.py @@ -9,6 +9,7 @@ from ...common.settings import settings from ..client.models import Directory, FileRequest, SuccessResponse +from ..transient_retry import retry_on_transient_reset_async from ..types import ( AsyncWatchHandle, CopyResponse, @@ -93,34 +94,35 @@ async def write_binary( if len(content) > MULTIPART_THRESHOLD: return await self._upload_with_multipart(path, content, "0644") - # Use regular upload for small files - # Wrap binary content in BytesIO to provide file-like interface - binary_file = io.BytesIO(content) - - # Prepare multipart form data - files = { - "file": ( - "binary-file.bin", - binary_file, - "application/octet-stream", - ), - } - data = {"permissions": "0644", "path": path} - - # Use the fixed get_client method url = f"{self.url}/filesystem/{path}" headers = {**settings.headers, **self.sandbox_config.headers} - client = self.get_client() - response = await client.put(url, files=files, data=data, headers=headers) - try: - content_bytes = await response.aread() - if not response.is_success: - error_text = content_bytes.decode("utf-8", errors="ignore") - raise Exception(f"Failed to write binary: {response.status_code} {error_text}") - return SuccessResponse.from_dict(json.loads(content_bytes)) - finally: - await response.aclose() + async def put_once() -> SuccessResponse: + files = { + "file": ( + "binary-file.bin", + io.BytesIO(content), + "application/octet-stream", + ), + } + data = {"permissions": "0644", "path": path} + client = self.get_client() + response = await client.put(url, files=files, data=data, headers=headers) + try: + content_bytes = await response.aread() + if not response.is_success: + error_text = content_bytes.decode("utf-8", errors="ignore") + raise Exception(f"Failed to write binary: {response.status_code} {error_text}") + result = SuccessResponse.from_dict(json.loads(content_bytes)) + assert result is not None + return result + finally: + await response.aclose() + + return await retry_on_transient_reset_async( + put_once, + retries=settings.fs_part_retries, + ) async def write_tree( self, @@ -152,16 +154,19 @@ async def write_tree( async def read(self, path: str) -> str: path = self.format_path(path) - client = self.get_client() - response = await client.get(f"/filesystem/{path}") - try: - data = json.loads(await response.aread()) - self.handle_response_error(response) - if "content" in data: - return data["content"] - raise Exception("Unsupported file type") - finally: - await response.aclose() + async def read_once() -> str: + client = self.get_client() + response = await client.get(f"/filesystem/{path}") + try: + data = json.loads(await response.aread()) + self.handle_response_error(response) + if "content" in data: + return data["content"] + raise Exception("Unsupported file type") + finally: + await response.aclose() + + return await retry_on_transient_reset_async(read_once) async def read_binary(self, path: str) -> bytes: """Read binary content from a file. @@ -181,14 +186,17 @@ async def read_binary(self, path: str) -> bytes: "Accept": "application/octet-stream", } - client = self.get_client() - response = await client.get(url, headers=headers) - try: - content = await response.aread() - self.handle_response_error(response) - return content - finally: - await response.aclose() + async def read_once() -> bytes: + client = self.get_client() + response = await client.get(url, headers=headers) + try: + content = await response.aread() + self.handle_response_error(response) + return content + finally: + await response.aclose() + + return await retry_on_transient_reset_async(read_once) async def download(self, src: str, destination_path: str, mode: int = 0o644) -> None: """Download a file from the sandbox to the local filesystem. @@ -219,16 +227,21 @@ async def rm(self, path: str, recursive: bool = False) -> SuccessResponse: async def ls(self, path: str) -> Directory: path = self.format_path(path) - client = self.get_client() - response = await client.get(f"/filesystem/{path}") - try: - data = json.loads(await response.aread()) - self.handle_response_error(response) - if not ("files" in data or "subdirectories" in data): - raise Exception('{"error": "Directory not found"}') - return Directory.from_dict(data) - finally: - await response.aclose() + async def ls_once() -> Directory: + client = self.get_client() + response = await client.get(f"/filesystem/{path}") + try: + data = json.loads(await response.aread()) + self.handle_response_error(response) + if not ("files" in data or "subdirectories" in data): + raise Exception('{"error": "Directory not found"}') + result = Directory.from_dict(data) + assert result is not None + return result + finally: + await response.aclose() + + return await retry_on_transient_reset_async(ls_once) async def find( self, @@ -269,17 +282,20 @@ async def find( url = f"{self.url}/filesystem-find/{path}" headers = {**settings.headers, **self.sandbox_config.headers} - client = self.get_client() - response = await client.get(url, params=params, headers=headers) - try: - data = json.loads(await response.aread()) - self.handle_response_error(response) + async def find_once(): + client = self.get_client() + response = await client.get(url, params=params, headers=headers) + try: + data = json.loads(await response.aread()) + self.handle_response_error(response) - from ..client.models.find_response import FindResponse + from ..client.models.find_response import FindResponse - return FindResponse.from_dict(data) - finally: - await response.aclose() + return FindResponse.from_dict(data) + finally: + await response.aclose() + + return await retry_on_transient_reset_async(find_once) async def grep( self, @@ -322,17 +338,20 @@ async def grep( url = f"{self.url}/filesystem-content-search/{path}" headers = {**settings.headers, **self.sandbox_config.headers} - client = self.get_client() - response = await client.get(url, params=params, headers=headers) - try: - data = json.loads(await response.aread()) - self.handle_response_error(response) + async def grep_once(): + client = self.get_client() + response = await client.get(url, params=params, headers=headers) + try: + data = json.loads(await response.aread()) + self.handle_response_error(response) - from ..client.models.content_search_response import ContentSearchResponse + from ..client.models.content_search_response import ContentSearchResponse - return ContentSearchResponse.from_dict(data) - finally: - await response.aclose() + return ContentSearchResponse.from_dict(data) + finally: + await response.aclose() + + return await retry_on_transient_reset_async(grep_once) async def cp(self, source: str, destination: str, max_wait: int = 180000) -> CopyResponse: """Copy files or directories using the cp command. @@ -492,17 +511,21 @@ async def _upload_part(self, upload_id: str, part_number: int, data: bytes) -> D headers = {**settings.headers, **self.sandbox_config.headers} params = {"partNumber": part_number} - # Prepare multipart form data with the file chunk - files = {"file": ("part", io.BytesIO(data), "application/octet-stream")} + async def put_once() -> Dict[str, Any]: + files = {"file": ("part", io.BytesIO(data), "application/octet-stream")} + client = self.get_client() + response = await client.put(url, files=files, params=params, headers=headers) + try: + self.handle_response_error(response) + result = json.loads(await response.aread()) + return result + finally: + await response.aclose() - client = self.get_client() - response = await client.put(url, files=files, params=params, headers=headers) - try: - self.handle_response_error(response) - result = json.loads(await response.aread()) - return result - finally: - await response.aclose() + return await retry_on_transient_reset_async( + put_once, + retries=settings.fs_part_retries, + ) async def _complete_multipart_upload( self, upload_id: str, parts: List[Dict[str, Any]] diff --git a/src/blaxel/core/sandbox/default/process.py b/src/blaxel/core/sandbox/default/process.py index 8edd360..0f440f7 100644 --- a/src/blaxel/core/sandbox/default/process.py +++ b/src/blaxel/core/sandbox/default/process.py @@ -6,6 +6,7 @@ from ...common.settings import settings from ..client.models import ProcessResponse, SuccessResponse from ..client.models.process_request import ProcessRequest +from ..transient_retry import retry_on_transient_reset_async from ..types import ( AsyncStreamHandle, ProcessRequestWithLog, @@ -417,33 +418,39 @@ async def wait( async def get(self, identifier: str) -> ProcessResponse: import json - client = self.get_client() - response = await client.get(f"/process/{identifier}") - try: - data = json.loads(await response.aread()) - self.handle_response_error(response) - result = ProcessResponse.from_dict(data) - assert result is not None - return result - finally: - await response.aclose() + async def get_once() -> ProcessResponse: + client = self.get_client() + response = await client.get(f"/process/{identifier}") + try: + data = json.loads(await response.aread()) + self.handle_response_error(response) + result = ProcessResponse.from_dict(data) + assert result is not None + return result + finally: + await response.aclose() + + return await retry_on_transient_reset_async(get_once) async def list(self) -> list[ProcessResponse]: import json - client = self.get_client() - response = await client.get("/process") - try: - data = json.loads(await response.aread()) - self.handle_response_error(response) - results = [] - for item in data: - result = ProcessResponse.from_dict(item) - assert result is not None - results.append(result) - return results - finally: - await response.aclose() + async def list_once() -> list[ProcessResponse]: + client = self.get_client() + response = await client.get("/process") + try: + data = json.loads(await response.aread()) + self.handle_response_error(response) + results = [] + for item in data: + result = ProcessResponse.from_dict(item) + assert result is not None + results.append(result) + return results + finally: + await response.aclose() + + return await retry_on_transient_reset_async(list_once) async def stop(self, identifier: str) -> SuccessResponse: import json @@ -480,18 +487,21 @@ async def logs( ) -> str: import json - client = self.get_client() - response = await client.get(f"/process/{identifier}/logs") - try: - data = json.loads(await response.aread()) - self.handle_response_error(response) - if log_type == "all": - return data.get("logs", "") - elif log_type == "stdout": - return data.get("stdout", "") - elif log_type == "stderr": - return data.get("stderr", "") - - raise Exception("Unsupported log type") - finally: - await response.aclose() + async def logs_once() -> str: + client = self.get_client() + response = await client.get(f"/process/{identifier}/logs") + try: + data = json.loads(await response.aread()) + self.handle_response_error(response) + if log_type == "all": + return data.get("logs", "") + elif log_type == "stdout": + return data.get("stdout", "") + elif log_type == "stderr": + return data.get("stderr", "") + + raise Exception("Unsupported log type") + finally: + await response.aclose() + + return await retry_on_transient_reset_async(logs_once) diff --git a/src/blaxel/core/sandbox/sync/drive.py b/src/blaxel/core/sandbox/sync/drive.py index cff9f15..748c650 100644 --- a/src/blaxel/core/sandbox/sync/drive.py +++ b/src/blaxel/core/sandbox/sync/drive.py @@ -14,6 +14,7 @@ DriveUnmountResponse, ErrorResponse, ) +from ..transient_retry import retry_on_transient_reset from ..types import SandboxConfiguration from .action import SyncSandboxAction @@ -100,15 +101,19 @@ def list(self) -> List[DriveMountInfo]: Returns: List of DriveMountInfo for each mounted drive """ - client = Client( - base_url=self.url, - headers={**settings.headers, **self.sandbox_config.headers}, - ) - with client: - response = get_drives_mount(client=client) - if response is None: - raise Exception("Failed to list drives") - if isinstance(response, ErrorResponse): - raise Exception(f"List drives failed: {response.error}") - return list(response.mounts) if response.mounts else [] + def list_once() -> List[DriveMountInfo]: + client = Client( + base_url=self.url, + headers={**settings.headers, **self.sandbox_config.headers}, + ) + + with client: + response = get_drives_mount(client=client) + if response is None: + raise Exception("Failed to list drives") + if isinstance(response, ErrorResponse): + raise Exception(f"List drives failed: {response.error}") + return list(response.mounts) if response.mounts else [] + + return retry_on_transient_reset(list_once) diff --git a/src/blaxel/core/sandbox/sync/filesystem.py b/src/blaxel/core/sandbox/sync/filesystem.py index eaf3af6..256eb26 100644 --- a/src/blaxel/core/sandbox/sync/filesystem.py +++ b/src/blaxel/core/sandbox/sync/filesystem.py @@ -9,6 +9,7 @@ from ...common.settings import settings from ..client.models import Directory, FileRequest, SuccessResponse +from ..transient_retry import retry_on_transient_reset from ..types import ( CopyResponse, SandboxConfiguration, @@ -60,22 +61,29 @@ def write_binary(self, path: str, content: Union[bytes, bytearray, str]) -> Succ content = bytes(content) if len(content) > MULTIPART_THRESHOLD: return self._upload_with_multipart(path, content, "0644") - binary_file = io.BytesIO(content) - files = { - "file": ( - "binary-file.bin", - binary_file, - "application/octet-stream", - ), - } - data = {"permissions": "0644", "path": path} url = f"{self.url}/filesystem/{path}" headers = {**settings.headers, **self.sandbox_config.headers} - with self.get_client() as client_instance: - response = client_instance.put(url, files=files, data=data, headers=headers) - if not response.is_success: - raise Exception(f"Failed to write binary: {response.status_code} {response.text}") - return SuccessResponse.from_dict(response.json()) + + def put_once() -> SuccessResponse: + files = { + "file": ( + "binary-file.bin", + io.BytesIO(content), + "application/octet-stream", + ), + } + data = {"permissions": "0644", "path": path} + with self.get_client() as client_instance: + response = client_instance.put(url, files=files, data=data, headers=headers) + if not response.is_success: + raise Exception( + f"Failed to write binary: {response.status_code} {response.text}" + ) + result = SuccessResponse.from_dict(response.json()) + assert result is not None + return result + + return retry_on_transient_reset(put_once, retries=settings.fs_part_retries) def write_tree( self, @@ -99,13 +107,17 @@ def write_tree( def read(self, path: str) -> str: path = self.format_path(path) - with self.get_client() as client_instance: - response = client_instance.get(f"/filesystem/{path}") - self.handle_response_error(response) - data = response.json() - if "content" in data: - return data["content"] - raise Exception("Unsupported file type") + + def read_once() -> str: + with self.get_client() as client_instance: + response = client_instance.get(f"/filesystem/{path}") + self.handle_response_error(response) + data = response.json() + if "content" in data: + return data["content"] + raise Exception("Unsupported file type") + + return retry_on_transient_reset(read_once) def read_binary(self, path: str) -> bytes: path = self.format_path(path) @@ -115,10 +127,14 @@ def read_binary(self, path: str) -> bytes: **self.sandbox_config.headers, "Accept": "application/octet-stream", } - with self.get_client() as client_instance: - response = client_instance.get(url, headers=headers) - self.handle_response_error(response) - return response.content + + def read_once() -> bytes: + with self.get_client() as client_instance: + response = client_instance.get(url, headers=headers) + self.handle_response_error(response) + return response.content + + return retry_on_transient_reset(read_once) def download(self, src: str, destination_path: str, mode: int = 0o644) -> None: content = self.read_binary(src) @@ -136,13 +152,19 @@ def rm(self, path: str, recursive: bool = False) -> SuccessResponse: def ls(self, path: str) -> Directory: path = self.format_path(path) - with self.get_client() as client_instance: - response = client_instance.get(f"/filesystem/{path}") - self.handle_response_error(response) - data = response.json() - if not ("files" in data or "subdirectories" in data): - raise Exception('{"error": "Directory not found"}') - return Directory.from_dict(data) + + def ls_once() -> Directory: + with self.get_client() as client_instance: + response = client_instance.get(f"/filesystem/{path}") + self.handle_response_error(response) + data = response.json() + if not ("files" in data or "subdirectories" in data): + raise Exception('{"error": "Directory not found"}') + result = Directory.from_dict(data) + assert result is not None + return result + + return retry_on_transient_reset(ls_once) def cp(self, source: str, destination: str, max_wait: int = 180000) -> CopyResponse: if not self.process: @@ -262,11 +284,15 @@ def _upload_part(self, upload_id: str, part_number: int, data: bytes) -> Dict[st url = f"{self.url}/filesystem-multipart/{upload_id}/part" headers = {**settings.headers, **self.sandbox_config.headers} params = {"partNumber": part_number} - files = {"file": ("part", io.BytesIO(data), "application/octet-stream")} - with self.get_client() as client_instance: - response = client_instance.put(url, files=files, params=params, headers=headers) - self.handle_response_error(response) - return response.json() + + def put_once() -> Dict[str, Any]: + files = {"file": ("part", io.BytesIO(data), "application/octet-stream")} + with self.get_client() as client_instance: + response = client_instance.put(url, files=files, params=params, headers=headers) + self.handle_response_error(response) + return response.json() + + return retry_on_transient_reset(put_once, retries=settings.fs_part_retries) def _complete_multipart_upload( self, upload_id: str, parts: List[Dict[str, Any]] diff --git a/src/blaxel/core/sandbox/sync/process.py b/src/blaxel/core/sandbox/sync/process.py index 7df1539..49952e7 100644 --- a/src/blaxel/core/sandbox/sync/process.py +++ b/src/blaxel/core/sandbox/sync/process.py @@ -7,6 +7,7 @@ from ...common.settings import settings from ..client.models import ProcessResponse, SuccessResponse from ..client.models.process_request import ProcessRequest +from ..transient_retry import retry_on_transient_reset from ..types import ( ProcessRequestWithLog, ProcessResponseWithLog, @@ -354,16 +355,29 @@ def wait(self, identifier: str, max_wait: int = 60000, interval: int = 1000) -> return data def get(self, identifier: str) -> ProcessResponse: - with self.get_client() as client_instance: - response = client_instance.get(f"/process/{identifier}") - self.handle_response_error(response) - return ProcessResponse.from_dict(response.json()) + def get_once() -> ProcessResponse: + with self.get_client() as client_instance: + response = client_instance.get(f"/process/{identifier}") + self.handle_response_error(response) + result = ProcessResponse.from_dict(response.json()) + assert result is not None + return result + + return retry_on_transient_reset(get_once) def list(self) -> list[ProcessResponse]: - with self.get_client() as client_instance: - response = client_instance.get("/process") - self.handle_response_error(response) - return [ProcessResponse.from_dict(item) for item in response.json()] + def list_once() -> list[ProcessResponse]: + with self.get_client() as client_instance: + response = client_instance.get("/process") + self.handle_response_error(response) + results = [] + for item in response.json(): + result = ProcessResponse.from_dict(item) + assert result is not None + results.append(result) + return results + + return retry_on_transient_reset(list_once) def stop(self, identifier: str) -> SuccessResponse: with self.get_client() as client_instance: @@ -382,14 +396,17 @@ def logs( identifier: str, log_type: Literal["stdout", "stderr", "all"] = "all", ) -> str: - with self.get_client() as client_instance: - response = client_instance.get(f"/process/{identifier}/logs") - self.handle_response_error(response) - data = response.json() - if log_type == "all": - return data.get("logs", "") - elif log_type == "stdout": - return data.get("stdout", "") - elif log_type == "stderr": - return data.get("stderr", "") - raise Exception("Unsupported log type") + def logs_once() -> str: + with self.get_client() as client_instance: + response = client_instance.get(f"/process/{identifier}/logs") + self.handle_response_error(response) + data = response.json() + if log_type == "all": + return data.get("logs", "") + elif log_type == "stdout": + return data.get("stdout", "") + elif log_type == "stderr": + return data.get("stderr", "") + raise Exception("Unsupported log type") + + return retry_on_transient_reset(logs_once) diff --git a/src/blaxel/core/sandbox/transient_retry.py b/src/blaxel/core/sandbox/transient_retry.py new file mode 100644 index 0000000..70a7c3f --- /dev/null +++ b/src/blaxel/core/sandbox/transient_retry.py @@ -0,0 +1,141 @@ +import asyncio +import random +import time +from collections.abc import Awaitable, Callable, Iterator +from typing import TypeVar + +import httpx + +from ..common.settings import settings + +T = TypeVar("T") + +TRANSIENT_RESET_MARKERS = ( + "ENHANCE_YOUR_CALM", + "NGHTTP2_INTERNAL_ERROR", + "ERR_HTTP2", + "GOAWAY", + "HTTP/2 session closed before response", + "HTTP/2 session sent GOAWAY before response", + "Connection reset by peer", + "Server disconnected without sending a response", +) + +TRANSIENT_ERROR_CODES = { + "ECONNRESET", + "ECONNREFUSED", + "ETIMEDOUT", + "EPIPE", + "ERR_HTTP2_STREAM_ERROR", + "ERR_HTTP2_GOAWAY_SESSION", + "ERR_HTTP2_SESSION_ERROR", +} + +DEFAULT_BASE_DELAY_SECONDS = 0.2 +DEFAULT_MAX_DELAY_SECONDS = 2.0 + + +def _walk_error_chain(error: BaseException) -> Iterator[BaseException]: + current: BaseException | None = error + seen: set[int] = set() + for _ in range(5): + if current is None or id(current) in seen: + break + seen.add(id(current)) + yield current + current = current.__cause__ or current.__context__ + + +def _has_http_response_status(error: BaseException) -> bool: + for node in _walk_error_chain(error): + response = getattr(node, "response", None) + status = getattr(response, "status_code", None) + if isinstance(status, int): + return True + status = getattr(node, "status_code", None) + if isinstance(status, int): + return True + return False + + +def _collect_error_text(error: BaseException) -> tuple[list[str], list[str]]: + messages: list[str] = [] + codes: list[str] = [] + for node in _walk_error_chain(error): + messages.append(str(node)) + code = getattr(node, "code", None) + if isinstance(code, str): + codes.append(code) + errno = getattr(node, "errno", None) + if isinstance(errno, str): + codes.append(errno) + return messages, codes + + +def is_transient_reset_error(error: BaseException) -> bool: + """True for transport-level drops that are safe to retry on idempotent calls.""" + if _has_http_response_status(error): + return False + if isinstance(error, httpx.TimeoutException | httpx.NetworkError | httpx.RemoteProtocolError): + return True + if not isinstance(error, httpx.TransportError): + return False + + messages, codes = _collect_error_text(error) + if any(code in TRANSIENT_ERROR_CODES for code in codes): + return True + return any(marker in message for message in messages for marker in TRANSIENT_RESET_MARKERS) + + +def _backoff_delay_seconds( + attempt: int, + base_delay_seconds: float, + max_delay_seconds: float, +) -> float: + if base_delay_seconds <= 0 or max_delay_seconds <= 0: + return 0 + exponential = base_delay_seconds * (2 ** (attempt - 1)) + capped = min(exponential, max_delay_seconds) + return capped + random.uniform(0, base_delay_seconds) + + +async def retry_on_transient_reset_async( + fn: Callable[[], Awaitable[T]], + *, + retries: int | None = None, + base_delay_seconds: float = DEFAULT_BASE_DELAY_SECONDS, + max_delay_seconds: float = DEFAULT_MAX_DELAY_SECONDS, +) -> T: + retry_budget = settings.sandbox_read_retries if retries is None else retries + attempt = 0 + while True: + try: + return await fn() + except Exception as error: + attempt += 1 + if retry_budget <= 0 or attempt > retry_budget or not is_transient_reset_error(error): + raise + delay = _backoff_delay_seconds(attempt, base_delay_seconds, max_delay_seconds) + if delay: + await asyncio.sleep(delay) + + +def retry_on_transient_reset( + fn: Callable[[], T], + *, + retries: int | None = None, + base_delay_seconds: float = DEFAULT_BASE_DELAY_SECONDS, + max_delay_seconds: float = DEFAULT_MAX_DELAY_SECONDS, +) -> T: + retry_budget = settings.sandbox_read_retries if retries is None else retries + attempt = 0 + while True: + try: + return fn() + except Exception as error: + attempt += 1 + if retry_budget <= 0 or attempt > retry_budget or not is_transient_reset_error(error): + raise + delay = _backoff_delay_seconds(attempt, base_delay_seconds, max_delay_seconds) + if delay: + time.sleep(delay) diff --git a/tests/core/test_sandbox_transient_retry.py b/tests/core/test_sandbox_transient_retry.py new file mode 100644 index 0000000..638f920 --- /dev/null +++ b/tests/core/test_sandbox_transient_retry.py @@ -0,0 +1,276 @@ +import asyncio +from typing import Any, cast + +import httpx +import pytest + +from blaxel.core.common.settings import settings +from blaxel.core.sandbox.default.filesystem import SandboxFileSystem +from blaxel.core.sandbox.default.process import SandboxProcess +from blaxel.core.sandbox.sync.filesystem import SyncSandboxFileSystem +from blaxel.core.sandbox.transient_retry import ( + is_transient_reset_error, + retry_on_transient_reset, + retry_on_transient_reset_async, +) +from blaxel.core.sandbox.types import ResponseError + + +class LoopbackFaultServer: + def __init__(self, *handlers): + self.handlers = handlers + self.requests = 0 + self.server: asyncio.Server | None = None + self.url = "" + + async def __aenter__(self): + self.server = await asyncio.start_server(self._handle, "127.0.0.1", 0) + socket = self.server.sockets[0] + host, port = socket.getsockname()[:2] + self.url = f"http://{host}:{port}" + return self + + async def __aexit__(self, *args): + if self.server is not None: + self.server.close() + await self.server.wait_closed() + + async def _handle( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + index = self.requests + self.requests += 1 + handler = self.handlers[min(index, len(self.handlers) - 1)] + await handler(reader, writer) + + +class AsyncSequenceClient: + def __init__(self, *results): + self.results = list(results) + self.calls = 0 + + async def get(self, *args, **kwargs): + self.calls += 1 + result = self.results.pop(0) + if isinstance(result, BaseException): + raise result + return result + + async def post(self, *args, **kwargs): + self.calls += 1 + result = self.results.pop(0) + if isinstance(result, BaseException): + raise result + return result + + +class SyncSequenceClient: + def __init__(self, *results): + self.results = list(results) + self.calls = 0 + + def __enter__(self): + return self + + def __exit__(self, *args): + return None + + def get(self, *args, **kwargs): + self.calls += 1 + result = self.results.pop(0) + if isinstance(result, BaseException): + raise result + return result + + +def ok_json_response(data): + return httpx.Response( + 200, + json=data, + request=httpx.Request("GET", "https://sandbox.test"), + ) + + +def app_error_response() -> ResponseError: + response = httpx.Response( + 500, + json={"error": "GOAWAY in an application body"}, + request=httpx.Request("GET", "https://sandbox.test"), + ) + return ResponseError(response) + + +async def close_without_response( + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, +) -> None: + try: + await asyncio.wait_for(reader.read(1024), timeout=0.2) + except TimeoutError: + pass + writer.close() + await writer.wait_closed() + + +async def send_ok_response( + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, +) -> None: + try: + await asyncio.wait_for(reader.readuntil(b"\r\n\r\n"), timeout=1.0) + except (TimeoutError, asyncio.IncompleteReadError, asyncio.LimitOverrunError): + pass + writer.write(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok") + await writer.drain() + writer.close() + await writer.wait_closed() + + +@pytest.fixture(autouse=True) +def no_retry_sleep(monkeypatch): + monkeypatch.setattr( + "blaxel.core.sandbox.transient_retry._backoff_delay_seconds", + lambda *args: 0, + ) + + +def test_retry_settings_defaults_and_env(monkeypatch): + monkeypatch.delenv("BL_FS_PART_RETRIES", raising=False) + monkeypatch.delenv("BL_SANDBOX_READ_RETRIES", raising=False) + assert settings.fs_part_retries == 3 + assert settings.sandbox_read_retries == 5 + + monkeypatch.setenv("BL_FS_PART_RETRIES", "1") + monkeypatch.setenv("BL_SANDBOX_READ_RETRIES", "2") + assert settings.fs_part_retries == 1 + assert settings.sandbox_read_retries == 2 + + +def test_classifier_accepts_httpx_transport_drops(): + assert is_transient_reset_error(httpx.ConnectError("All connection attempts failed")) + assert is_transient_reset_error(httpx.RemoteProtocolError("GOAWAY received")) + assert is_transient_reset_error(httpx.ReadTimeout("timed out")) + + +def test_classifier_rejects_application_responses(): + assert not is_transient_reset_error(app_error_response()) + + +@pytest.mark.asyncio +async def test_real_httpx_transport_drop_is_classified_transient(): + async with LoopbackFaultServer(close_without_response) as server: + async with httpx.AsyncClient(timeout=2.0) as client: + with pytest.raises(httpx.TransportError) as exc_info: + await client.get(server.url) + + assert is_transient_reset_error(exc_info.value) + assert server.requests == 1 + + +@pytest.mark.asyncio +async def test_async_retry_counts_real_transport_fault_attempts(): + async with LoopbackFaultServer(close_without_response) as server: + async with httpx.AsyncClient(timeout=2.0) as client: + with pytest.raises(httpx.TransportError): + await retry_on_transient_reset_async( + lambda: client.get(server.url), + retries=2, + ) + + assert server.requests == 3 + + +@pytest.mark.asyncio +async def test_async_retry_self_heals_after_real_transport_fault_clears(): + async with LoopbackFaultServer(close_without_response, send_ok_response) as server: + async with httpx.AsyncClient(timeout=2.0) as client: + response = await retry_on_transient_reset_async( + lambda: client.get(server.url), + retries=1, + ) + + assert response.status_code == 200 + assert response.text == "ok" + assert server.requests == 2 + + +@pytest.mark.asyncio +async def test_async_retry_recovers_once(): + calls = 0 + + async def flaky(): + nonlocal calls + calls += 1 + if calls == 1: + raise httpx.ConnectError("All connection attempts failed") + return "ok" + + assert await retry_on_transient_reset_async(flaky, retries=1) == "ok" + assert calls == 2 + + +def test_sync_retry_recovers_once(): + calls = 0 + + def flaky(): + nonlocal calls + calls += 1 + if calls == 1: + raise httpx.ConnectError("All connection attempts failed") + return "ok" + + assert retry_on_transient_reset(flaky, retries=1) == "ok" + assert calls == 2 + + +def test_sync_retry_does_not_retry_application_response(): + calls = 0 + + def app_error(): + nonlocal calls + calls += 1 + raise app_error_response() + + with pytest.raises(ResponseError): + retry_on_transient_reset(app_error, retries=3) + assert calls == 1 + + +@pytest.mark.asyncio +async def test_async_filesystem_read_retries_transport_reset(monkeypatch): + monkeypatch.setenv("BL_SANDBOX_READ_RETRIES", "1") + client = AsyncSequenceClient( + httpx.ConnectError("All connection attempts failed"), + ok_json_response({"content": "hello"}), + ) + filesystem = cast(Any, object.__new__(SandboxFileSystem)) + filesystem.get_client = lambda: client + + assert await filesystem.read("/file.txt") == "hello" + assert client.calls == 2 + + +def test_sync_filesystem_read_retries_transport_reset(monkeypatch): + monkeypatch.setenv("BL_SANDBOX_READ_RETRIES", "1") + client = SyncSequenceClient( + httpx.ConnectError("All connection attempts failed"), + ok_json_response({"content": "hello"}), + ) + filesystem = cast(Any, object.__new__(SyncSandboxFileSystem)) + filesystem.get_client = lambda: client + + assert filesystem.read("/file.txt") == "hello" + assert client.calls == 2 + + +@pytest.mark.asyncio +async def test_process_exec_is_not_retried_on_transport_reset(): + client = AsyncSequenceClient(httpx.ConnectError("All connection attempts failed")) + process = cast(Any, object.__new__(SandboxProcess)) + process.get_client = lambda: client + + with pytest.raises(httpx.ConnectError): + await process.exec({"command": "echo nope"}) + assert client.calls == 1