Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions src/blaxel/core/sandbox/default/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import json
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Union

Expand Down Expand Up @@ -81,15 +82,19 @@ async def write_binary(
"""
path = self.format_path(path)

# If content is a string, treat it as a file path and read it
# If content is a string, treat it as a file path
if isinstance(content, str):
local_path = Path(content)
file_size = local_path.stat().st_size
# Stream from disk for large files to avoid loading into memory
if file_size > MULTIPART_THRESHOLD:
return await self._upload_file_with_multipart(path, local_path, "0644")
content = local_path.read_bytes()
# Convert bytearray to bytes if necessary
elif isinstance(content, bytearray):
content = bytes(content)

# Use multipart upload for large files
# Use multipart upload for large in-memory data
if len(content) > MULTIPART_THRESHOLD:
return await self._upload_with_multipart(path, content, "0644")

Expand Down Expand Up @@ -535,6 +540,64 @@ async def _abort_multipart_upload(self, upload_id: str) -> None:
finally:
await response.aclose()

async def _upload_file_with_multipart(
self, path: str, local_path: Path, permissions: str = "0644"
) -> SuccessResponse:
"""Upload a local file using streaming multipart upload.

Reads chunks directly from disk without loading the entire file into memory.
At most MAX_PARALLEL_UPLOADS * CHUNK_SIZE bytes are held in memory at once.
"""
file_size = local_path.stat().st_size

init_response = await self._initiate_multipart_upload(path, permissions)
upload_id = init_response.get("uploadId")

if not upload_id:
raise Exception("Failed to get upload ID from initiate response")

try:
num_parts = (file_size + CHUNK_SIZE - 1) // CHUNK_SIZE
parts: List[Dict[str, Any]] = []

fd = os.open(str(local_path), os.O_RDONLY)
try:
for i in range(0, num_parts, MAX_PARALLEL_UPLOADS):
batch_tasks = []

for j in range(MAX_PARALLEL_UPLOADS):
if i + j >= num_parts:
break

part_number = i + j + 1
offset = (part_number - 1) * CHUNK_SIZE
read_size = min(CHUNK_SIZE, file_size - offset)
chunk = os.pread(fd, read_size, offset)

batch_tasks.append(self._upload_part(upload_id, part_number, chunk))

batch_results = await asyncio.gather(*batch_tasks)
parts.extend(
[
{
"partNumber": r.get("partNumber"),
"etag": r.get("etag"),
}
for r in batch_results
]
)
finally:
os.close(fd)

parts.sort(key=lambda p: p.get("partNumber", 0))
return await self._complete_multipart_upload(upload_id, parts)
except Exception as error:
try:
await self._abort_multipart_upload(upload_id)
except Exception as abort_error:
logger.warning(f"Failed to abort multipart upload: {abort_error}")
raise error

async def _upload_with_multipart(
self, path: str, data: bytes, permissions: str = "0644"
) -> SuccessResponse:
Expand Down
64 changes: 64 additions & 0 deletions src/blaxel/core/sandbox/sync/filesystem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import json
import logging
import os
import threading
from pathlib import Path
from typing import Any, Callable, Dict, List, Union
Expand Down Expand Up @@ -55,6 +56,10 @@ def write_binary(self, path: str, content: Union[bytes, bytearray, str]) -> Succ
path = self.format_path(path)
if isinstance(content, str):
local_path = Path(content)
file_size = local_path.stat().st_size
# Stream from disk for large files to avoid loading into memory
if file_size > MULTIPART_THRESHOLD:
return self._upload_file_with_multipart(path, local_path, "0644")
content = local_path.read_bytes()
elif isinstance(content, bytearray):
content = bytes(content)
Expand Down Expand Up @@ -287,6 +292,65 @@ def _abort_multipart_upload(self, upload_id: str) -> None:
if not response.is_success:
logger.warning(f"Failed to abort multipart upload: {response.status_code}")

def _upload_file_with_multipart(
self, path: str, local_path: Path, permissions: str = "0644"
) -> SuccessResponse:
"""Upload a local file using streaming multipart upload.

Reads chunks directly from disk without loading the entire file into memory.
"""
file_size = local_path.stat().st_size

init_response = self._initiate_multipart_upload(path, permissions)
upload_id = init_response.get("uploadId")
if not upload_id:
raise Exception("Failed to get upload ID from initiate response")

try:
num_parts = (file_size + CHUNK_SIZE - 1) // CHUNK_SIZE
parts: List[Dict[str, Any]] = []

fd = os.open(str(local_path), os.O_RDONLY)
try:
for i in range(0, num_parts, MAX_PARALLEL_UPLOADS):
threads = []
results: Dict[int, Dict[str, Any]] = {}
exceptions: List[Exception] = []

def make_upload(part_number: int, chunk: bytes):
try:
results[part_number] = self._upload_part(upload_id, part_number, chunk)
except Exception as e:
exceptions.append(e)

for j in range(MAX_PARALLEL_UPLOADS):
if i + j >= num_parts:
break
part_number = i + j + 1
offset = (part_number - 1) * CHUNK_SIZE
read_size = min(CHUNK_SIZE, file_size - offset)
chunk = os.pread(fd, read_size, offset)
t = threading.Thread(target=make_upload, args=(part_number, chunk))
threads.append(t)
t.start()
for t in threads:
t.join()
Comment thread
mendral-app[bot] marked this conversation as resolved.
if exceptions:
raise exceptions[0]
for part_number, r in results.items():
parts.append({"partNumber": part_number, "etag": r.get("etag")})
finally:
os.close(fd)

parts.sort(key=lambda p: p.get("partNumber", 0))
return self._complete_multipart_upload(upload_id, parts)
except Exception as error:
try:
self._abort_multipart_upload(upload_id)
except Exception as abort_error:
logger.warning(f"Failed to abort multipart upload: {abort_error}")
raise error

def _upload_with_multipart(
self, path: str, data: bytes, permissions: str = "0644"
) -> SuccessResponse:
Expand Down
Loading