diff --git a/src/blaxel/core/sandbox/sync/filesystem.py b/src/blaxel/core/sandbox/sync/filesystem.py index b263057..eaf3af6 100644 --- a/src/blaxel/core/sandbox/sync/filesystem.py +++ b/src/blaxel/core/sandbox/sync/filesystem.py @@ -302,9 +302,18 @@ def _upload_with_multipart( for i in range(0, num_parts, MAX_PARALLEL_UPLOADS): threads = [] results: Dict[int, Dict[str, Any]] = {} + exceptions: List[Exception] = [] + results_lock = threading.Lock() def make_upload(part_number: int, chunk: bytes): - results[part_number] = self._upload_part(upload_id, part_number, chunk) + try: + result = self._upload_part(upload_id, part_number, chunk) + except Exception as error: + with results_lock: + exceptions.append(error) + else: + with results_lock: + results[part_number] = result for j in range(MAX_PARALLEL_UPLOADS): if i + j >= num_parts: @@ -318,6 +327,8 @@ def make_upload(part_number: int, chunk: bytes): t.start() for t in threads: t.join() + if exceptions: + raise exceptions[0] for part_number, r in results.items(): parts.append({"partNumber": part_number, "etag": r.get("etag")}) parts.sort(key=lambda p: p.get("partNumber", 0)) diff --git a/tests/core/test_sandbox_filesystem.py b/tests/core/test_sandbox_filesystem.py new file mode 100644 index 0000000..5793683 --- /dev/null +++ b/tests/core/test_sandbox_filesystem.py @@ -0,0 +1,31 @@ +import pytest + +from blaxel.core.sandbox.sync.filesystem import SyncSandboxFileSystem + + +def test_sync_multipart_upload_aborts_when_part_thread_fails(): + filesystem = object.__new__(SyncSandboxFileSystem) + uploaded_parts = [] + aborted_uploads = [] + completed_parts = [] + + filesystem._initiate_multipart_upload = lambda path, permissions="0644": { + "uploadId": "upload-1" + } + + def upload_part(upload_id, part_number, data): + uploaded_parts.append(part_number) + if part_number == 2: + raise RuntimeError("part 2 failed") + return {"partNumber": part_number, "etag": f"etag-{part_number}"} + + filesystem._upload_part = upload_part + filesystem._abort_multipart_upload = lambda upload_id: aborted_uploads.append(upload_id) + filesystem._complete_multipart_upload = lambda upload_id, parts: completed_parts.append(parts) + + with pytest.raises(RuntimeError, match="part 2 failed"): + filesystem._upload_with_multipart("/tmp/large.bin", b"0" * (11 * 1024 * 1024)) + + assert 2 in uploaded_parts + assert aborted_uploads == ["upload-1"] + assert completed_parts == []