Skip to content

Commit dc7dd7d

Browse files
authored
feat(s3,utils): use niquests on_upload hook and optimize stream chunking (#49)
* feat(s3,utils): use niquests on_upload hook and optimize stream chunking
1 parent 8463e30 commit dc7dd7d

5 files changed

Lines changed: 179 additions & 41 deletions

File tree

tests/s3/test_niquests.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,11 +256,12 @@ async def data_stream():
256256
for i in range(0, len(test_data), chunk_size):
257257
yield test_data[i : i + chunk_size]
258258

259-
received_size = 0
259+
upload_progress_called = False
260260

261-
def on_chunk(chunk: bytes):
262-
nonlocal received_size
263-
received_size += len(chunk)
261+
def on_upload(req: niquests.PreparedRequest):
262+
nonlocal upload_progress_called
263+
if req.upload_progress is not None:
264+
upload_progress_called = True
264265

265266
await s3_file_upload(
266267
s3,
@@ -269,11 +270,11 @@ def on_chunk(chunk: bytes):
269270
key,
270271
data_stream(),
271272
min_part_size=5 * 1024 * 1024,
272-
on_chunk_received=on_chunk,
273+
on_upload=on_upload,
273274
content_length=content_length,
274275
)
275276

276-
assert received_size == data_size
277+
assert upload_progress_called
277278
result = await s3_get_object(s3, client, s3_bucket, key)
278279
assert result == test_data
279280

tests/test_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,49 @@ async def async_data():
176176
assert extra_check(chunks)
177177

178178
asyncio.run(_test())
179+
180+
181+
@pytest.mark.parametrize(
182+
("input_chunks", "min_size", "expected_total", "extra_check"),
183+
[
184+
pytest.param(
185+
["hello", "world", "12345"],
186+
5,
187+
15,
188+
lambda chunks: all(len(c) >= 5 for c in chunks),
189+
id="exact_chunks",
190+
),
191+
pytest.param(
192+
["small"],
193+
100,
194+
5,
195+
lambda chunks: chunks == ["small"],
196+
id="single_small_chunk",
197+
),
198+
pytest.param(
199+
["hello", "", "world"],
200+
5,
201+
10,
202+
None,
203+
id="empty_chunks_ignored",
204+
),
205+
],
206+
)
207+
def test_get_stream_chunk_str(input_chunks, min_size, expected_total, extra_check):
208+
from tracktolib.utils import get_stream_chunk_str
209+
210+
async def _test():
211+
async def async_data():
212+
for chunk in input_chunks:
213+
yield chunk
214+
215+
chunks = []
216+
async for chunk in get_stream_chunk_str(async_data(), min_size=min_size):
217+
chunks.append(chunk)
218+
219+
total_size = sum(len(c) for c in chunks)
220+
assert total_size == expected_total
221+
if extra_check:
222+
assert extra_check(chunks)
223+
224+
asyncio.run(_test())

tracktolib/s3/niquests.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..utils import get_stream_chunk
3232

3333
__all__ = (
34+
"OnUpload",
3435
"S3MultipartUpload",
3536
"S3Object",
3637
"S3ObjectParams",
@@ -200,10 +201,10 @@ class S3Session:
200201
...
201202
"""
202203

203-
endpoint_url: str
204-
access_key: str
205-
secret_key: str
206-
region: str
204+
endpoint_url: str | None = None
205+
access_key: str | None = None
206+
secret_key: str | None = None
207+
region: str | None = None
207208
s3_config: Config | None = None
208209
s3_client: botocore.client.BaseClient | None = None
209210
http_client: niquests.AsyncSession = field(default_factory=niquests.AsyncSession)
@@ -212,7 +213,8 @@ class S3Session:
212213
def __post_init__(self):
213214
if self.s3_client is None:
214215
self._botocore_session = botocore.session.Session()
215-
self._botocore_session.set_credentials(self.access_key, self.secret_key)
216+
if self.access_key is not None and self.secret_key is not None:
217+
self._botocore_session.set_credentials(self.access_key, self.secret_key)
216218
self.s3_client = self._botocore_session.create_client(
217219
"s3",
218220
endpoint_url=self.endpoint_url,
@@ -307,7 +309,7 @@ async def file_upload(
307309
data: AsyncIterator[bytes],
308310
*,
309311
min_part_size: int = 5 * 1024 * 1024,
310-
on_chunk_received: Callable[[bytes], None] | None = None,
312+
on_upload: OnUpload | None = None,
311313
content_length: int | None = None,
312314
**kwargs: Unpack[S3ObjectParams],
313315
) -> None:
@@ -319,7 +321,7 @@ async def file_upload(
319321
key,
320322
data,
321323
min_part_size=min_part_size,
322-
on_chunk_received=on_chunk_received,
324+
on_upload=on_upload,
323325
content_length=content_length,
324326
**kwargs,
325327
)
@@ -392,7 +394,7 @@ class UploadPart(TypedDict):
392394
class S3MultipartUpload(NamedTuple):
393395
fetch_create: Callable[[], Awaitable[str]]
394396
fetch_complete: Callable[[], Awaitable[niquests.Response]]
395-
upload_part: Callable[[bytes], Awaitable[UploadPart]]
397+
upload_part: Callable[[bytes | bytearray], Awaitable[UploadPart]]
396398
generate_presigned_url: Callable[..., str]
397399
fetch_abort: Callable[[], Awaitable[niquests.Response]]
398400

@@ -508,12 +510,21 @@ async def s3_list_files(
508510
break
509511

510512

513+
type OnUpload = Callable[[niquests.PreparedRequest], None]
514+
515+
516+
def _upload_hooks(on_upload: OnUpload | None) -> dict | None:
517+
return {"on_upload": [on_upload]} if on_upload else None
518+
519+
511520
async def s3_put_object(
512521
s3: botocore.client.BaseClient,
513522
client: niquests.AsyncSession,
514523
bucket: str,
515524
key: str,
516-
data: bytes,
525+
data: bytes | bytearray,
526+
*,
527+
on_upload: OnUpload | None = None,
517528
**kwargs: Unpack[S3ObjectParams],
518529
) -> niquests.Response:
519530
"""
@@ -529,7 +540,9 @@ async def s3_put_object(
529540
ClientMethod="put_object",
530541
Params=presigned_params,
531542
)
532-
resp = (await client.put(url, data=data, headers=headers if headers else None)).raise_for_status()
543+
resp = (
544+
await client.put(url, data=data, headers=headers if headers else None, hooks=_upload_hooks(on_upload))
545+
).raise_for_status()
533546
return resp
534547

535548

@@ -638,6 +651,7 @@ async def s3_multipart_upload(
638651
key: str,
639652
*,
640653
expires_in: int = 3600,
654+
on_upload: OnUpload | None = None,
641655
**kwargs: Unpack[S3ObjectParams],
642656
) -> AsyncIterator[S3MultipartUpload]:
643657
"""Async context manager for S3 multipart upload with automatic cleanup."""
@@ -670,12 +684,12 @@ async def fetch_abort():
670684
_has_been_aborted = True
671685
return abort_resp
672686

673-
async def upload_part(data: bytes) -> UploadPart:
687+
async def upload_part(data: bytes | bytearray) -> UploadPart:
674688
nonlocal _part_number, _parts
675689
if upload_id is None:
676690
raise ValueError("Upload ID is not set")
677691
presigned_url = _generate_presigned_url("upload_part", UploadId=upload_id, PartNumber=_part_number)
678-
upload_resp = (await client.put(presigned_url, data=data)).raise_for_status()
692+
upload_resp = (await client.put(presigned_url, data=data, hooks=_upload_hooks(on_upload))).raise_for_status()
679693
_etag = upload_resp.headers.get("ETag")
680694
etag: str | None = _etag.decode() if isinstance(_etag, bytes) else _etag
681695
_part: UploadPart = {"PartNumber": _part_number, "ETag": etag}
@@ -723,38 +737,35 @@ async def s3_file_upload(
723737
*,
724738
# 5MB minimum for S3 parts
725739
min_part_size: int = 5 * 1024 * 1024,
726-
on_chunk_received: Callable[[bytes], None] | None = None,
740+
on_upload: OnUpload | None = None,
727741
content_length: int | None = None,
728742
**kwargs: Unpack[S3ObjectParams],
729743
) -> None:
730744
"""
731745
Upload a file to S3 from an async byte stream.
732746
733747
Uses multipart upload for large files. If `content_length` is provided and smaller
734-
than `min_part_size`, uses a single PUT instead. Use `on_chunk_received` callback
735-
to track upload progress.
748+
than `min_part_size`, uses a single PUT instead. The optional `on_upload` callback
749+
receives a `niquests.PreparedRequest` with an `upload_progress` attribute for
750+
fine-grained byte-level progress tracking.
736751
"""
737752
if content_length is not None and content_length < min_part_size:
738753
# Small file - use single PUT operation
739-
_data = b""
754+
_data = bytearray()
740755
async for chunk in data:
741-
_data += chunk
742-
if on_chunk_received:
743-
on_chunk_received(chunk)
744-
await s3_put_object(s3, client, bucket=bucket, key=key, data=_data, **kwargs)
756+
_data.extend(chunk)
757+
await s3_put_object(s3, client, bucket=bucket, key=key, data=bytes(_data), on_upload=on_upload, **kwargs)
745758
return
746759

747-
async with s3_multipart_upload(s3, client, bucket=bucket, key=key, **kwargs) as mpart:
760+
async with s3_multipart_upload(s3, client, bucket=bucket, key=key, on_upload=on_upload, **kwargs) as mpart:
748761
await mpart.fetch_create()
749762
has_uploaded_parts = False
750763
async for chunk in get_stream_chunk(data, min_size=min_part_size):
751-
if on_chunk_received:
752-
on_chunk_received(chunk)
753764
if len(chunk) < min_part_size:
754765
if not has_uploaded_parts:
755766
# No parts uploaded yet, abort multipart and use single PUT
756767
await mpart.fetch_abort()
757-
await s3_put_object(s3, client, bucket=bucket, key=key, data=chunk, **kwargs)
768+
await s3_put_object(s3, client, bucket=bucket, key=key, data=chunk, on_upload=on_upload, **kwargs)
758769
else:
759770
# Parts already uploaded, upload final chunk as last part (S3 allows last part to be smaller)
760771
await mpart.upload_part(chunk)

tracktolib/utils.py

Lines changed: 90 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
2+
import collections
23
import datetime as dt
34
import importlib.util
5+
import io
46
import itertools
57
import mmap
68
import os
@@ -112,28 +114,106 @@ def get_chunks[T](it: Iterable[T], size: int, *, as_list: bool = True) -> Iterat
112114
yield d if not as_list else list(d)
113115

114116

115-
async def get_stream_chunk[S: (bytes, str)](data_stream: AsyncIterable[S], min_size: int) -> AsyncIterator[S]:
116-
"""Buffers an async stream and yields chunks of at least `min_size`."""
117-
buffer: S | None = None
117+
async def get_stream_chunk_str(
118+
data_stream: AsyncIterable[str],
119+
min_size: int,
120+
) -> AsyncIterator[str]:
121+
"""Buffers an async string stream and yields chunks of at least `min_size`."""
122+
buffer = ""
118123
buffer_size = 0
119-
120124
async for chunk in data_stream:
121125
if not chunk:
122126
continue
123-
buffer = chunk if buffer is None else buffer + chunk # type: ignore[operator]
127+
buffer += chunk
124128
buffer_size += len(chunk)
125-
126-
# Yield chunks of min_size while we have enough data for at least 2 chunks
127129
while buffer_size >= min_size * 2:
128130
yield buffer[:min_size]
129131
buffer = buffer[min_size:]
130132
buffer_size -= min_size
131-
132-
# Handle the final chunk(s)
133-
if buffer is not None and buffer_size > 0:
133+
if buffer_size > 0:
134134
yield buffer
135135

136136

137+
class BytesBuffer:
138+
"""Memory-efficient bytes buffer using a deque of chunks.
139+
140+
Appends are O(1) (no copy). Reads only copy at chunk boundaries via memoryview.
141+
Adapted from urllib3's BytesQueueBuffer.
142+
"""
143+
144+
__slots__ = ("buffer", "_size")
145+
146+
def __init__(self) -> None:
147+
self.buffer: collections.deque[bytes | memoryview[bytes]] = collections.deque()
148+
self._size: int = 0
149+
150+
def __len__(self) -> int:
151+
return self._size
152+
153+
def put(self, data: bytes) -> None:
154+
self.buffer.append(data)
155+
self._size += len(data)
156+
157+
def get(self, n: int) -> bytes:
158+
if not self.buffer:
159+
raise RuntimeError("buffer is empty")
160+
161+
# Fast path: first chunk is exactly the right size
162+
if len(self.buffer[0]) == n and isinstance(self.buffer[0], bytes):
163+
self._size -= n
164+
return self.buffer.popleft() # type: ignore[return-value]
165+
166+
fetched = 0
167+
ret = io.BytesIO()
168+
while fetched < n:
169+
remaining = n - fetched
170+
chunk = self.buffer.popleft()
171+
chunk_length = len(chunk)
172+
if remaining < chunk_length:
173+
mv = memoryview(chunk)
174+
ret.write(mv[:remaining])
175+
self.buffer.appendleft(mv[remaining:]) # type: ignore[arg-type]
176+
self._size -= remaining
177+
break
178+
ret.write(chunk)
179+
self._size -= chunk_length
180+
fetched += chunk_length
181+
if not self.buffer:
182+
break
183+
return ret.getvalue()
184+
185+
def get_all(self) -> bytes:
186+
buffer = self.buffer
187+
if not buffer:
188+
return b""
189+
if len(buffer) == 1:
190+
result = buffer.pop()
191+
if isinstance(result, memoryview):
192+
result = result.tobytes()
193+
else:
194+
ret = io.BytesIO()
195+
ret.writelines(buffer.popleft() for _ in range(len(buffer)))
196+
result = ret.getvalue()
197+
self._size = 0
198+
return result # type: ignore[return-value]
199+
200+
201+
async def get_stream_chunk(
202+
data_stream: AsyncIterable[bytes],
203+
min_size: int,
204+
) -> AsyncIterator[bytes]:
205+
"""Buffers an async byte stream and yields chunks of at least `min_size`."""
206+
buffer = BytesBuffer()
207+
async for chunk in data_stream:
208+
if not chunk:
209+
continue
210+
buffer.put(chunk)
211+
while len(buffer) >= min_size * 2:
212+
yield buffer.get(min_size)
213+
if len(buffer) > 0:
214+
yield buffer.get_all()
215+
216+
137217
def json_serial(obj):
138218
"""JSON serializer for objects not serializable by default json code"""
139219
if isinstance(obj, (dt.datetime, dt.date)):

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)