diff --git a/src/blaxel/core/sandbox/default/filesystem.py b/src/blaxel/core/sandbox/default/filesystem.py index c5e8a1d..7d3293f 100644 --- a/src/blaxel/core/sandbox/default/filesystem.py +++ b/src/blaxel/core/sandbox/default/filesystem.py @@ -2,6 +2,7 @@ import io import json import logging +import os from pathlib import Path from typing import Any, Callable, Dict, List, Union @@ -81,15 +82,19 @@ async def write_binary( """ path = self.format_path(path) - # If content is a string, treat it as a file path and read it + # If content is a string, treat it as a file path if isinstance(content, str): local_path = Path(content) + file_size = local_path.stat().st_size + # Stream from disk for large files to avoid loading into memory + if file_size > MULTIPART_THRESHOLD: + return await self._upload_file_with_multipart(path, local_path, "0644") content = local_path.read_bytes() # Convert bytearray to bytes if necessary elif isinstance(content, bytearray): content = bytes(content) - # Use multipart upload for large files + # Use multipart upload for large in-memory data if len(content) > MULTIPART_THRESHOLD: return await self._upload_with_multipart(path, content, "0644") @@ -535,6 +540,64 @@ async def _abort_multipart_upload(self, upload_id: str) -> None: finally: await response.aclose() + async def _upload_file_with_multipart( + self, path: str, local_path: Path, permissions: str = "0644" + ) -> SuccessResponse: + """Upload a local file using streaming multipart upload. + + Reads chunks directly from disk without loading the entire file into memory. + At most MAX_PARALLEL_UPLOADS * CHUNK_SIZE bytes are held in memory at once. + """ + file_size = local_path.stat().st_size + + init_response = await self._initiate_multipart_upload(path, permissions) + upload_id = init_response.get("uploadId") + + if not upload_id: + raise Exception("Failed to get upload ID from initiate response") + + try: + num_parts = (file_size + CHUNK_SIZE - 1) // CHUNK_SIZE + parts: List[Dict[str, Any]] = [] + + fd = os.open(str(local_path), os.O_RDONLY) + try: + for i in range(0, num_parts, MAX_PARALLEL_UPLOADS): + batch_tasks = [] + + for j in range(MAX_PARALLEL_UPLOADS): + if i + j >= num_parts: + break + + part_number = i + j + 1 + offset = (part_number - 1) * CHUNK_SIZE + read_size = min(CHUNK_SIZE, file_size - offset) + chunk = os.pread(fd, read_size, offset) + + batch_tasks.append(self._upload_part(upload_id, part_number, chunk)) + + batch_results = await asyncio.gather(*batch_tasks) + parts.extend( + [ + { + "partNumber": r.get("partNumber"), + "etag": r.get("etag"), + } + for r in batch_results + ] + ) + finally: + os.close(fd) + + parts.sort(key=lambda p: p.get("partNumber", 0)) + return await self._complete_multipart_upload(upload_id, parts) + except Exception as error: + try: + await self._abort_multipart_upload(upload_id) + except Exception as abort_error: + logger.warning(f"Failed to abort multipart upload: {abort_error}") + raise error + async def _upload_with_multipart( self, path: str, data: bytes, permissions: str = "0644" ) -> SuccessResponse: diff --git a/src/blaxel/core/sandbox/sync/filesystem.py b/src/blaxel/core/sandbox/sync/filesystem.py index b263057..a92f25d 100644 --- a/src/blaxel/core/sandbox/sync/filesystem.py +++ b/src/blaxel/core/sandbox/sync/filesystem.py @@ -1,6 +1,7 @@ import io import json import logging +import os import threading from pathlib import Path from typing import Any, Callable, Dict, List, Union @@ -55,6 +56,10 @@ def write_binary(self, path: str, content: Union[bytes, bytearray, str]) -> Succ path = self.format_path(path) if isinstance(content, str): local_path = Path(content) + file_size = local_path.stat().st_size + # Stream from disk for large files to avoid loading into memory + if file_size > MULTIPART_THRESHOLD: + return self._upload_file_with_multipart(path, local_path, "0644") content = local_path.read_bytes() elif isinstance(content, bytearray): content = bytes(content) @@ -287,6 +292,65 @@ def _abort_multipart_upload(self, upload_id: str) -> None: if not response.is_success: logger.warning(f"Failed to abort multipart upload: {response.status_code}") + def _upload_file_with_multipart( + self, path: str, local_path: Path, permissions: str = "0644" + ) -> SuccessResponse: + """Upload a local file using streaming multipart upload. + + Reads chunks directly from disk without loading the entire file into memory. + """ + file_size = local_path.stat().st_size + + init_response = self._initiate_multipart_upload(path, permissions) + upload_id = init_response.get("uploadId") + if not upload_id: + raise Exception("Failed to get upload ID from initiate response") + + try: + num_parts = (file_size + CHUNK_SIZE - 1) // CHUNK_SIZE + parts: List[Dict[str, Any]] = [] + + fd = os.open(str(local_path), os.O_RDONLY) + try: + for i in range(0, num_parts, MAX_PARALLEL_UPLOADS): + threads = [] + results: Dict[int, Dict[str, Any]] = {} + exceptions: List[Exception] = [] + + def make_upload(part_number: int, chunk: bytes): + try: + results[part_number] = self._upload_part(upload_id, part_number, chunk) + except Exception as e: + exceptions.append(e) + + for j in range(MAX_PARALLEL_UPLOADS): + if i + j >= num_parts: + break + part_number = i + j + 1 + offset = (part_number - 1) * CHUNK_SIZE + read_size = min(CHUNK_SIZE, file_size - offset) + chunk = os.pread(fd, read_size, offset) + t = threading.Thread(target=make_upload, args=(part_number, chunk)) + threads.append(t) + t.start() + for t in threads: + t.join() + if exceptions: + raise exceptions[0] + for part_number, r in results.items(): + parts.append({"partNumber": part_number, "etag": r.get("etag")}) + finally: + os.close(fd) + + parts.sort(key=lambda p: p.get("partNumber", 0)) + return self._complete_multipart_upload(upload_id, parts) + except Exception as error: + try: + self._abort_multipart_upload(upload_id) + except Exception as abort_error: + logger.warning(f"Failed to abort multipart upload: {abort_error}") + raise error + def _upload_with_multipart( self, path: str, data: bytes, permissions: str = "0644" ) -> SuccessResponse: