From 5b5c177184fd0072d6101e027352b1d6d572fd3b Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 21 Apr 2026 19:23:30 +0800 Subject: [PATCH 1/5] support user-defined data parser callable Signed-off-by: 0oshowero0 --- tests/test_client.py | 4 +- tests/test_simple_storage_unit.py | 58 ++- transfer_queue/client.py | 17 +- transfer_queue/interface.py | 36 +- transfer_queue/storage/managers/base.py | 17 +- .../managers/simple_backend_manager.py | 12 +- transfer_queue/storage/simple_backend.py | 6 +- tutorial/basic.ipynb | 477 +++++++++++------- 8 files changed, 433 insertions(+), 194 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index a911511d..7ce8bcbf 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -418,7 +418,7 @@ def client_setup(mock_controller, mock_storage): client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations - async def mock_put_data(data, metadata): + async def mock_put_data(data, metadata, data_parser=None): pass # Just pretend to store the data async def mock_get_data(metadata): @@ -511,7 +511,7 @@ def test_single_controller_multiple_storages(): client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations - async def mock_put_data(data, metadata): + async def mock_put_data(data, metadata, data_parser=None): pass # Just pretend to store the data async def mock_get_data(metadata): diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index c553a1a7..a5fb507c 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -34,11 +34,14 @@ def __init__(self, storage_put_get_address): self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout self.socket.connect(storage_put_get_address) - def send_put(self, client_id, global_indexes, field_data): + def send_put(self, client_id, global_indexes, field_data, data_parser=None): + body = {"global_indexes": global_indexes, "data": field_data} + if data_parser is not None: + body["data_parser"] = data_parser msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_DATA, sender_id=f"mock_client_{client_id}", - body={"global_indexes": global_indexes, "data": field_data}, + body=body, ) self.socket.send_multipart(msg.serialize()) return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) @@ -434,3 +437,54 @@ def test_storage_unit_data_capacity_uses_active_keys(): assert len(storage._active_keys) == 2 storage.put_data({"f": [4]}, global_indexes=[3]) assert storage._active_keys == {0, 1, 3} + + +def test_storage_unit_data_parser(storage_setup): + """Test data_parser functionality in SimpleStorageUnit. + + Writes two columns: + - normal_data: regular tensors, should remain unchanged + - data_to_be_parsed: list of shape descriptors (list of ints) + + data_parser converts shape descriptors into random tensors of those shapes. + """ + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + def create_data_by_shape_parser(field_data): + if "data_to_be_parsed" in field_data: + shapes = field_data["data_to_be_parsed"] + field_data["data_to_be_parsed"] = [torch.randn(shape) for shape in shapes] + return field_data + + # Prepare data: normal_data is a batch tensor, data_to_be_parsed is a list of shape lists + field_data = { + "normal_data": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + "data_to_be_parsed": [[2, 3], [1, 4], [3, 2]], + } + global_indexes = [0, 1, 2] + + # Put with data_parser + response = client.send_put(0, global_indexes, field_data, data_parser=create_data_by_shape_parser) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"Put failed: {response.body}" + + # Get back + response = client.send_get(0, global_indexes, ["normal_data", "data_to_be_parsed"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + + result = response.body["data"] + + # Verify normal_data is unchanged + torch.testing.assert_close(result["normal_data"][0], torch.tensor([1.0, 2.0])) + torch.testing.assert_close(result["normal_data"][1], torch.tensor([3.0, 4.0])) + torch.testing.assert_close(result["normal_data"][2], torch.tensor([5.0, 6.0])) + + # Verify data_to_be_parsed shapes match the input shape descriptors + expected_shapes = [(2, 3), (1, 4), (3, 2)] + for i, expected_shape in enumerate(expected_shapes): + actual_shape = tuple(result["data_to_be_parsed"][i].shape) + assert actual_shape == expected_shape, ( + f"Shape mismatch at index {i}: expected {expected_shape}, got {actual_shape}" + ) + + client.close() diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 79b5c5d2..85df0fdc 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -324,6 +324,7 @@ async def async_put( data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> BatchMeta: """Asynchronously write data to storage units based on metadata. @@ -342,6 +343,9 @@ async def async_put( metadata: Records the metadata of a batch of data samples, containing index and storage unit information. If None, metadata will be auto-generated. partition_id: Target data partition id (required if metadata is not provided) + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. Receives a dict of field_name -> batched values and should + return a dict with the same structure. Only supported by SimpleStorage. Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved @@ -411,7 +415,7 @@ async def async_put( with limit_pytorch_auto_parallel_threads( target_num_threads=TQ_NUM_THREADS, info=f"[{self.client_id}] async_put" ): - await self.storage_manager.put_data(data, metadata) + await self.storage_manager.put_data(data, metadata, data_parser=data_parser) await self.async_set_custom_meta(metadata) @@ -1279,7 +1283,11 @@ def set_custom_meta(self, metadata: BatchMeta) -> None: return self._set_custom_meta(metadata=metadata) def put( - self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None + self, + data: TensorDict, + metadata: Optional[BatchMeta] = None, + partition_id: Optional[str] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> BatchMeta: """Synchronously write data to storage units based on metadata. @@ -1298,6 +1306,9 @@ def put( metadata: Records the metadata of a batch of data samples, containing index and storage unit information. If None, metadata will be auto-generated. partition_id: Target data partition id (required if metadata is not provided) + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. Receives a dict of field_name -> batched values and should + return a dict with the same structure. Only supported by SimpleStorage. Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved @@ -1336,7 +1347,7 @@ def put( >>> # This will create metadata in "insert" mode internally. >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) """ - return self._put(data=data, metadata=metadata, partition_id=partition_id) + return self._put(data=data, metadata=metadata, partition_id=partition_id, data_parser=data_parser) def get_data(self, metadata: BatchMeta) -> TensorDict: """Synchronously fetch data from storage units and organize into TensorDict. diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 21f1e8ae..93eb05d1 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -19,7 +19,7 @@ import subprocess import time from importlib import resources -from typing import Any, Optional +from typing import Any, Callable, Optional from urllib.parse import urlparse import ray @@ -388,6 +388,7 @@ def kv_put( partition_id: str, fields: Optional[TensorDict | dict[str, Any]] = None, tag: Optional[dict[str, Any]] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> KVBatchMeta: """Put a single key-value pair to TransferQueue. @@ -403,6 +404,9 @@ def kv_put( If dict is provided, tensors will be unsqueezed to add batch dimension. If not provided, will only update the newly given tag to the key. tag: Optional metadata tag to associate with the key + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. Receives a dict of field_name -> batched values and should + return a dict with the same structure. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. @@ -459,7 +463,7 @@ def kv_put( raise ValueError("`fields` can only be dict or TensorDict") # After put, batch_meta.field_names will include the new fields written by user - batch_meta = tq_client.put(fields, batch_meta) + batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser) else: # Directly update custom_meta (tag) to controller tq_client.set_custom_meta(batch_meta) @@ -476,7 +480,11 @@ def kv_put( def kv_batch_put( - keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None + keys: list[str], + partition_id: str, + fields: Optional[TensorDict] = None, + tags: Optional[list[dict[str, Any]]] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> KVBatchMeta: """Put multiple key-value pairs to TransferQueue in batch. @@ -489,6 +497,9 @@ def kv_batch_put( fields: TensorDict containing data for all keys. Must have batch_size == len(keys). If not provided, will only update the newly given tags to the keys. tags: List of metadata tags, one for each key + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. Receives a dict of field_name -> batched values and should + return a dict with the same structure. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. @@ -542,7 +553,7 @@ def kv_batch_put( # 3. put data if fields is not None: # After put, batch_meta.field_names will include the new fields written by user - batch_meta = tq_client.put(fields, batch_meta) + batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser) else: # Directly update custom_meta (tags) to controller tq_client.set_custom_meta(batch_meta) @@ -742,6 +753,7 @@ async def async_kv_put( partition_id: str, fields: Optional[TensorDict | dict[str, Any]] = None, tag: Optional[dict[str, Any]] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> KVBatchMeta: """Asynchronously put a single key-value pair to TransferQueue. @@ -757,6 +769,9 @@ async def async_kv_put( If dict is provided, tensors will be unsqueezed to add batch dimension. If not provided, will only update the newly given tag to the key. tag: Optional metadata tag to associate with the key + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. Receives a dict of field_name -> batched values and should + return a dict with the same structure. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. @@ -814,7 +829,7 @@ async def async_kv_put( raise ValueError("`fields` can only be dict or TensorDict") # After put, batch_meta.field_names will include the new fields written by user - batch_meta = await tq_client.async_put(fields, batch_meta) + batch_meta = await tq_client.async_put(fields, batch_meta, data_parser=data_parser) else: # Directly update custom_meta (tag) to controller await tq_client.async_set_custom_meta(batch_meta) @@ -831,7 +846,11 @@ async def async_kv_put( async def async_kv_batch_put( - keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None + keys: list[str], + partition_id: str, + fields: Optional[TensorDict] = None, + tags: Optional[list[dict[str, Any]]] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> KVBatchMeta: """Asynchronously put multiple key-value pairs to TransferQueue in batch. @@ -844,6 +863,9 @@ async def async_kv_batch_put( fields: TensorDict containing data for all keys. Must have batch_size == len(keys). If not provided, will only update the newly given tags to the keys. tags: List of metadata tags, one for each key + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. Receives a dict of field_name -> batched values and should + return a dict with the same structure. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. @@ -896,7 +918,7 @@ async def async_kv_batch_put( # 3. put data if fields is not None: # After put, batch_meta.field_names will include the new fields written by user - batch_meta = await tq_client.async_put(fields, batch_meta) + batch_meta = await tq_client.async_put(fields, batch_meta, data_parser=data_parser) else: # Directly update custom_meta (tags) to controller await tq_client.async_set_custom_meta(batch_meta) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 42f17db8..168cd8ba 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -21,7 +21,7 @@ import weakref from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Any, Optional +from typing import Any, Callable, Optional from uuid import uuid4 import ray @@ -306,13 +306,17 @@ async def notify_data_update( pass @abstractmethod - async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: + async def put_data( + self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + ) -> None: """ Put data into the storage backend. Args: data: Data to be put into the storage. metadata: BatchMeta of the corresponding data. + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. Only supported by SimpleStorage backend. """ raise NotImplementedError("Subclasses must implement put_data") @@ -561,10 +565,17 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): ) return shapes, dtypes, custom_backend_meta_list - async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: + async def put_data( + self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + ) -> None: """ Store tensor data in the backend storage and notify the controller. """ + if data_parser is not None: + raise NotImplementedError( + "data_parser is not supported for KV-based backends (MooncakeStore, Yuanrong, RayStore)." + ) + num_samples = len(metadata.global_indexes) if data.batch_size[0] != num_samples: raise ValueError(f"Batch size of data ({data.batch_size[0]}) does not match expected ({num_samples})") diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 00d87822..71bcabb9 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -21,7 +21,7 @@ from collections.abc import Mapping from functools import wraps from operator import itemgetter -from typing import Any, Callable, NamedTuple +from typing import Any, Callable, NamedTuple, Optional from uuid import uuid4 import torch @@ -285,7 +285,9 @@ def _select_by_positions(field_data, positions: list[int]): else: return field_data[positions] - async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: + async def put_data( + self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + ) -> None: """ Send data to remote StorageUnit based on metadata. @@ -295,6 +297,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: Args: data: TensorDict containing the data to store. metadata: BatchMeta containing storage location information. + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. Executed distributedly on each SimpleStorageUnit. """ logger.debug(f"[{self.storage_manager_id}]: receive put_data request, putting {metadata.size} samples.") @@ -312,6 +316,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: group.global_indexes, {f: self._select_by_positions(data[f], group.batch_positions) for f in data.keys()}, target_storage_unit=su_id, + data_parser=data_parser, ) for su_id, group in routing.items() ] @@ -341,6 +346,7 @@ async def _put_to_single_storage_unit( global_indexes: list[int], storage_data: dict[str, Any], target_storage_unit: str, + data_parser: Optional[Callable[[Any], Any]] = None, socket: zmq.Socket = None, ): """ @@ -351,7 +357,7 @@ async def _put_to_single_storage_unit( request_type=ZMQRequestType.PUT_DATA, # type: ignore[arg-type] sender_id=self.storage_manager_id, receiver_id=target_storage_unit, - body={"global_indexes": global_indexes, "data": storage_data}, + body={"global_indexes": global_indexes, "data": storage_data, "data_parser": data_parser}, ) try: diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index e6908e53..049efc15 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -342,10 +342,14 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: """ try: global_indexes = data_parts.body["global_indexes"] - field_data = data_parts.body["data"] # field_data should be a TensorDict. + field_data = data_parts.body["data"] # field_data should be a dict. + data_parser = data_parts.body.get("data_parser", None) + with limit_pytorch_auto_parallel_threads( target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_put" ): + if data_parser is not None: + field_data = data_parser(field_data) self.storage_data.put_data(field_data, global_indexes) # After put operation finish, send a message to the client diff --git a/tutorial/basic.ipynb b/tutorial/basic.ipynb index 5e5944d2..bdf85a39 100644 --- a/tutorial/basic.ipynb +++ b/tutorial/basic.ipynb @@ -4,53 +4,35 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# TransferQueue Tutorial — Basic Usage\n", - "\n", - "This notebook walks through the core **Key-Value (KV) interface** of\n", - "[TransferQueue](https://github.com/Ascend/TransferQueue), an asynchronous\n", - "streaming data management module for efficient post-training workflows.\n", - "\n", - "**What you will learn:**\n", - "\n", - "1. Initialise TransferQueue (with Ray)\n", - "2. Store a single sample — `kv_put`\n", - "3. Store a batch of samples — `kv_batch_put`\n", - "4. Retrieve data — `kv_batch_get`\n", - "5. List stored keys & tags — `kv_list`\n", - "6. Partial-key and partial-field retrieval\n", - "7. Updating fields incrementally\n", - "8. Working with nested (variable-length) tensors\n", - "9. Storing variable-size image data\n", - "10. Storing non-tensor data (`NonTensorData` / `NonTensorStack`)\n", - "11. Multiple partitions\n", - "12. Clean up — `kv_clear` / `close`\n", - "\n", - "> **Prerequisites:** `pip install TransferQueue` (or install from source). \n", - "> Ray will be started automatically in this notebook." + "# TransferQueue Tutorial — Basic Usage\n\nThis notebook walks through the core **Key-Value (KV) interface** of\n[TransferQueue](https://github.com/Ascend/TransferQueue), an asynchronous\nstreaming data management module for efficient post-training workflows.\n\n**What you will learn:**\n\n1. Initialise TransferQueue (with Ray)\n2. Store a single sample — `kv_put`\n3. Store a batch of samples — `kv_batch_put`\n4. Retrieve data — `kv_batch_get`\n5. List stored keys & tags — `kv_list`\n6. Partial-key and partial-field retrieval\n7. Updating fields incrementally\n8. Working with nested (variable-length) tensors\n9. Storing variable-size image data\n10. Storing non-tensor data (`NonTensorData` / `NonTensorStack`)\n11. Multiple partitions\n12. Clean up — `kv_clear` / `close`\n\n> **Prerequisites:** `pip install TransferQueue` (or install from source). \n> Ray will be started automatically in this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. Initialization\n", - "\n", - "TransferQueue runs on top of [Ray](https://www.ray.io/). \n", - "We start Ray, then call `tq.init()` with a minimal configuration that uses the\n", - "built-in **SimpleStorage** backend." + "## 1. Initialization\n\nTransferQueue runs on top of [Ray](https://www.ray.io/). \nWe start Ray, then call `tq.init()` with a minimal configuration that uses the\nbuilt-in **SimpleStorage** backend." ] }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.241363Z", + "start_time": "2026-04-21T11:17:06.290229Z" + } + }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2026-03-27 23:06:04,557\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n", - "/opt/miniconda3/envs/verl/lib/python3.11/site-packages/ray/_private/worker.py:2062: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0\n", + "/Users/hanzhenyu/anaconda3/envs/deep/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "2026-04-21 19:17:06,587\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", + "2026-04-21 19:17:11,180\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n", + "/Users/hanzhenyu/anaconda3/envs/deep/lib/python3.11/site-packages/ray/_private/worker.py:2062: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0\n", " warnings.warn(\n" ] }, @@ -58,7 +40,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33m(raylet)\u001b[0m It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace=\"798d49c0-a4e8-4877-813b-bce2d966c4eb\", ...)\n", "TransferQueue is ready!\n" ] } @@ -95,19 +76,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 2. Store a Single Sample — `kv_put`\n", - "\n", - "`kv_put` stores **one** key-value pair. \n", - "- `key` — a unique string identifier for the sample \n", - "- `partition_id` — a logical namespace (like a table name) \n", - "- `fields` — a `dict` of tensors **or** a `TensorDict` \n", - "- `tag` — optional metadata dict attached to the key" + "## 2. Store a Single Sample — `kv_put`\n\n`kv_put` stores **one** key-value pair. \n- `key` — a unique string identifier for the sample \n- `partition_id` — a logical namespace (like a table name) \n- `fields` — a `dict` of tensors **or** a `TensorDict` \n- `tag` — optional metadata dict attached to the key" ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.323637Z", + "start_time": "2026-04-21T11:17:15.269690Z" + } + }, "outputs": [ { "name": "stdout", @@ -131,14 +111,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can also pass a pre-built `TensorDict` directly (the batch dimension\n", - "must be 1 for `kv_put`):" + "You can also pass a pre-built `TensorDict` directly (the batch dimension\nmust be 1 for `kv_put`):" ] }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.372889Z", + "start_time": "2026-04-21T11:17:15.331445Z" + } + }, "outputs": [ { "name": "stdout", @@ -170,17 +154,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 3. Store a Batch of Samples — `kv_batch_put`\n", - "\n", - "When you have multiple samples, `kv_batch_put` is more efficient than\n", - "calling `kv_put` in a loop. The `fields` TensorDict must have\n", - "`batch_size == len(keys)`." + "## 3. Store a Batch of Samples — `kv_batch_put`\n\nWhen you have multiple samples, `kv_batch_put` is more efficient than\ncalling `kv_put` in a loop. The `fields` TensorDict must have\n`batch_size == len(keys)`." ] }, { "cell_type": "code", - "execution_count": 4, - "metadata": {}, + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.421902Z", + "start_time": "2026-04-21T11:17:15.382726Z" + } + }, "outputs": [ { "name": "stdout", @@ -215,15 +200,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Retrieve Data — `kv_batch_get`\n", - "\n", - "Retrieve samples by key(s). The result is always a `TensorDict`." + "## 4. Retrieve Data — `kv_batch_get`\n\nRetrieve samples by key(s). The result is always a `TensorDict`." ] }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.516473Z", + "start_time": "2026-04-21T11:17:15.437341Z" + } + }, "outputs": [ { "name": "stdout", @@ -248,8 +236,13 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.548120Z", + "start_time": "2026-04-21T11:17:15.518716Z" + } + }, "outputs": [ { "name": "stdout", @@ -285,19 +278,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 5. List Keys & Tags — `kv_list`\n", - "\n", - "`kv_list` returns a nested dict:\n", - "```\n", - "{ partition_id: { key: tag_dict, ... }, ... }\n", - "```\n", - "Pass `partition_id` to filter, or omit it to see everything." + "## 5. List Keys & Tags — `kv_list`\n\n`kv_list` returns a nested dict:\n```\n{ partition_id: { key: tag_dict, ... }, ... }\n```\nPass `partition_id` to filter, or omit it to see everything." ] }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.581880Z", + "start_time": "2026-04-21T11:17:15.549628Z" + } + }, "outputs": [ { "name": "stdout", @@ -326,24 +318,25 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 6. Partial-Key and Partial-Field Retrieval\n", - "\n", - "You don't have to retrieve *all* keys or *all* fields at once." + "## 6. Partial-Key and Partial-Field Retrieval\n\nYou don't have to retrieve *all* keys or *all* fields at once." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### 6a. Partial Keys\n", - "\n", - "Just pass a subset of the keys you stored." + "### 6a. Partial Keys\n\nJust pass a subset of the keys you stored." ] }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.630111Z", + "start_time": "2026-04-21T11:17:15.598888Z" + } + }, "outputs": [ { "name": "stdout", @@ -365,15 +358,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### 6b. Partial Fields\n", - "\n", - "Use the `fields` argument to select specific columns." + "### 6b. Partial Fields\n\nUse the `fields` argument to select specific columns." ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": {}, + "execution_count": 10, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.675806Z", + "start_time": "2026-04-21T11:17:15.637681Z" + } + }, "outputs": [ { "name": "stdout", @@ -404,17 +400,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 7. Updating Fields Incrementally\n", - "\n", - "TransferQueue tracks each field (column) independently per key (row). \n", - "You can **add new fields** to existing keys with another `kv_put` /\n", - "`kv_batch_put` call — the earlier fields are preserved." + "## 7. Updating Fields Incrementally\n\nTransferQueue tracks each field (column) independently per key (row). \nYou can **add new fields** to existing keys with another `kv_put` /\n`kv_batch_put` call — the earlier fields are preserved." ] }, { "cell_type": "code", - "execution_count": 10, - "metadata": {}, + "execution_count": 11, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.715913Z", + "start_time": "2026-04-21T11:17:15.676965Z" + } + }, "outputs": [ { "name": "stdout", @@ -446,22 +443,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 8. Working with Nested (Variable-Length) Tensors\n", - "\n", - "In many NLP and RL workloads each sample has a **different sequence length**\n", - "(e.g. generated responses). PyTorch represents these as\n", - "[nested tensors](https://pytorch.org/docs/stable/nested.html) with the\n", - "**jagged layout** (`layout=torch.jagged`), and TransferQueue handles them\n", - "natively.\n", - "\n", - "> **Note:** Because individual samples have different shapes, you must use\n", - "> `kv_batch_put` (not `kv_put`) to store nested tensors." + "## 8. Working with Nested (Variable-Length) Tensors\n\nIn many NLP and RL workloads each sample has a **different sequence length**\n(e.g. generated responses). PyTorch represents these as\n[nested tensors](https://pytorch.org/docs/stable/nested.html) with the\n**jagged layout** (`layout=torch.jagged`), and TransferQueue handles them\nnatively.\n\n> **Note:** Because individual samples have different shapes, you must use\n> `kv_batch_put` (not `kv_put`) to store nested tensors." ] }, { "cell_type": "code", - "execution_count": 11, - "metadata": {}, + "execution_count": 12, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.764212Z", + "start_time": "2026-04-21T11:17:15.717341Z" + } + }, "outputs": [ { "name": "stdout", @@ -503,8 +496,13 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": {}, + "execution_count": 13, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.807560Z", + "start_time": "2026-04-21T11:17:15.772541Z" + } + }, "outputs": [ { "name": "stdout", @@ -535,14 +533,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Partial-key retrieval works the same way with nested tensors — only the\n", - "requested samples are returned:" + "Partial-key retrieval works the same way with nested tensors — only the\nrequested samples are returned:" ] }, { "cell_type": "code", - "execution_count": 13, - "metadata": {}, + "execution_count": 14, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.833638Z", + "start_time": "2026-04-21T11:17:15.808823Z" + } + }, "outputs": [ { "name": "stdout", @@ -578,15 +580,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Higher-dimensional nested tensors work too. Here each sample is a 3D\n", - "tensor with a variable first dimension (e.g. a different number of\n", - "attention heads or generated candidates):" + "Higher-dimensional nested tensors work too. Here each sample is a 3D\ntensor with a variable first dimension (e.g. a different number of\nattention heads or generated candidates):" ] }, { "cell_type": "code", - "execution_count": 14, - "metadata": {}, + "execution_count": 15, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.883307Z", + "start_time": "2026-04-21T11:17:15.834875Z" + } + }, "outputs": [ { "name": "stdout", @@ -633,23 +638,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 9. Storing Variable-Size Image Data\n", - "\n", - "A common multimodal scenario: each sample in a batch contains a\n", - "**different number of images**, and each image has a **different\n", - "resolution**. We can model this with a list of nested tensors — one\n", - "nested tensor per sample — wrapped inside a `TensorDict`.\n", - "\n", - "Since the data is doubly ragged (variable count *and* variable size),\n", - "we store each image as a flattened 1-D tensor and pack all images per\n", - "sample into a single jagged nested tensor. This way every sample is\n", - "one element of the batch, yet images retain their individual sizes." + "## 9. Storing Variable-Size Image Data\n\nA common multimodal scenario: each sample in a batch contains a\n**different number of images**, and each image has a **different\nresolution**. We can model this with a list of nested tensors — one\nnested tensor per sample — wrapped inside a `TensorDict`.\n\nSince the data is doubly ragged (variable count *and* variable size),\nwe store each image as a flattened 1-D tensor and pack all images per\nsample into a single jagged nested tensor. This way every sample is\none element of the batch, yet images retain their individual sizes." ] }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, + "execution_count": 16, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.922965Z", + "start_time": "2026-04-21T11:17:15.884483Z" + } + }, "outputs": [ { "name": "stdout", @@ -717,8 +717,13 @@ }, { "cell_type": "code", - "execution_count": 16, - "metadata": {}, + "execution_count": 17, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.956379Z", + "start_time": "2026-04-21T11:17:15.924099Z" + } + }, "outputs": [ { "name": "stdout", @@ -758,22 +763,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 10. Storing Non-Tensor Data — `NonTensorData` / `NonTensorStack`\n", - "\n", - "Not every field is a numeric tensor. Prompts, file paths, JSON metadata,\n", - "or arbitrary Python objects can be stored as **non-tensor data** using\n", - "tensordict's `NonTensorData` and `NonTensorStack`.\n", - "\n", - "- `NonTensorData` wraps a **single** Python object (string, dict, list, …).\n", - "- `NonTensorStack` wraps a **batch** of Python objects — one per sample.\n", - "\n", - "TransferQueue serialises them transparently alongside regular tensors." + "## 10. Storing Non-Tensor Data — `NonTensorData` / `NonTensorStack`\n\nNot every field is a numeric tensor. Prompts, file paths, JSON metadata,\nor arbitrary Python objects can be stored as **non-tensor data** using\ntensordict's `NonTensorData` and `NonTensorStack`.\n\n- `NonTensorData` wraps a **single** Python object (string, dict, list, …).\n- `NonTensorStack` wraps a **batch** of Python objects — one per sample.\n\nTransferQueue serialises them transparently alongside regular tensors." ] }, { "cell_type": "code", - "execution_count": 17, - "metadata": {}, + "execution_count": 18, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:15.991612Z", + "start_time": "2026-04-21T11:17:15.957250Z" + } + }, "outputs": [ { "name": "stdout", @@ -812,8 +813,13 @@ }, { "cell_type": "code", - "execution_count": 18, - "metadata": {}, + "execution_count": 19, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:16.059711Z", + "start_time": "2026-04-21T11:17:16.022937Z" + } + }, "outputs": [ { "name": "stdout", @@ -870,16 +876,146 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 11. Multiple Partitions\n", + "## 11. Lazy Data Parsing with `data_parser`\n", "\n", - "Partitions provide logical isolation — the same key name can exist in\n", - "different partitions without conflict." + "Sometimes you want to store lightweight **references** (e.g. URLs, file paths, or shape descriptors)\n", + "and defer the expensive loading / decoding until the data reaches the storage unit.\n", + "\n", + "`kv_put` / `kv_batch_put` accept an optional `data_parser` callable. It is executed **inside**\n", + "each `SimpleStorageUnit` at put time, and receives the raw `field_data` dict during this put request. The callable\n", + "should return a dict with the same structure, replacing reference values by parsed data.\n", + "\n", + "> **Note:** `data_parser` is only supported by the **SimpleStorage** backend." ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:16.088804Z", + "start_time": "2026-04-21T11:17:16.062652Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Stored 3 samples with data parser\n" + ] + } + ], + "source": [ + "def create_data_by_shape_parser(field_data):\n", + " \"\"\"Convert shape descriptors in 'data_to_be_parsed' into random tensors.\"\"\"\n", + " if \"data_to_be_parsed\" in field_data:\n", + " shapes = field_data[\"data_to_be_parsed\"]\n", + " field_data[\"data_to_be_parsed\"] = [torch.randn(shape) for shape in shapes]\n", + " return field_data\n", + "\n", + "\n", + "parser_keys = [\"parser_0\", \"parser_1\", \"parser_2\"]\n", + "\n", + "parser_fields = TensorDict(\n", + " {\n", + " # Normal data column that will not be modified by data parser\n", + " \"normal_data\": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),\n", + " # Each sample carries a list of ints describing the desired tensor shape\n", + " \"data_to_be_parsed\": NonTensorStack([2, 3], [1, 4], [3, 2]),\n", + " },\n", + " batch_size=3,\n", + ")\n", + "\n", + "tq.kv_batch_put(\n", + " keys=parser_keys,\n", + " partition_id=\"train\",\n", + " fields=parser_fields,\n", + " data_parser=create_data_by_shape_parser,\n", + ")\n", + "print(\"Stored 3 samples with data parser\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:16.132204Z", + "start_time": "2026-04-21T11:17:16.089857Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2026-04-21 19:17:16,096 - WARNING - transfer_queue.storage.managers.simple_backend_manager - Failed to pack nested tensor with jagged layout. Falling back to strided layout. Detailed error: Cannot represent given tensor list as a nested tensor with the jagged layout. Note that the jagged layout only allows for a single ragged dimension. For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "normal_data:\n", + "tensor([[1., 2.],\n", + " [3., 4.],\n", + " [5., 6.]])\n", + "\n", + "data_to_be_parsed shapes:\n", + " sample 0: (2, 3)\n", + " sample 1: (1, 4)\n", + " sample 2: (3, 2)\n", + "\n", + "Assertions passed!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/hanzhenyu/anaconda3/envs/deep/lib/python3.11/site-packages/torch/nested/__init__.py:120: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/NestedTensorImpl.cpp:182.)\n", + " return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None)\n" + ] + } + ], + "source": [ + "result_parser = tq.kv_batch_get(keys=parser_keys, partition_id=\"train\")\n", + "\n", + "# normal_data should be unchanged\n", + "print(\"normal_data:\")\n", + "print(result_parser[\"normal_data\"])\n", + "\n", + "# data_to_be_parsed should now be tensors with the requested shapes\n", + "expected_shapes = [(2, 3), (1, 4), (3, 2)]\n", + "print(\"\\ndata_to_be_parsed shapes:\")\n", + "for i, expected in enumerate(expected_shapes):\n", + " actual = tuple(result_parser[\"data_to_be_parsed\"][i].shape)\n", + " assert actual == expected, f\"Mismatch at {i}: {actual} != {expected}\"\n", + " print(f\" sample {i}: {actual}\")\n", + "\n", + "print(\"\\nAssertions passed!\")\n", + "\n", + "# Clean up\n", + "tq.kv_clear(keys=parser_keys, partition_id=\"train\")" + ] + }, + { + "cell_type": "markdown", "metadata": {}, + "source": [ + "## 12. Multiple Partitions\n\nPartitions provide logical isolation — the same key name can exist in\ndifferent partitions without conflict." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:16.210207Z", + "start_time": "2026-04-21T11:17:16.142608Z" + } + }, "outputs": [ { "name": "stdout", @@ -905,15 +1041,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 12. Clean Up — `kv_clear` and `close`\n", - "\n", - "Remove specific keys with `kv_clear`, then shut down the system with `tq.close()`." + "## 13. Clean Up — `kv_clear` and `close`\n\nRemove specific keys with `kv_clear`, then shut down the system with `tq.close()`." ] }, { "cell_type": "code", - "execution_count": 20, - "metadata": {}, + "execution_count": 23, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:16.249278Z", + "start_time": "2026-04-21T11:17:16.211357Z" + } + }, "outputs": [ { "name": "stdout", @@ -939,8 +1078,13 @@ }, { "cell_type": "code", - "execution_count": 21, - "metadata": {}, + "execution_count": 24, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:17.385182Z", + "start_time": "2026-04-21T11:17:16.250683Z" + } + }, "outputs": [ { "name": "stdout", @@ -960,31 +1104,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "---\n", - "\n", - "## Summary\n", - "\n", - "| Operation | Function | Notes |\n", - "|---|---|---|\n", - "| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n", - "| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n", - "| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n", - "| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n", - "| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n", - "| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n", - "| Close | `tq.close()` | Tears down controller & storage |\n", - "\n", - "For **async** variants, use `async_kv_put`, `async_kv_batch_put`,\n", - "`async_kv_batch_get`, `async_kv_list`, and `async_kv_clear`.\n", - "\n", - "For low-level, metadata-based access, see `tq.get_client()` and the\n", - "[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)." + "---\n\n## Summary\n\n| Operation | Function | Notes |\n|---|---|---|\n| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n| Put with parser | `tq.kv_batch_put(..., data_parser=fn)` | Only for **SimpleStorage**; fn receives dict |\n| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n| Close | `tq.close()` | Tears down controller & storage |\n\nFor **async** variants, use `async_kv_put`, `async_kv_batch_put`,\n`async_kv_batch_get`, `async_kv_list`, and `async_kv_clear`.\n\nFor low-level, metadata-based access, see `tq.get_client()` and the\n[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)." ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 24, + "metadata": { + "ExecuteTime": { + "end_time": "2026-04-21T11:17:17.455073Z", + "start_time": "2026-04-21T11:17:17.409870Z" + } + }, "outputs": [], "source": [] } From c6d57e2bfbbf83d1501a43971b53840a73f70d13 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Tue, 21 Apr 2026 20:16:22 +0800 Subject: [PATCH 2/5] fix comments Signed-off-by: 0oshowero0 --- tests/test_simple_storage_unit.py | 81 ++++++++++++++++++++++++ transfer_queue/storage/simple_backend.py | 4 ++ transfer_queue/utils/serial_utils.py | 6 +- tutorial/basic.ipynb | 10 +-- 4 files changed, 90 insertions(+), 11 deletions(-) diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index a5fb507c..e0a54de5 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -488,3 +488,84 @@ def create_data_by_shape_parser(field_data): ) client.close() + + +def test_storage_unit_data_parser_callable_types(storage_setup): + """Test that various callable types (partial, callable class) work as data_parser.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + from functools import partial + + # 1. Test functools.partial + def _partial_parser(field_data, prefix): + if "text" in field_data: + field_data["text"] = [f"{prefix}{t}" for t in field_data["text"]] + return field_data + + partial_parser = partial(_partial_parser, prefix="parsed_") + + response = client.send_put( + 0, + [0, 1], + {"text": ["a", "b"]}, + data_parser=partial_parser, + ) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"partial parser failed: {response.body}" + + response = client.send_get(0, [0, 1], ["text"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + assert response.body["data"]["text"] == ["parsed_a", "parsed_b"] + + # 2. Test callable class instance + class CallableParser: + def __call__(self, field_data): + if "value" in field_data: + field_data["value"] = [v * 2 for v in field_data["value"]] + return field_data + + callable_parser = CallableParser() + response = client.send_put( + 0, + [2, 3], + {"value": [1, 2]}, + data_parser=callable_parser, + ) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"callable class parser failed: {response.body}" + + response = client.send_get(0, [2, 3], ["value"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + assert response.body["data"]["value"] == [2, 4] + + client.close() + + +def test_storage_unit_data_parser_validation(storage_setup): + """Test that invalid data_parser inputs produce clear error messages.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + # 1. Non-callable data_parser should return a clear TypeError + response = client.send_put( + 0, + [0], + {"data": [1]}, + data_parser="not_callable", + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser must be callable" in response.body["message"] + + # 2. data_parser returning non-dict should return a clear TypeError + def bad_parser(field_data): + return "not_a_dict" + + response = client.send_put( + 0, + [1], + {"data": [1]}, + data_parser=bad_parser, + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser must return a dict" in response.body["message"] + + client.close() diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index 049efc15..d7dca175 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -349,7 +349,11 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_put" ): if data_parser is not None: + if not callable(data_parser): + raise TypeError(f"data_parser must be callable, got {type(data_parser).__name__}") field_data = data_parser(field_data) + if not isinstance(field_data, dict): + raise TypeError(f"data_parser must return a dict, got {type(field_data).__name__}") self.storage_data.put_data(field_data, global_indexes) # After put operation finish, send a message to the client diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index fd57b807..90f8592c 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -23,7 +23,6 @@ import warnings from collections.abc import Sequence from contextvars import ContextVar -from types import FunctionType from typing import Any, TypeAlias import cloudpickle @@ -118,8 +117,9 @@ def enc_hook(self, obj: Any) -> Any: # Only true object arrays (or structured dtypes with object fields) reach here return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) - if isinstance(obj, FunctionType): - # cloudpickle for functions/methods + if callable(obj): + # cloudpickle for arbitrary callables (functions, lambdas, functools.partial, + # callable class instances, bound methods, etc.) return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) # Fallback to pickle for unknown types diff --git a/tutorial/basic.ipynb b/tutorial/basic.ipynb index bdf85a39..a64ce008 100644 --- a/tutorial/basic.ipynb +++ b/tutorial/basic.ipynb @@ -28,12 +28,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/hanzhenyu/anaconda3/envs/deep/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "2026-04-21 19:17:06,587\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n", - "2026-04-21 19:17:11,180\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n", - "/Users/hanzhenyu/anaconda3/envs/deep/lib/python3.11/site-packages/ray/_private/worker.py:2062: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0\n", - " warnings.warn(\n" + "2026-04-21 19:17:11,180\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n" ] }, { @@ -974,8 +969,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/hanzhenyu/anaconda3/envs/deep/lib/python3.11/site-packages/torch/nested/__init__.py:120: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. We recommend specifying layout=torch.jagged when constructing a nested tensor, as this layout receives active development, has better operator coverage, and works with torch.compile. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/NestedTensorImpl.cpp:182.)\n", - " return torch._nested_tensor_from_tensor_list(ts, dtype, None, device, None)\n" + "" ] } ], From 5fcd9fb6719975d9da946066168896bd8809b16f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 22 Apr 2026 14:52:49 +0800 Subject: [PATCH 3/5] add detailed input/output description & better example Signed-off-by: 0oshowero0 --- tests/test_simple_storage_unit.py | 42 ++ transfer_queue/client.py | 20 +- transfer_queue/interface.py | 44 +- transfer_queue/storage/managers/base.py | 9 +- .../managers/simple_backend_manager.py | 9 +- transfer_queue/storage/simple_backend.py | 36 ++ tutorial/basic.ipynb | 388 ++++++++---------- 7 files changed, 321 insertions(+), 227 deletions(-) diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index e0a54de5..2519b975 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -568,4 +568,46 @@ def bad_parser(field_data): assert response.request_type == ZMQRequestType.PUT_ERROR assert "data_parser must return a dict" in response.body["message"] + # 3. data_parser deleting a key should return a clear ValueError + def delete_key_parser(field_data): + del field_data["data"] + return field_data + + response = client.send_put( + 0, + [2], + {"data": [1], "extra": [2]}, + data_parser=delete_key_parser, + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser must not change dict keys" in response.body["message"] + + # 4. data_parser adding a key should return a clear ValueError + def add_key_parser(field_data): + field_data["new_key"] = [999] + return field_data + + response = client.send_put( + 0, + [3], + {"data": [1]}, + data_parser=add_key_parser, + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser must not change dict keys" in response.body["message"] + + # 5. data_parser changing element count should return a clear ValueError + def wrong_len_parser(field_data): + field_data["data"] = field_data["data"][:-1] + return field_data + + response = client.send_put( + 0, + [4, 5], + {"data": [1, 2]}, + data_parser=wrong_len_parser, + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser changed the number of elements" in response.body["message"] + client.close() diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 85df0fdc..df9825b7 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -344,8 +344,14 @@ async def async_put( storage unit information. If None, metadata will be auto-generated. partition_id: Target data partition id (required if metadata is not provided) data_parser: Optional callable to parse reference data (e.g., URLs) into real - content. Receives a dict of field_name -> batched values and should - return a dict with the same structure. Only supported by SimpleStorage. + content. The input is a slice of the `data` parameter, in plain + dict format (not TensorDict), mapping field_name -> batched values. + For a regular tensor column the value is a batched tensor; for + nested tensors (jagged or strided) and NonTensorStack columns + the values are extracted into a list. It must return a dict of + the same format with the exact same keys and the same number of + elements per column; do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved @@ -1307,8 +1313,14 @@ def put( storage unit information. If None, metadata will be auto-generated. partition_id: Target data partition id (required if metadata is not provided) data_parser: Optional callable to parse reference data (e.g., URLs) into real - content. Receives a dict of field_name -> batched values and should - return a dict with the same structure. Only supported by SimpleStorage. + content. The input is a slice of the `data` parameter, in plain + dict format (not TensorDict), mapping field_name -> batched values. + For a regular tensor column the value is a batched tensor; for + nested tensors (jagged or strided) and NonTensorStack columns + the values are extracted into a list. It must return a dict of + the same format with the exact same keys and the same number of + elements per column; do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 93eb05d1..627f5879 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -405,8 +405,15 @@ def kv_put( If not provided, will only update the newly given tag to the key. tag: Optional metadata tag to associate with the key data_parser: Optional callable to parse reference data (e.g., URLs) into real - content. Receives a dict of field_name -> batched values and should - return a dict with the same structure. Only supported by SimpleStorage. + content. The input is a slice of the `fields` parameter passed to + kv_put / kv_batch_put, in plain dict format (not TensorDict), + mapping field_name -> batched values. For a regular tensor column + the value is a batched tensor; for nested tensors (jagged or + strided) and NonTensorStack columns the values are extracted into + a list. It must return a dict of the same format with the exact + same keys and the same number of elements per column; do not + change the inner order of values within each column. Only + supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. @@ -498,8 +505,15 @@ def kv_batch_put( If not provided, will only update the newly given tags to the keys. tags: List of metadata tags, one for each key data_parser: Optional callable to parse reference data (e.g., URLs) into real - content. Receives a dict of field_name -> batched values and should - return a dict with the same structure. Only supported by SimpleStorage. + content. The input is a slice of the `fields` parameter passed to + kv_put / kv_batch_put, in plain dict format (not TensorDict), + mapping field_name -> batched values. For a regular tensor column + the value is a batched tensor; for nested tensors (jagged or + strided) and NonTensorStack columns the values are extracted into + a list. It must return a dict of the same format with the exact + same keys and the same number of elements per column; do not + change the inner order of values within each column. Only + supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. @@ -770,8 +784,15 @@ async def async_kv_put( If not provided, will only update the newly given tag to the key. tag: Optional metadata tag to associate with the key data_parser: Optional callable to parse reference data (e.g., URLs) into real - content. Receives a dict of field_name -> batched values and should - return a dict with the same structure. Only supported by SimpleStorage. + content. The input is a slice of the `fields` parameter passed to + kv_put / kv_batch_put, in plain dict format (not TensorDict), + mapping field_name -> batched values. For a regular tensor column + the value is a batched tensor; for nested tensors (jagged or + strided) and NonTensorStack columns the values are extracted into + a list. It must return a dict of the same format with the exact + same keys and the same number of elements per column; do not + change the inner order of values within each column. Only + supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. @@ -864,8 +885,15 @@ async def async_kv_batch_put( If not provided, will only update the newly given tags to the keys. tags: List of metadata tags, one for each key data_parser: Optional callable to parse reference data (e.g., URLs) into real - content. Receives a dict of field_name -> batched values and should - return a dict with the same structure. Only supported by SimpleStorage. + content. The input is a slice of the `fields` parameter passed to + kv_put / kv_batch_put, in plain dict format (not TensorDict), + mapping field_name -> batched values. For a regular tensor column + the value is a batched tensor; for nested tensors (jagged or + strided) and NonTensorStack columns the values are extracted into + a list. It must return a dict of the same format with the exact + same keys and the same number of elements per column; do not + change the inner order of values within each column. Only + supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 168cd8ba..54622ffb 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -316,7 +316,14 @@ async def put_data( data: Data to be put into the storage. metadata: BatchMeta of the corresponding data. data_parser: Optional callable to parse reference data (e.g., URLs) into real - content. Only supported by SimpleStorage backend. + content. The input is a plain dict (not TensorDict) mapping + field_name -> batched values. For a regular tensor column the + value is a batched tensor; for nested tensors (jagged or strided) + and NonTensorStack columns the values are extracted into a list. + It must return a dict of the same format with the exact same keys + and the same number of elements per column; do not change the + inner order of values within each column. Only supported by + SimpleStorage backend. """ raise NotImplementedError("Subclasses must implement put_data") diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 71bcabb9..09ee54f6 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -298,7 +298,14 @@ async def put_data( data: TensorDict containing the data to store. metadata: BatchMeta containing storage location information. data_parser: Optional callable to parse reference data (e.g., URLs) into real - content. Executed distributedly on each SimpleStorageUnit. + content. The input is a plain dict (not TensorDict) mapping + field_name -> batched values. For a regular tensor column the + value is a batched tensor; for nested tensors (jagged or strided) + and NonTensorStack columns the values are extracted into a list. + It must return a dict of the same format with the exact same keys + and the same number of elements per column; do not change the + inner order of values within each column. Executed distributedly + on each SimpleStorageUnit. """ logger.debug(f"[{self.storage_manager_id}]: receive put_data request, putting {metadata.size} samples.") diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index d7dca175..7f0a0012 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -351,9 +351,45 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: if data_parser is not None: if not callable(data_parser): raise TypeError(f"data_parser must be callable, got {type(data_parser).__name__}") + + original_keys = set(field_data.keys()) + original_lengths = {} + for k, v in field_data.items(): + if hasattr(v, "shape") and isinstance(v.shape, tuple | list) and len(v.shape) > 0: + original_lengths[k] = v.shape[0] + else: + try: + original_lengths[k] = len(v) + except Exception: + original_lengths[k] = None + field_data = data_parser(field_data) + if not isinstance(field_data, dict): raise TypeError(f"data_parser must return a dict, got {type(field_data).__name__}") + + new_keys = set(field_data.keys()) + if new_keys != original_keys: + raise ValueError( + f"data_parser must not change dict keys. " + f"Original keys: {sorted(original_keys)}, got: {sorted(new_keys)}" + ) + + for k, v in field_data.items(): + if hasattr(v, "shape") and isinstance(v.shape, tuple | list) and len(v.shape) > 0: + new_len = v.shape[0] + else: + try: + new_len = len(v) + except Exception: + new_len = None + + orig_len = original_lengths[k] + if orig_len is not None and new_len is not None and orig_len != new_len: + raise ValueError( + f"data_parser changed the number of elements for key '{k}': " + f"expected {orig_len}, got {new_len}" + ) self.storage_data.put_data(field_data, global_indexes) # After put operation finish, send a message to the client diff --git a/tutorial/basic.ipynb b/tutorial/basic.ipynb index a64ce008..2edf4582 100644 --- a/tutorial/basic.ipynb +++ b/tutorial/basic.ipynb @@ -16,19 +16,14 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.241363Z", - "start_time": "2026-04-21T11:17:06.290229Z" - } - }, + "execution_count": 1, + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2026-04-21 19:17:11,180\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n" + "2026-04-22 14:15:11,786\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n" ] }, { @@ -76,13 +71,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.323637Z", - "start_time": "2026-04-21T11:17:15.269690Z" - } - }, + "execution_count": 2, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -111,13 +101,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.372889Z", - "start_time": "2026-04-21T11:17:15.331445Z" - } - }, + "execution_count": 3, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -154,13 +139,8 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.421902Z", - "start_time": "2026-04-21T11:17:15.382726Z" - } - }, + "execution_count": 4, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -200,13 +180,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.516473Z", - "start_time": "2026-04-21T11:17:15.437341Z" - } - }, + "execution_count": 5, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -231,13 +206,8 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.548120Z", - "start_time": "2026-04-21T11:17:15.518716Z" - } - }, + "execution_count": 6, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -278,13 +248,8 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.581880Z", - "start_time": "2026-04-21T11:17:15.549628Z" - } - }, + "execution_count": 7, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -325,13 +290,8 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.630111Z", - "start_time": "2026-04-21T11:17:15.598888Z" - } - }, + "execution_count": 8, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -358,13 +318,8 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.675806Z", - "start_time": "2026-04-21T11:17:15.637681Z" - } - }, + "execution_count": 9, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -400,13 +355,8 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.715913Z", - "start_time": "2026-04-21T11:17:15.676965Z" - } - }, + "execution_count": 10, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -443,13 +393,8 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.764212Z", - "start_time": "2026-04-21T11:17:15.717341Z" - } - }, + "execution_count": 11, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -491,13 +436,8 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.807560Z", - "start_time": "2026-04-21T11:17:15.772541Z" - } - }, + "execution_count": 12, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -533,13 +473,8 @@ }, { "cell_type": "code", - "execution_count": 14, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.833638Z", - "start_time": "2026-04-21T11:17:15.808823Z" - } - }, + "execution_count": 13, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -580,13 +515,8 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.883307Z", - "start_time": "2026-04-21T11:17:15.834875Z" - } - }, + "execution_count": 14, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -638,13 +568,8 @@ }, { "cell_type": "code", - "execution_count": 16, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.922965Z", - "start_time": "2026-04-21T11:17:15.884483Z" - } - }, + "execution_count": 15, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -712,13 +637,8 @@ }, { "cell_type": "code", - "execution_count": 17, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.956379Z", - "start_time": "2026-04-21T11:17:15.924099Z" - } - }, + "execution_count": 16, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -758,18 +678,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 10. Storing Non-Tensor Data — `NonTensorData` / `NonTensorStack`\n\nNot every field is a numeric tensor. Prompts, file paths, JSON metadata,\nor arbitrary Python objects can be stored as **non-tensor data** using\ntensordict's `NonTensorData` and `NonTensorStack`.\n\n- `NonTensorData` wraps a **single** Python object (string, dict, list, …).\n- `NonTensorStack` wraps a **batch** of Python objects — one per sample.\n\nTransferQueue serialises them transparently alongside regular tensors." + "## 10. Storing Non-Tensor Data — `NonTensorData` / `NonTensorStack`\n", + "\n", + "Not every field is a numeric tensor. Prompts, file paths, JSON metadata,\n", + "or arbitrary Python objects can be stored as **non-tensor data** using\n", + "tensordict's `NonTensorStack`.\n", + "\n", + "- `NonTensorStack` wraps a **batch** of Python objects — one per sample.\n", + "\n", + "TransferQueue serialises them transparently alongside regular tensors." ] }, { "cell_type": "code", - "execution_count": 18, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:15.991612Z", - "start_time": "2026-04-21T11:17:15.957250Z" - } - }, + "execution_count": 17, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -808,13 +731,8 @@ }, { "cell_type": "code", - "execution_count": 19, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:16.059711Z", - "start_time": "2026-04-21T11:17:16.022937Z" - } - }, + "execution_count": 18, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -853,7 +771,7 @@ "for i, meta in enumerate(list(result_nt[\"metadata\"])):\n", " print(f\" [{i}] {meta}\")\n", "\n", - "# You can also add a NonTensorData field to a single key via kv_put\n", + "# You can also add a NonTensorStack field to a single key via kv_put\n", "tq.kv_put(\n", " key=\"single_nt\",\n", " partition_id=\"train\",\n", @@ -873,96 +791,155 @@ "source": [ "## 11. Lazy Data Parsing with `data_parser`\n", "\n", - "Sometimes you want to store lightweight **references** (e.g. URLs, file paths, or shape descriptors)\n", + "Sometimes you want to store lightweight **references** (e.g. URLs or file paths)\n", "and defer the expensive loading / decoding until the data reaches the storage unit.\n", "\n", "`kv_put` / `kv_batch_put` accept an optional `data_parser` callable. It is executed **inside**\n", - "each `SimpleStorageUnit` at put time, and receives the raw `field_data` dict during this put request. The callable\n", - "should return a dict with the same structure, replacing reference values by parsed data.\n", + "each `SimpleStorageUnit` at put time. The callable receives a plain `dict` (not a TensorDict)\n", + "mapping `field_name -> batched_values`. For a regular tensor column the value is a batched tensor;\n", + "for nested tensors (jagged or strided) and `NonTensorStack` columns the values are extracted into\n", + "a `list`. It must return a `dict` with the exact same keys and the same number of elements per\n", + "column; do not change the inner order of values within each column.\n", + "\n", + "> **Design tip:** Separate the **core single-sample parser** from the **batch concurrency wrapper**.\n", + "> The wrapper can use `asyncio` to process all samples in parallel while the parser function itself\n", + "> remains synchronous to the caller.\n", "\n", "> **Note:** `data_parser` is only supported by the **SimpleStorage** backend." ] }, { "cell_type": "code", - "execution_count": 20, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:16.088804Z", - "start_time": "2026-04-21T11:17:16.062652Z" - } - }, + "execution_count": 19, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Stored 3 samples with data parser\n" + "Put succeeded. Fields: ['data_to_be_parsed', 'normal_data']\n", + "Total kv_batch_put time: 1.02s (concurrency keeps it ~1s, not 32s)\n", + "\n" ] } ], "source": [ - "def create_data_by_shape_parser(field_data):\n", - " \"\"\"Convert shape descriptors in 'data_to_be_parsed' into random tensors.\"\"\"\n", - " if \"data_to_be_parsed\" in field_data:\n", - " shapes = field_data[\"data_to_be_parsed\"]\n", - " field_data[\"data_to_be_parsed\"] = [torch.randn(shape) for shape in shapes]\n", + "import asyncio\n", + "import time\n", + "\n", + "\n", + "# 1. Define core single-sample parser (pure business logic, no asyncio, no batch)\n", + "def parse_url(url: str) -> torch.Tensor:\n", + " \"\"\"Parse a URL-like descriptor 'dtype:HxW' into a random tensor.\"\"\"\n", + " dtype_str, shape_str = url.split(\":\")\n", + " dtype = getattr(torch, dtype_str)\n", + " shape = [int(dim) for dim in shape_str.split(\"x\")]\n", + " return torch.randn(shape, dtype=dtype)\n", + "\n", + "\n", + "# 2. Define Batch-level parser: sync on the outside, async-parallel on the inside\n", + "def concurrent_batch_url_parser(field_data: dict) -> dict:\n", + " \"\"\"Batch-level data_parser executed inside SimpleStorageUnit.\n", + "\n", + " It receives a ``dict`` (not a TensorDict) where each value is a\n", + " batched column. For columns created from ``NonTensorStack`` the\n", + " value is a plain ``list`` of Python objects.\n", + "\n", + " Workflow:\n", + " 1. Spawns one async task per list element.\n", + " 2. Waits until *all* tasks finish (``asyncio.gather``).\n", + " 3. Replaces the list with the list of results.\n", + "\n", + " Because ``asyncio.run`` blocks until the loop finishes, this function\n", + " is **synchronous** to its caller: when it returns, every sample has\n", + " been processed.\n", + "\n", + " Args:\n", + " field_data: Mapping ``field_name -> batched_values``. The dict\n", + " keys must stay exactly the same; only values may be\n", + " transformed in-place.\n", + "\n", + " Returns:\n", + " The same dict with parsed values substituted.\n", + " \"\"\"\n", + " if \"data_to_be_parsed\" not in field_data:\n", + " return field_data\n", + "\n", + " urls: list[str] = field_data[\"data_to_be_parsed\"]\n", + "\n", + " async def _async_parse_single(url: str) -> torch.Tensor:\n", + " await asyncio.sleep(1.0) # Add fixed delay per sample\n", + " return parse_url(url)\n", + "\n", + " async def _process_all():\n", + " tasks = [asyncio.create_task(_async_parse_single(url)) for url in urls]\n", + " return await asyncio.gather(*tasks)\n", + "\n", + " start = time.perf_counter()\n", + " field_data[\"data_to_be_parsed\"] = asyncio.run(_process_all())\n", + " elapsed = time.perf_counter() - start\n", + "\n", + " print(f\"[data_parser] Processed {len(urls)} samples in {elapsed:.2f}s (serial would be ~{len(urls)}.0s)\")\n", " return field_data\n", "\n", "\n", - "parser_keys = [\"parser_0\", \"parser_1\", \"parser_2\"]\n", + "# ---------------------------------------------------------------------------\n", + "# Build the batch\n", + "# ---------------------------------------------------------------------------\n", + "batch_size = 32\n", + "\n", + "normal_data = torch.randn(batch_size, 2)\n", + "\n", + "# URL-like strings: all use the same dtype so TQ can pack them on get\n", + "shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]\n", + "urls = [f\"float32:{h}x{w}\" for h, w in shapes]\n", "\n", "parser_fields = TensorDict(\n", " {\n", - " # Normal data column that will not be modified by data parser\n", - " \"normal_data\": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),\n", - " # Each sample carries a list of ints describing the desired tensor shape\n", - " \"data_to_be_parsed\": NonTensorStack([2, 3], [1, 4], [3, 2]),\n", + " \"normal_data\": normal_data,\n", + " \"data_to_be_parsed\": NonTensorStack(*urls),\n", " },\n", - " batch_size=3,\n", + " batch_size=batch_size,\n", ")\n", "\n", - "tq.kv_batch_put(\n", - " keys=parser_keys,\n", + "data_parser_keys = [f\"data_parser_sample_{i}\" for i in range(batch_size)]\n", + "\n", + "put_start_time = time.perf_counter()\n", + "meta = tq.kv_batch_put(\n", + " keys=data_parser_keys,\n", " partition_id=\"train\",\n", " fields=parser_fields,\n", - " data_parser=create_data_by_shape_parser,\n", + " data_parser=concurrent_batch_url_parser,\n", ")\n", - "print(\"Stored 3 samples with data parser\")" + "put_elapsed = time.perf_counter() - put_start_time\n", + "print(f\"Put succeeded. Fields: {meta.fields}\")\n", + "print(f\"Total kv_batch_put time: {put_elapsed:.2f}s (concurrency keeps it ~1s, not {batch_size}s)\\n\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "When `kv_batch_put` returns, the user-defined data parser has also finished executing. We can then safely call `kv_batch_get` to retrieve the parsed data." + }, { "cell_type": "code", - "execution_count": 21, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:16.132204Z", - "start_time": "2026-04-21T11:17:16.089857Z" - } - }, + "execution_count": 20, + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2026-04-21 19:17:16,096 - WARNING - transfer_queue.storage.managers.simple_backend_manager - Failed to pack nested tensor with jagged layout. Falling back to strided layout. Detailed error: Cannot represent given tensor list as a nested tensor with the jagged layout. Note that the jagged layout only allows for a single ragged dimension. For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim.\n" + "2026-04-22 14:15:17,477 - WARNING - transfer_queue.storage.managers.simple_backend_manager - Failed to pack nested tensor with jagged layout. Falling back to strided layout. Detailed error: Cannot represent given tensor list as a nested tensor with the jagged layout. Note that the jagged layout only allows for a single ragged dimension. For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "normal_data:\n", - "tensor([[1., 2.],\n", - " [3., 4.],\n", - " [5., 6.]])\n", - "\n", - "data_to_be_parsed shapes:\n", - " sample 0: (2, 3)\n", - " sample 1: (1, 4)\n", - " sample 2: (3, 2)\n", - "\n", - "Assertions passed!\n" + "[PASS] normal_data is unchanged.\n", + "[PASS] All 32 parsed tensors have correct dtype & shape.\n", + "[PASS] Timing looks concurrent: 1.02s < 2.0s\n" ] }, { @@ -974,24 +951,28 @@ } ], "source": [ - "result_parser = tq.kv_batch_get(keys=parser_keys, partition_id=\"train\")\n", + "result = tq.kv_batch_get(keys=data_parser_keys, partition_id=\"train\")\n", "\n", "# normal_data should be unchanged\n", - "print(\"normal_data:\")\n", - "print(result_parser[\"normal_data\"])\n", + "torch.testing.assert_close(result[\"normal_data\"], normal_data)\n", + "print(\"[PASS] normal_data is unchanged.\")\n", "\n", "# data_to_be_parsed should now be tensors with the requested shapes\n", - "expected_shapes = [(2, 3), (1, 4), (3, 2)]\n", - "print(\"\\ndata_to_be_parsed shapes:\")\n", + "expected_shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]\n", "for i, expected in enumerate(expected_shapes):\n", - " actual = tuple(result_parser[\"data_to_be_parsed\"][i].shape)\n", + " tensor = result[\"data_to_be_parsed\"][i]\n", + " assert tensor.dtype == torch.float32\n", + " actual = tuple(tensor.shape)\n", " assert actual == expected, f\"Mismatch at {i}: {actual} != {expected}\"\n", - " print(f\" sample {i}: {actual}\")\n", + "print(f\"[PASS] All {batch_size} parsed tensors have correct dtype & shape.\")\n", "\n", - "print(\"\\nAssertions passed!\")\n", + "# Timing sanity check: serial would be ~batch_size seconds.\n", + "# Because asyncio tasks run in parallel inside the parser, it should be ~1 s.\n", + "assert put_elapsed < 2.0, f\"Expected concurrent execution (~1s), but took {put_elapsed:.2f}s.\"\n", + "print(f\"[PASS] Timing looks concurrent: {put_elapsed:.2f}s < 2.0s\")\n", "\n", - "# Clean up\n", - "tq.kv_clear(keys=parser_keys, partition_id=\"train\")" + "# wait for Ray log collect\n", + "time.sleep(2)" ] }, { @@ -1003,13 +984,8 @@ }, { "cell_type": "code", - "execution_count": 22, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:16.210207Z", - "start_time": "2026-04-21T11:17:16.142608Z" - } - }, + "execution_count": 21, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1040,13 +1016,8 @@ }, { "cell_type": "code", - "execution_count": 23, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:16.249278Z", - "start_time": "2026-04-21T11:17:16.211357Z" - } - }, + "execution_count": 22, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1062,6 +1033,7 @@ "tq.kv_clear(keys=\"sample_0\", partition_id=\"train\")\n", "tq.kv_clear(keys=\"sample_1\", partition_id=\"train\")\n", "tq.kv_clear(keys=keys, partition_id=\"train\")\n", + "tq.kv_clear(keys=data_parser_keys, partition_id=\"train\")\n", "tq.kv_clear(keys=\"val_sample_0\", partition_id=\"validation\")\n", "print(\"All keys cleared.\")\n", "\n", @@ -1072,13 +1044,8 @@ }, { "cell_type": "code", - "execution_count": 24, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:17.385182Z", - "start_time": "2026-04-21T11:17:16.250683Z" - } - }, + "execution_count": 23, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -1098,18 +1065,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "---\n\n## Summary\n\n| Operation | Function | Notes |\n|---|---|---|\n| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n| Put with parser | `tq.kv_batch_put(..., data_parser=fn)` | Only for **SimpleStorage**; fn receives dict |\n| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n| Close | `tq.close()` | Tears down controller & storage |\n\nFor **async** variants, use `async_kv_put`, `async_kv_batch_put`,\n`async_kv_batch_get`, `async_kv_list`, and `async_kv_clear`.\n\nFor low-level, metadata-based access, see `tq.get_client()` and the\n[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)." + "---\n\n## Summary\n\n| Operation | Function | Notes |\n|---|---|---|\n| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n| Put with parser | `tq.kv_batch_put(..., data_parser=fn)` | Only for **SimpleStorage**; receives dict, can use asyncio inside |\n| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n| Close | `tq.close()` | Tears down controller & storage |\n\nFor **async** variants, use `async_kv_put`, `async_kv_batch_put`,\n`async_kv_batch_get`, `async_kv_list`, and `async_kv_clear`.\n\nFor low-level, metadata-based access, see `tq.get_client()` and the\n[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)." ] }, { "cell_type": "code", - "execution_count": 24, - "metadata": { - "ExecuteTime": { - "end_time": "2026-04-21T11:17:17.455073Z", - "start_time": "2026-04-21T11:17:17.409870Z" - } - }, + "execution_count": 23, + "metadata": {}, "outputs": [], "source": [] } From c031e3caf541ba3da4f6aecbaf24fe2039c29f03 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 22 Apr 2026 15:10:47 +0800 Subject: [PATCH 4/5] update Signed-off-by: 0oshowero0 --- tutorial/basic.ipynb | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tutorial/basic.ipynb b/tutorial/basic.ipynb index 2edf4582..e8f97e9b 100644 --- a/tutorial/basic.ipynb +++ b/tutorial/basic.ipynb @@ -4,7 +4,30 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# TransferQueue Tutorial — Basic Usage\n\nThis notebook walks through the core **Key-Value (KV) interface** of\n[TransferQueue](https://github.com/Ascend/TransferQueue), an asynchronous\nstreaming data management module for efficient post-training workflows.\n\n**What you will learn:**\n\n1. Initialise TransferQueue (with Ray)\n2. Store a single sample — `kv_put`\n3. Store a batch of samples — `kv_batch_put`\n4. Retrieve data — `kv_batch_get`\n5. List stored keys & tags — `kv_list`\n6. Partial-key and partial-field retrieval\n7. Updating fields incrementally\n8. Working with nested (variable-length) tensors\n9. Storing variable-size image data\n10. Storing non-tensor data (`NonTensorData` / `NonTensorStack`)\n11. Multiple partitions\n12. Clean up — `kv_clear` / `close`\n\n> **Prerequisites:** `pip install TransferQueue` (or install from source). \n> Ray will be started automatically in this notebook." + "# TransferQueue Tutorial — Basic Usage\n", + "\n", + "This notebook walks through the core **Key-Value (KV) interface** of\n", + "[TransferQueue](https://github.com/Ascend/TransferQueue), an asynchronous\n", + "streaming data management module for efficient post-training workflows.\n", + "\n", + "**What you will learn:**\n", + "\n", + "1. Initialise TransferQueue (with Ray)\n", + "2. Store a single sample — `kv_put`\n", + "3. Store a batch of samples — `kv_batch_put`\n", + "4. Retrieve data — `kv_batch_get`\n", + "5. List stored keys & tags — `kv_list`\n", + "6. Partial-key and partial-field retrieval\n", + "7. Updating fields incrementally\n", + "8. Working with nested (variable-length) tensors\n", + "9. Storing variable-size image data\n", + "10. Storing non-tensor data (`NonTensorStack`)\n", + "11. Lazy Data Parsing with `data_parser`\n", + "12. Multiple partitions\n", + "13. Clean up — `kv_clear` / `close`\n", + "\n", + "> **Prerequisites:** `pip install TransferQueue` (or install from source). \n", + "> Ray will be started automatically in this notebook." ] }, { @@ -23,7 +46,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "2026-04-22 14:15:11,786\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n" + "2026-04-22 14:15:11,786\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001B[1m\u001B[32mhttp://127.0.0.1:8265 \u001B[39m\u001B[22m\n" ] }, { @@ -678,7 +701,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 10. Storing Non-Tensor Data — `NonTensorData` / `NonTensorStack`\n", + "## 10. Storing Non-Tensor Data — `NonTensorStack`\n", "\n", "Not every field is a numeric tensor. Prompts, file paths, JSON metadata,\n", "or arbitrary Python objects can be stored as **non-tensor data** using\n", From 45ca52d8ad493b5a9734c72ec5167fabb2a2cb3f Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Wed, 22 Apr 2026 15:52:23 +0800 Subject: [PATCH 5/5] change docstring Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 18 ++++++----- transfer_queue/interface.py | 32 +++++++++---------- transfer_queue/storage/managers/base.py | 8 ++--- .../managers/simple_backend_manager.py | 8 ++--- tutorial/basic.ipynb | 5 +-- 5 files changed, 37 insertions(+), 34 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index df9825b7..212c52f1 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -348,10 +348,11 @@ async def async_put( dict format (not TensorDict), mapping field_name -> batched values. For a regular tensor column the value is a batched tensor; for nested tensors (jagged or strided) and NonTensorStack columns - the values are extracted into a list. It must return a dict of - the same format with the exact same keys and the same number of - elements per column; do not change the inner order of values - within each column. Only supported by SimpleStorage. + the values are extracted into a list. It must modify values + in-place based on the original keys; do not add or remove keys. + The number of elements per column must also remain unchanged. + Do not change the inner order of values within each column. + Only supported by SimpleStorage. Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved @@ -1317,10 +1318,11 @@ def put( dict format (not TensorDict), mapping field_name -> batched values. For a regular tensor column the value is a batched tensor; for nested tensors (jagged or strided) and NonTensorStack columns - the values are extracted into a list. It must return a dict of - the same format with the exact same keys and the same number of - elements per column; do not change the inner order of values - within each column. Only supported by SimpleStorage. + the values are extracted into a list. It must modify values + in-place based on the original keys; do not add or remove keys. + The number of elements per column must also remain unchanged. + Do not change the inner order of values within each column. + Only supported by SimpleStorage. Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 627f5879..31b5a232 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -410,10 +410,10 @@ def kv_put( mapping field_name -> batched values. For a regular tensor column the value is a batched tensor; for nested tensors (jagged or strided) and NonTensorStack columns the values are extracted into - a list. It must return a dict of the same format with the exact - same keys and the same number of elements per column; do not - change the inner order of values within each column. Only - supported by SimpleStorage. + a list. It must modify values in-place based on the original keys; + do not add or remove keys. The number of elements per column must + also remain unchanged. Do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. @@ -510,10 +510,10 @@ def kv_batch_put( mapping field_name -> batched values. For a regular tensor column the value is a batched tensor; for nested tensors (jagged or strided) and NonTensorStack columns the values are extracted into - a list. It must return a dict of the same format with the exact - same keys and the same number of elements per column; do not - change the inner order of values within each column. Only - supported by SimpleStorage. + a list. It must modify values in-place based on the original keys; + do not add or remove keys. The number of elements per column must + also remain unchanged. Do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. @@ -789,10 +789,10 @@ async def async_kv_put( mapping field_name -> batched values. For a regular tensor column the value is a batched tensor; for nested tensors (jagged or strided) and NonTensorStack columns the values are extracted into - a list. It must return a dict of the same format with the exact - same keys and the same number of elements per column; do not - change the inner order of values within each column. Only - supported by SimpleStorage. + a list. It must modify values in-place based on the original keys; + do not add or remove keys. The number of elements per column must + also remain unchanged. Do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. @@ -890,10 +890,10 @@ async def async_kv_batch_put( mapping field_name -> batched values. For a regular tensor column the value is a batched tensor; for nested tensors (jagged or strided) and NonTensorStack columns the values are extracted into - a list. It must return a dict of the same format with the exact - same keys and the same number of elements per column; do not - change the inner order of values within each column. Only - supported by SimpleStorage. + a list. It must modify values in-place based on the original keys; + do not add or remove keys. The number of elements per column must + also remain unchanged. Do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 54622ffb..374db0a1 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -320,10 +320,10 @@ async def put_data( field_name -> batched values. For a regular tensor column the value is a batched tensor; for nested tensors (jagged or strided) and NonTensorStack columns the values are extracted into a list. - It must return a dict of the same format with the exact same keys - and the same number of elements per column; do not change the - inner order of values within each column. Only supported by - SimpleStorage backend. + It must modify values in-place based on the original keys; do not + add or remove keys. The number of elements per column must also + remain unchanged. Do not change the inner order of values within + each column. Only supported by SimpleStorage backend. """ raise NotImplementedError("Subclasses must implement put_data") diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 09ee54f6..99c6c7d2 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -302,10 +302,10 @@ async def put_data( field_name -> batched values. For a regular tensor column the value is a batched tensor; for nested tensors (jagged or strided) and NonTensorStack columns the values are extracted into a list. - It must return a dict of the same format with the exact same keys - and the same number of elements per column; do not change the - inner order of values within each column. Executed distributedly - on each SimpleStorageUnit. + It must modify values in-place based on the original keys; do not + add or remove keys. The number of elements per column must also + remain unchanged. Do not change the inner order of values within + each column. Executed distributedly on each SimpleStorageUnit. """ logger.debug(f"[{self.storage_manager_id}]: receive put_data request, putting {metadata.size} samples.") diff --git a/tutorial/basic.ipynb b/tutorial/basic.ipynb index e8f97e9b..90ad2411 100644 --- a/tutorial/basic.ipynb +++ b/tutorial/basic.ipynb @@ -821,8 +821,9 @@ "each `SimpleStorageUnit` at put time. The callable receives a plain `dict` (not a TensorDict)\n", "mapping `field_name -> batched_values`. For a regular tensor column the value is a batched tensor;\n", "for nested tensors (jagged or strided) and `NonTensorStack` columns the values are extracted into\n", - "a `list`. It must return a `dict` with the exact same keys and the same number of elements per\n", - "column; do not change the inner order of values within each column.\n", + "a `list`. It must modify values in-place based on the original keys; do not add or remove keys.\n", + "The number of elements per column must also remain unchanged. Do not change the inner order of\n", + "values within each column.\n", "\n", "> **Design tip:** Separate the **core single-sample parser** from the **batch concurrency wrapper**.\n", "> The wrapper can use `asyncio` to process all samples in parallel while the parser function itself\n",