From 6bfe6c8b7db3bf4d5a9f2d0d6012d80a0da5a9e2 Mon Sep 17 00:00:00 2001 From: Michael Stolarz Date: Tue, 26 May 2026 15:57:41 -0700 Subject: [PATCH] fix sandbox create-if-not-exists races --- .../core/sandbox/default/interpreter.py | 7 +- src/blaxel/core/sandbox/default/sandbox.py | 70 ++-- src/blaxel/core/sandbox/sync/interpreter.py | 7 +- src/blaxel/core/sandbox/sync/sandbox.py | 44 +-- tests/core/test_sandbox.py | 355 +++++++++++++++++- 5 files changed, 429 insertions(+), 54 deletions(-) diff --git a/src/blaxel/core/sandbox/default/interpreter.py b/src/blaxel/core/sandbox/default/interpreter.py index afd28746..9d90389b 100644 --- a/src/blaxel/core/sandbox/default/interpreter.py +++ b/src/blaxel/core/sandbox/default/interpreter.py @@ -36,6 +36,7 @@ async def create( cls, sandbox: Sandbox | SandboxCreateConfiguration | Dict[str, Any] | None = None, safe: bool = True, + create_if_not_exist: bool = False, ) -> CodeInterpreter: """ Create a sandbox instance using the jupyter-server image. @@ -83,7 +84,11 @@ async def create( if sandbox.spec and getattr(sandbox.spec, "region", None): payload["region"] = sandbox.spec.region - base_instance = await SandboxInstance.create(payload, safe=safe) + base_instance = await SandboxInstance.create( + payload, + safe=safe, + create_if_not_exist=create_if_not_exist, + ) return cls( sandbox=base_instance.sandbox, force_url=base_instance.config.force_url, diff --git a/src/blaxel/core/sandbox/default/sandbox.py b/src/blaxel/core/sandbox/default/sandbox.py index 78ec28c9..a069c677 100644 --- a/src/blaxel/core/sandbox/default/sandbox.py +++ b/src/blaxel/core/sandbox/default/sandbox.py @@ -55,6 +55,35 @@ def __init__(self, message: str, status_code: int | None = None, code: str | Non logger = logging.getLogger(__name__) +NON_REUSABLE_SANDBOX_STATUSES = { + "FAILED", + "TERMINATED", + "TERMINATING", + "DELETING", + "DEACTIVATING", +} + + +def _is_sandbox_conflict(error: SandboxAPIError) -> bool: + return error.status_code == 409 or error.code in {409, "409", "SANDBOX_ALREADY_EXISTS"} + + +def _sandbox_name( + sandbox: Union[Sandbox, SandboxCreateConfiguration, Dict[str, Any]], +) -> str | None: + if isinstance(sandbox, SandboxCreateConfiguration): + return sandbox.name + if isinstance(sandbox, dict): + if "name" in sandbox: + return sandbox["name"] + metadata = sandbox.get("metadata") + if isinstance(metadata, dict): + return metadata.get("name") + return getattr(metadata, "name", None) + if isinstance(sandbox, Sandbox): + return sandbox.metadata.name if sandbox.metadata else None + return None + class _AsyncDeleteDescriptor: """Descriptor that provides both class-level and instance-level delete functionality.""" @@ -155,6 +184,7 @@ async def create( cls, sandbox: Union[Sandbox, SandboxCreateConfiguration, Dict[str, Any], None] = None, safe: bool = False, + create_if_not_exist: bool = False, ) -> "SandboxInstance": default_name = f"sandbox-{uuid.uuid4().hex[:8]}" default_image = "blaxel/base-image:latest" @@ -285,6 +315,7 @@ async def create( response = await create_sandbox( client=client, body=sandbox, + create_if_not_exist=create_if_not_exist, ) # Check if response is an error @@ -451,40 +482,23 @@ async def create_if_not_exists( cls, sandbox: Union[Sandbox, SandboxCreateConfiguration, Dict[str, Any]] ) -> "SandboxInstance": """Create a sandbox if it doesn't exist, otherwise return existing.""" - try: - return await cls.create(sandbox) - except SandboxAPIError as e: - # Check if it's a 409 conflict error (sandbox already exists) - if e.status_code == 409 or e.code in [409, "SANDBOX_ALREADY_EXISTS"]: - # Extract name from different configuration types - if isinstance(sandbox, SandboxCreateConfiguration): - name = sandbox.name - elif isinstance(sandbox, dict): - if "name" in sandbox: - name = sandbox["name"] - elif "metadata" in sandbox and isinstance(sandbox["metadata"], dict): - name = sandbox["metadata"].get("name") - else: - name = None - elif isinstance(sandbox, Sandbox): - name = sandbox.metadata.name if sandbox.metadata else None - else: - name = None + attempts = 3 + for _ in range(attempts): + try: + return await cls.create(sandbox, create_if_not_exist=True) + except SandboxAPIError as e: + if not _is_sandbox_conflict(e): + raise + name = _sandbox_name(sandbox) if not name: raise ValueError("Sandbox name is required") - # Get the existing sandbox to check its status sandbox_instance = await cls.get(name) + if str(sandbox_instance.status) not in NON_REUSABLE_SANDBOX_STATUSES: + return sandbox_instance - # If the sandbox is TERMINATED, treat it as not existing - if sandbox_instance.status == "TERMINATED": - # Create a new sandbox - backend will handle cleanup of the terminated one - return await cls.create(sandbox) - - # Otherwise return the existing active sandbox - return sandbox_instance - raise + raise RuntimeError(f"Unable to create sandbox after {attempts} attempts.") @classmethod async def from_session( diff --git a/src/blaxel/core/sandbox/sync/interpreter.py b/src/blaxel/core/sandbox/sync/interpreter.py index 6080d0d6..ff9dbd64 100644 --- a/src/blaxel/core/sandbox/sync/interpreter.py +++ b/src/blaxel/core/sandbox/sync/interpreter.py @@ -29,6 +29,7 @@ def create( cls, sandbox: Union[Sandbox, SandboxCreateConfiguration, Dict[str, Any], None] = None, safe: bool = True, + create_if_not_exist: bool = False, ) -> "SyncCodeInterpreter": """ Create a sandbox instance using the jupyter-server image. @@ -72,7 +73,11 @@ def create( if sandbox.spec and getattr(sandbox.spec, "region", None): payload["region"] = sandbox.spec.region - base_instance = SyncSandboxInstance.create(payload, safe=safe) + base_instance = SyncSandboxInstance.create( + payload, + safe=safe, + create_if_not_exist=create_if_not_exist, + ) return cls( sandbox=base_instance.sandbox, force_url=base_instance.config.force_url, diff --git a/src/blaxel/core/sandbox/sync/sandbox.py b/src/blaxel/core/sandbox/sync/sandbox.py index 65280ce7..9a76f503 100644 --- a/src/blaxel/core/sandbox/sync/sandbox.py +++ b/src/blaxel/core/sandbox/sync/sandbox.py @@ -27,7 +27,12 @@ from ...client.models.sandbox_error import SandboxError from ...client.types import UNSET from ...common.settings import settings -from ..default.sandbox import SandboxAPIError +from ..default.sandbox import ( + NON_REUSABLE_SANDBOX_STATUSES, + SandboxAPIError, + _is_sandbox_conflict, + _sandbox_name, +) from ..types import ( SandboxConfiguration, SandboxCreateConfiguration, @@ -138,6 +143,7 @@ def create( cls, sandbox: Union[Sandbox, SandboxCreateConfiguration, Dict[str, Any], None] = None, safe: bool = False, + create_if_not_exist: bool = False, ) -> "SyncSandboxInstance": default_name = f"sandbox-{uuid.uuid4().hex[:8]}" default_image = "blaxel/base-image:latest" @@ -251,6 +257,7 @@ def create( response = create_sandbox( client=client, body=sandbox, + create_if_not_exist=create_if_not_exist, ) # Check if response is an error @@ -389,30 +396,23 @@ def update_lifecycle( def create_if_not_exists( cls, sandbox: Union[Sandbox, SandboxCreateConfiguration, Dict[str, Any]] ) -> "SyncSandboxInstance": - try: - return cls.create(sandbox) - except SandboxAPIError as e: - if e.status_code == 409 or e.code in [409, "SANDBOX_ALREADY_EXISTS"]: - if isinstance(sandbox, SandboxCreateConfiguration): - name = sandbox.name - elif isinstance(sandbox, dict): - if "name" in sandbox: - name = sandbox["name"] - elif "metadata" in sandbox and isinstance(sandbox["metadata"], dict): - name = sandbox["metadata"].get("name") - else: - name = None - elif isinstance(sandbox, Sandbox): - name = sandbox.metadata.name if sandbox.metadata else None - else: - name = None + attempts = 3 + for _ in range(attempts): + try: + return cls.create(sandbox, create_if_not_exist=True) + except SandboxAPIError as e: + if not _is_sandbox_conflict(e): + raise + + name = _sandbox_name(sandbox) if not name: raise ValueError("Sandbox name is required") + sandbox_instance = cls.get(name) - if sandbox_instance.status == "TERMINATED": - return cls.create(sandbox) - return sandbox_instance - raise + if str(sandbox_instance.status) not in NON_REUSABLE_SANDBOX_STATUSES: + return sandbox_instance + + raise RuntimeError(f"Unable to create sandbox after {attempts} attempts.") @classmethod def from_session( diff --git a/tests/core/test_sandbox.py b/tests/core/test_sandbox.py index e5890799..d531aeb3 100644 --- a/tests/core/test_sandbox.py +++ b/tests/core/test_sandbox.py @@ -1,16 +1,38 @@ """Tests for sandbox functionality.""" import os -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, call, patch import pytest from blaxel.core.client.models import Metadata, Sandbox, SandboxSpec -from blaxel.core.sandbox import SandboxInstance +from blaxel.core.sandbox import ( + CodeInterpreter, + SandboxAPIError, + SandboxInstance, + SyncCodeInterpreter, + SyncSandboxInstance, +) from blaxel.core.sandbox.default.action import SandboxAction from blaxel.core.sandbox.types import ResponseError, SandboxConfiguration +def sandbox_instance(name: str, status: str = "DEPLOYED", cls=SandboxInstance): + sandbox_data = Sandbox(metadata=Metadata(name=name), spec=SandboxSpec()) + sandbox_data.status = status + return cls(sandbox_data) + + +def conflict_error() -> SandboxAPIError: + return SandboxAPIError("already exists", status_code=409) + + +def conflict_error_with_code(code) -> SandboxAPIError: + error = SandboxAPIError("already exists", code=code) + error.code = code + return error + + @pytest.mark.asyncio async def test_sandbox_creation(): """Test sandbox instance creation.""" @@ -154,3 +176,332 @@ async def test_sandbox_class_methods(): assert hasattr(SandboxInstance, "list") assert hasattr(SandboxInstance, "delete") assert hasattr(SandboxInstance, "wait") + + +@pytest.mark.asyncio +async def test_create_if_not_exists_uses_server_side_param(): + existing = sandbox_instance("existing") + + with ( + patch.object(SandboxInstance, "create", new_callable=AsyncMock) as mock_create, + patch.object(SandboxInstance, "get", new_callable=AsyncMock) as mock_get, + ): + mock_create.return_value = existing + + result = await SandboxInstance.create_if_not_exists({"name": "existing"}) + + assert result is existing + mock_create.assert_awaited_once_with({"name": "existing"}, create_if_not_exist=True) + mock_get.assert_not_called() + + +@pytest.mark.asyncio +async def test_create_forwards_create_if_not_exist_to_generated_client(): + created = sandbox_instance("created").sandbox + + with patch( + "blaxel.core.sandbox.default.sandbox.create_sandbox", + new_callable=AsyncMock, + ) as mock_create_sandbox: + mock_create_sandbox.return_value = created + + result = await SandboxInstance.create( + {"name": "created", "region": "us-pdx-1"}, + create_if_not_exist=True, + ) + + assert result.metadata.name == "created" + assert mock_create_sandbox.await_args.kwargs["create_if_not_exist"] is True + + +@pytest.mark.asyncio +async def test_create_if_not_exists_returns_existing_after_conflict(): + existing = sandbox_instance("existing") + + with ( + patch.object(SandboxInstance, "create", new_callable=AsyncMock) as mock_create, + patch.object(SandboxInstance, "get", new_callable=AsyncMock) as mock_get, + ): + mock_create.side_effect = [conflict_error()] + mock_get.return_value = existing + + result = await SandboxInstance.create_if_not_exists({"name": "existing"}) + + assert result is existing + mock_create.assert_awaited_once_with({"name": "existing"}, create_if_not_exist=True) + mock_get.assert_awaited_once_with("existing") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("code", ["SANDBOX_ALREADY_EXISTS", "409", 409]) +async def test_create_if_not_exists_accepts_conflict_error_codes(code): + existing = sandbox_instance("existing") + + with ( + patch.object(SandboxInstance, "create", new_callable=AsyncMock) as mock_create, + patch.object(SandboxInstance, "get", new_callable=AsyncMock) as mock_get, + ): + mock_create.side_effect = [conflict_error_with_code(code)] + mock_get.return_value = existing + + result = await SandboxInstance.create_if_not_exists({"name": "existing"}) + + assert result is existing + mock_get.assert_awaited_once_with("existing") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "status", + ["FAILED", "TERMINATED", "TERMINATING", "DELETING", "DEACTIVATING"], +) +async def test_create_if_not_exists_retries_for_non_reusable_statuses(status): + replacement = sandbox_instance("stale") + + with ( + patch.object(SandboxInstance, "create", new_callable=AsyncMock) as mock_create, + patch.object(SandboxInstance, "get", new_callable=AsyncMock) as mock_get, + ): + mock_create.side_effect = [conflict_error(), replacement] + mock_get.return_value = sandbox_instance("stale", status) + + result = await SandboxInstance.create_if_not_exists({"name": "stale"}) + + assert result is replacement + assert mock_create.await_args_list == [ + call({"name": "stale"}, create_if_not_exist=True), + call({"name": "stale"}, create_if_not_exist=True), + ] + + +@pytest.mark.asyncio +async def test_create_if_not_exists_handles_recreate_race_after_terminal_status(): + winner = sandbox_instance("race") + + with ( + patch.object(SandboxInstance, "create", new_callable=AsyncMock) as mock_create, + patch.object(SandboxInstance, "get", new_callable=AsyncMock) as mock_get, + ): + mock_create.side_effect = [conflict_error(), conflict_error()] + mock_get.side_effect = [sandbox_instance("race", "TERMINATED"), winner] + + result = await SandboxInstance.create_if_not_exists({"name": "race"}) + + assert result is winner + assert mock_create.await_count == 2 + assert mock_get.await_args_list == [call("race"), call("race")] + + +@pytest.mark.asyncio +async def test_create_if_not_exists_stops_after_bounded_attempts(): + with ( + patch.object(SandboxInstance, "create", new_callable=AsyncMock) as mock_create, + patch.object(SandboxInstance, "get", new_callable=AsyncMock) as mock_get, + ): + mock_create.side_effect = conflict_error() + mock_get.return_value = sandbox_instance("stuck", "TERMINATED") + + with pytest.raises(RuntimeError, match="Unable to create sandbox after 3 attempts"): + await SandboxInstance.create_if_not_exists({"name": "stuck"}) + + assert mock_create.await_count == 3 + assert mock_get.await_count == 3 + + +@pytest.mark.asyncio +async def test_code_interpreter_create_forwards_create_if_not_exist(): + with patch( + "blaxel.core.sandbox.default.interpreter.SandboxInstance.create", + new_callable=AsyncMock, + ) as mock_create: + mock_create.return_value = sandbox_instance("interpreter") + + result = await CodeInterpreter.create( + {"name": "interpreter"}, + safe=False, + create_if_not_exist=True, + ) + + assert isinstance(result, CodeInterpreter) + payload = mock_create.await_args.args[0] + assert payload["name"] == "interpreter" + assert mock_create.await_args.kwargs == { + "safe": False, + "create_if_not_exist": True, + } + + +@pytest.mark.asyncio +async def test_code_interpreter_create_if_not_exists_uses_server_side_param(): + with patch( + "blaxel.core.sandbox.default.interpreter.SandboxInstance.create", + new_callable=AsyncMock, + ) as mock_create: + mock_create.return_value = sandbox_instance("interpreter-existing") + + result = await CodeInterpreter.create_if_not_exists({"name": "interpreter-existing"}) + + assert isinstance(result, CodeInterpreter) + payload = mock_create.await_args.args[0] + assert payload["name"] == "interpreter-existing" + assert mock_create.await_args.kwargs == { + "safe": True, + "create_if_not_exist": True, + } + + +def test_sync_create_if_not_exists_uses_server_side_param(): + existing = sandbox_instance("existing", cls=SyncSandboxInstance) + + with ( + patch.object(SyncSandboxInstance, "create") as mock_create, + patch.object(SyncSandboxInstance, "get") as mock_get, + ): + mock_create.return_value = existing + + result = SyncSandboxInstance.create_if_not_exists({"name": "existing"}) + + assert result is existing + mock_create.assert_called_once_with({"name": "existing"}, create_if_not_exist=True) + mock_get.assert_not_called() + + +def test_sync_create_forwards_create_if_not_exist_to_generated_client(): + created = sandbox_instance("created", cls=SyncSandboxInstance).sandbox + + with patch("blaxel.core.sandbox.sync.sandbox.create_sandbox") as mock_create_sandbox: + mock_create_sandbox.return_value = created + + result = SyncSandboxInstance.create( + {"name": "created", "region": "us-pdx-1"}, + create_if_not_exist=True, + ) + + assert result.metadata.name == "created" + assert mock_create_sandbox.call_args.kwargs["create_if_not_exist"] is True + + +def test_sync_create_if_not_exists_returns_existing_after_conflict(): + existing = sandbox_instance("existing", cls=SyncSandboxInstance) + + with ( + patch.object(SyncSandboxInstance, "create") as mock_create, + patch.object(SyncSandboxInstance, "get") as mock_get, + ): + mock_create.side_effect = [conflict_error()] + mock_get.return_value = existing + + result = SyncSandboxInstance.create_if_not_exists({"name": "existing"}) + + assert result is existing + mock_create.assert_called_once_with({"name": "existing"}, create_if_not_exist=True) + mock_get.assert_called_once_with("existing") + + +@pytest.mark.parametrize("code", ["SANDBOX_ALREADY_EXISTS", "409", 409]) +def test_sync_create_if_not_exists_accepts_conflict_error_codes(code): + existing = sandbox_instance("existing", cls=SyncSandboxInstance) + + with ( + patch.object(SyncSandboxInstance, "create") as mock_create, + patch.object(SyncSandboxInstance, "get") as mock_get, + ): + mock_create.side_effect = [conflict_error_with_code(code)] + mock_get.return_value = existing + + result = SyncSandboxInstance.create_if_not_exists({"name": "existing"}) + + assert result is existing + mock_get.assert_called_once_with("existing") + + +@pytest.mark.parametrize( + "status", + ["FAILED", "TERMINATED", "TERMINATING", "DELETING", "DEACTIVATING"], +) +def test_sync_create_if_not_exists_retries_for_non_reusable_statuses(status): + replacement = sandbox_instance("stale", cls=SyncSandboxInstance) + + with ( + patch.object(SyncSandboxInstance, "create") as mock_create, + patch.object(SyncSandboxInstance, "get") as mock_get, + ): + mock_create.side_effect = [conflict_error(), replacement] + mock_get.return_value = sandbox_instance("stale", status, cls=SyncSandboxInstance) + + result = SyncSandboxInstance.create_if_not_exists({"name": "stale"}) + + assert result is replacement + assert mock_create.call_args_list == [ + call({"name": "stale"}, create_if_not_exist=True), + call({"name": "stale"}, create_if_not_exist=True), + ] + + +def test_sync_create_if_not_exists_handles_recreate_race_after_terminal_status(): + winner = sandbox_instance("race", cls=SyncSandboxInstance) + + with ( + patch.object(SyncSandboxInstance, "create") as mock_create, + patch.object(SyncSandboxInstance, "get") as mock_get, + ): + mock_create.side_effect = [conflict_error(), conflict_error()] + mock_get.side_effect = [ + sandbox_instance("race", "TERMINATED", cls=SyncSandboxInstance), + winner, + ] + + result = SyncSandboxInstance.create_if_not_exists({"name": "race"}) + + assert result is winner + assert mock_create.call_count == 2 + assert mock_get.call_args_list == [call("race"), call("race")] + + +def test_sync_create_if_not_exists_stops_after_bounded_attempts(): + with ( + patch.object(SyncSandboxInstance, "create") as mock_create, + patch.object(SyncSandboxInstance, "get") as mock_get, + ): + mock_create.side_effect = conflict_error() + mock_get.return_value = sandbox_instance("stuck", "TERMINATED", cls=SyncSandboxInstance) + + with pytest.raises(RuntimeError, match="Unable to create sandbox after 3 attempts"): + SyncSandboxInstance.create_if_not_exists({"name": "stuck"}) + + assert mock_create.call_count == 3 + assert mock_get.call_count == 3 + + +def test_sync_code_interpreter_create_forwards_create_if_not_exist(): + with patch("blaxel.core.sandbox.sync.interpreter.SyncSandboxInstance.create") as mock_create: + mock_create.return_value = sandbox_instance("interpreter", cls=SyncSandboxInstance) + + result = SyncCodeInterpreter.create( + {"name": "interpreter"}, + safe=False, + create_if_not_exist=True, + ) + + assert isinstance(result, SyncCodeInterpreter) + payload = mock_create.call_args.args[0] + assert payload["name"] == "interpreter" + assert mock_create.call_args.kwargs == { + "safe": False, + "create_if_not_exist": True, + } + + +def test_sync_code_interpreter_create_if_not_exists_uses_server_side_param(): + with patch("blaxel.core.sandbox.sync.interpreter.SyncSandboxInstance.create") as mock_create: + mock_create.return_value = sandbox_instance("interpreter-existing", cls=SyncSandboxInstance) + + result = SyncCodeInterpreter.create_if_not_exists({"name": "interpreter-existing"}) + + assert isinstance(result, SyncCodeInterpreter) + payload = mock_create.call_args.args[0] + assert payload["name"] == "interpreter-existing" + assert mock_create.call_args.kwargs == { + "safe": True, + "create_if_not_exist": True, + }