Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/blaxel/core/sandbox/sync/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,18 @@ def _upload_with_multipart(
for i in range(0, num_parts, MAX_PARALLEL_UPLOADS):
threads = []
results: Dict[int, Dict[str, Any]] = {}
exceptions: List[Exception] = []
results_lock = threading.Lock()

def make_upload(part_number: int, chunk: bytes):
results[part_number] = self._upload_part(upload_id, part_number, chunk)
try:
result = self._upload_part(upload_id, part_number, chunk)
except Exception as error:
with results_lock:
exceptions.append(error)
else:
with results_lock:
results[part_number] = result

for j in range(MAX_PARALLEL_UPLOADS):
if i + j >= num_parts:
Expand All @@ -318,6 +327,8 @@ def make_upload(part_number: int, chunk: bytes):
t.start()
for t in threads:
t.join()
if exceptions:
raise exceptions[0]
for part_number, r in results.items():
parts.append({"partNumber": part_number, "etag": r.get("etag")})
parts.sort(key=lambda p: p.get("partNumber", 0))
Expand Down
31 changes: 31 additions & 0 deletions tests/core/test_sandbox_filesystem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from blaxel.core.sandbox.sync.filesystem import SyncSandboxFileSystem


def test_sync_multipart_upload_aborts_when_part_thread_fails():
filesystem = object.__new__(SyncSandboxFileSystem)
uploaded_parts = []
aborted_uploads = []
completed_parts = []

filesystem._initiate_multipart_upload = lambda path, permissions="0644": {
"uploadId": "upload-1"
}

def upload_part(upload_id, part_number, data):
uploaded_parts.append(part_number)
if part_number == 2:
raise RuntimeError("part 2 failed")
return {"partNumber": part_number, "etag": f"etag-{part_number}"}

filesystem._upload_part = upload_part
filesystem._abort_multipart_upload = lambda upload_id: aborted_uploads.append(upload_id)
filesystem._complete_multipart_upload = lambda upload_id, parts: completed_parts.append(parts)

with pytest.raises(RuntimeError, match="part 2 failed"):
filesystem._upload_with_multipart("/tmp/large.bin", b"0" * (11 * 1024 * 1024))

assert 2 in uploaded_parts
assert aborted_uploads == ["upload-1"]
assert completed_parts == []
Loading