diff --git a/tests/test_client.py b/tests/test_client.py index a911511..7ce8bcb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -418,7 +418,7 @@ def client_setup(mock_controller, mock_storage): client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations - async def mock_put_data(data, metadata): + async def mock_put_data(data, metadata, data_parser=None): pass # Just pretend to store the data async def mock_get_data(metadata): @@ -511,7 +511,7 @@ def test_single_controller_multiple_storages(): client.initialize_storage_manager(manager_type="SimpleStorage", config=config) # Mock all storage manager methods to avoid real ZMQ operations - async def mock_put_data(data, metadata): + async def mock_put_data(data, metadata, data_parser=None): pass # Just pretend to store the data async def mock_get_data(metadata): diff --git a/tests/test_simple_storage_unit.py b/tests/test_simple_storage_unit.py index c553a1a..2519b97 100644 --- a/tests/test_simple_storage_unit.py +++ b/tests/test_simple_storage_unit.py @@ -34,11 +34,14 @@ def __init__(self, storage_put_get_address): self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout self.socket.connect(storage_put_get_address) - def send_put(self, client_id, global_indexes, field_data): + def send_put(self, client_id, global_indexes, field_data, data_parser=None): + body = {"global_indexes": global_indexes, "data": field_data} + if data_parser is not None: + body["data_parser"] = data_parser msg = ZMQMessage.create( request_type=ZMQRequestType.PUT_DATA, sender_id=f"mock_client_{client_id}", - body={"global_indexes": global_indexes, "data": field_data}, + body=body, ) self.socket.send_multipart(msg.serialize()) return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False)) @@ -434,3 +437,177 @@ def test_storage_unit_data_capacity_uses_active_keys(): assert len(storage._active_keys) == 2 storage.put_data({"f": [4]}, global_indexes=[3]) assert storage._active_keys == {0, 1, 3} + + +def test_storage_unit_data_parser(storage_setup): + """Test data_parser functionality in SimpleStorageUnit. + + Writes two columns: + - normal_data: regular tensors, should remain unchanged + - data_to_be_parsed: list of shape descriptors (list of ints) + + data_parser converts shape descriptors into random tensors of those shapes. + """ + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + def create_data_by_shape_parser(field_data): + if "data_to_be_parsed" in field_data: + shapes = field_data["data_to_be_parsed"] + field_data["data_to_be_parsed"] = [torch.randn(shape) for shape in shapes] + return field_data + + # Prepare data: normal_data is a batch tensor, data_to_be_parsed is a list of shape lists + field_data = { + "normal_data": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + "data_to_be_parsed": [[2, 3], [1, 4], [3, 2]], + } + global_indexes = [0, 1, 2] + + # Put with data_parser + response = client.send_put(0, global_indexes, field_data, data_parser=create_data_by_shape_parser) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"Put failed: {response.body}" + + # Get back + response = client.send_get(0, global_indexes, ["normal_data", "data_to_be_parsed"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + + result = response.body["data"] + + # Verify normal_data is unchanged + torch.testing.assert_close(result["normal_data"][0], torch.tensor([1.0, 2.0])) + torch.testing.assert_close(result["normal_data"][1], torch.tensor([3.0, 4.0])) + torch.testing.assert_close(result["normal_data"][2], torch.tensor([5.0, 6.0])) + + # Verify data_to_be_parsed shapes match the input shape descriptors + expected_shapes = [(2, 3), (1, 4), (3, 2)] + for i, expected_shape in enumerate(expected_shapes): + actual_shape = tuple(result["data_to_be_parsed"][i].shape) + assert actual_shape == expected_shape, ( + f"Shape mismatch at index {i}: expected {expected_shape}, got {actual_shape}" + ) + + client.close() + + +def test_storage_unit_data_parser_callable_types(storage_setup): + """Test that various callable types (partial, callable class) work as data_parser.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + from functools import partial + + # 1. Test functools.partial + def _partial_parser(field_data, prefix): + if "text" in field_data: + field_data["text"] = [f"{prefix}{t}" for t in field_data["text"]] + return field_data + + partial_parser = partial(_partial_parser, prefix="parsed_") + + response = client.send_put( + 0, + [0, 1], + {"text": ["a", "b"]}, + data_parser=partial_parser, + ) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"partial parser failed: {response.body}" + + response = client.send_get(0, [0, 1], ["text"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + assert response.body["data"]["text"] == ["parsed_a", "parsed_b"] + + # 2. Test callable class instance + class CallableParser: + def __call__(self, field_data): + if "value" in field_data: + field_data["value"] = [v * 2 for v in field_data["value"]] + return field_data + + callable_parser = CallableParser() + response = client.send_put( + 0, + [2, 3], + {"value": [1, 2]}, + data_parser=callable_parser, + ) + assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"callable class parser failed: {response.body}" + + response = client.send_get(0, [2, 3], ["value"]) + assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE + assert response.body["data"]["value"] == [2, 4] + + client.close() + + +def test_storage_unit_data_parser_validation(storage_setup): + """Test that invalid data_parser inputs produce clear error messages.""" + _, put_get_address = storage_setup + client = MockStorageClient(put_get_address) + + # 1. Non-callable data_parser should return a clear TypeError + response = client.send_put( + 0, + [0], + {"data": [1]}, + data_parser="not_callable", + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser must be callable" in response.body["message"] + + # 2. data_parser returning non-dict should return a clear TypeError + def bad_parser(field_data): + return "not_a_dict" + + response = client.send_put( + 0, + [1], + {"data": [1]}, + data_parser=bad_parser, + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser must return a dict" in response.body["message"] + + # 3. data_parser deleting a key should return a clear ValueError + def delete_key_parser(field_data): + del field_data["data"] + return field_data + + response = client.send_put( + 0, + [2], + {"data": [1], "extra": [2]}, + data_parser=delete_key_parser, + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser must not change dict keys" in response.body["message"] + + # 4. data_parser adding a key should return a clear ValueError + def add_key_parser(field_data): + field_data["new_key"] = [999] + return field_data + + response = client.send_put( + 0, + [3], + {"data": [1]}, + data_parser=add_key_parser, + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser must not change dict keys" in response.body["message"] + + # 5. data_parser changing element count should return a clear ValueError + def wrong_len_parser(field_data): + field_data["data"] = field_data["data"][:-1] + return field_data + + response = client.send_put( + 0, + [4, 5], + {"data": [1, 2]}, + data_parser=wrong_len_parser, + ) + assert response.request_type == ZMQRequestType.PUT_ERROR + assert "data_parser changed the number of elements" in response.body["message"] + + client.close() diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 79b5c5d..212c52f 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -324,6 +324,7 @@ async def async_put( data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> BatchMeta: """Asynchronously write data to storage units based on metadata. @@ -342,6 +343,16 @@ async def async_put( metadata: Records the metadata of a batch of data samples, containing index and storage unit information. If None, metadata will be auto-generated. partition_id: Target data partition id (required if metadata is not provided) + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. The input is a slice of the `data` parameter, in plain + dict format (not TensorDict), mapping field_name -> batched values. + For a regular tensor column the value is a batched tensor; for + nested tensors (jagged or strided) and NonTensorStack columns + the values are extracted into a list. It must modify values + in-place based on the original keys; do not add or remove keys. + The number of elements per column must also remain unchanged. + Do not change the inner order of values within each column. + Only supported by SimpleStorage. Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved @@ -411,7 +422,7 @@ async def async_put( with limit_pytorch_auto_parallel_threads( target_num_threads=TQ_NUM_THREADS, info=f"[{self.client_id}] async_put" ): - await self.storage_manager.put_data(data, metadata) + await self.storage_manager.put_data(data, metadata, data_parser=data_parser) await self.async_set_custom_meta(metadata) @@ -1279,7 +1290,11 @@ def set_custom_meta(self, metadata: BatchMeta) -> None: return self._set_custom_meta(metadata=metadata) def put( - self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None + self, + data: TensorDict, + metadata: Optional[BatchMeta] = None, + partition_id: Optional[str] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> BatchMeta: """Synchronously write data to storage units based on metadata. @@ -1298,6 +1313,16 @@ def put( metadata: Records the metadata of a batch of data samples, containing index and storage unit information. If None, metadata will be auto-generated. partition_id: Target data partition id (required if metadata is not provided) + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. The input is a slice of the `data` parameter, in plain + dict format (not TensorDict), mapping field_name -> batched values. + For a regular tensor column the value is a batched tensor; for + nested tensors (jagged or strided) and NonTensorStack columns + the values are extracted into a list. It must modify values + in-place based on the original keys; do not add or remove keys. + The number of elements per column must also remain unchanged. + Do not change the inner order of values within each column. + Only supported by SimpleStorage. Returns: BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved @@ -1336,7 +1361,7 @@ def put( >>> # This will create metadata in "insert" mode internally. >>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id) """ - return self._put(data=data, metadata=metadata, partition_id=partition_id) + return self._put(data=data, metadata=metadata, partition_id=partition_id, data_parser=data_parser) def get_data(self, metadata: BatchMeta) -> TensorDict: """Synchronously fetch data from storage units and organize into TensorDict. diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index 21f1e8a..31b5a23 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -19,7 +19,7 @@ import subprocess import time from importlib import resources -from typing import Any, Optional +from typing import Any, Callable, Optional from urllib.parse import urlparse import ray @@ -388,6 +388,7 @@ def kv_put( partition_id: str, fields: Optional[TensorDict | dict[str, Any]] = None, tag: Optional[dict[str, Any]] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> KVBatchMeta: """Put a single key-value pair to TransferQueue. @@ -403,6 +404,16 @@ def kv_put( If dict is provided, tensors will be unsqueezed to add batch dimension. If not provided, will only update the newly given tag to the key. tag: Optional metadata tag to associate with the key + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. The input is a slice of the `fields` parameter passed to + kv_put / kv_batch_put, in plain dict format (not TensorDict), + mapping field_name -> batched values. For a regular tensor column + the value is a batched tensor; for nested tensors (jagged or + strided) and NonTensorStack columns the values are extracted into + a list. It must modify values in-place based on the original keys; + do not add or remove keys. The number of elements per column must + also remain unchanged. Do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. @@ -459,7 +470,7 @@ def kv_put( raise ValueError("`fields` can only be dict or TensorDict") # After put, batch_meta.field_names will include the new fields written by user - batch_meta = tq_client.put(fields, batch_meta) + batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser) else: # Directly update custom_meta (tag) to controller tq_client.set_custom_meta(batch_meta) @@ -476,7 +487,11 @@ def kv_put( def kv_batch_put( - keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None + keys: list[str], + partition_id: str, + fields: Optional[TensorDict] = None, + tags: Optional[list[dict[str, Any]]] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> KVBatchMeta: """Put multiple key-value pairs to TransferQueue in batch. @@ -489,6 +504,16 @@ def kv_batch_put( fields: TensorDict containing data for all keys. Must have batch_size == len(keys). If not provided, will only update the newly given tags to the keys. tags: List of metadata tags, one for each key + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. The input is a slice of the `fields` parameter passed to + kv_put / kv_batch_put, in plain dict format (not TensorDict), + mapping field_name -> batched values. For a regular tensor column + the value is a batched tensor; for nested tensors (jagged or + strided) and NonTensorStack columns the values are extracted into + a list. It must modify values in-place based on the original keys; + do not add or remove keys. The number of elements per column must + also remain unchanged. Do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. @@ -542,7 +567,7 @@ def kv_batch_put( # 3. put data if fields is not None: # After put, batch_meta.field_names will include the new fields written by user - batch_meta = tq_client.put(fields, batch_meta) + batch_meta = tq_client.put(fields, batch_meta, data_parser=data_parser) else: # Directly update custom_meta (tags) to controller tq_client.set_custom_meta(batch_meta) @@ -742,6 +767,7 @@ async def async_kv_put( partition_id: str, fields: Optional[TensorDict | dict[str, Any]] = None, tag: Optional[dict[str, Any]] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> KVBatchMeta: """Asynchronously put a single key-value pair to TransferQueue. @@ -757,6 +783,16 @@ async def async_kv_put( If dict is provided, tensors will be unsqueezed to add batch dimension. If not provided, will only update the newly given tag to the key. tag: Optional metadata tag to associate with the key + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. The input is a slice of the `fields` parameter passed to + kv_put / kv_batch_put, in plain dict format (not TensorDict), + mapping field_name -> batched values. For a regular tensor column + the value is a batched tensor; for nested tensors (jagged or + strided) and NonTensorStack columns the values are extracted into + a list. It must modify values in-place based on the original keys; + do not add or remove keys. The number of elements per column must + also remain unchanged. Do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the key, tags, partition_id, and fields. @@ -814,7 +850,7 @@ async def async_kv_put( raise ValueError("`fields` can only be dict or TensorDict") # After put, batch_meta.field_names will include the new fields written by user - batch_meta = await tq_client.async_put(fields, batch_meta) + batch_meta = await tq_client.async_put(fields, batch_meta, data_parser=data_parser) else: # Directly update custom_meta (tag) to controller await tq_client.async_set_custom_meta(batch_meta) @@ -831,7 +867,11 @@ async def async_kv_put( async def async_kv_batch_put( - keys: list[str], partition_id: str, fields: Optional[TensorDict] = None, tags: Optional[list[dict[str, Any]]] = None + keys: list[str], + partition_id: str, + fields: Optional[TensorDict] = None, + tags: Optional[list[dict[str, Any]]] = None, + data_parser: Optional[Callable[[Any], Any]] = None, ) -> KVBatchMeta: """Asynchronously put multiple key-value pairs to TransferQueue in batch. @@ -844,6 +884,16 @@ async def async_kv_batch_put( fields: TensorDict containing data for all keys. Must have batch_size == len(keys). If not provided, will only update the newly given tags to the keys. tags: List of metadata tags, one for each key + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. The input is a slice of the `fields` parameter passed to + kv_put / kv_batch_put, in plain dict format (not TensorDict), + mapping field_name -> batched values. For a regular tensor column + the value is a batched tensor; for nested tensors (jagged or + strided) and NonTensorStack columns the values are extracted into + a list. It must modify values in-place based on the original keys; + do not add or remove keys. The number of elements per column must + also remain unchanged. Do not change the inner order of values + within each column. Only supported by SimpleStorage. Returns: KVBatchMeta: Metadata containing the keys, tags, partition_id, and fields. @@ -896,7 +946,7 @@ async def async_kv_batch_put( # 3. put data if fields is not None: # After put, batch_meta.field_names will include the new fields written by user - batch_meta = await tq_client.async_put(fields, batch_meta) + batch_meta = await tq_client.async_put(fields, batch_meta, data_parser=data_parser) else: # Directly update custom_meta (tags) to controller await tq_client.async_set_custom_meta(batch_meta) diff --git a/transfer_queue/storage/managers/base.py b/transfer_queue/storage/managers/base.py index 42f17db..374db0a 100644 --- a/transfer_queue/storage/managers/base.py +++ b/transfer_queue/storage/managers/base.py @@ -21,7 +21,7 @@ import weakref from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from typing import Any, Optional +from typing import Any, Callable, Optional from uuid import uuid4 import ray @@ -306,13 +306,24 @@ async def notify_data_update( pass @abstractmethod - async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: + async def put_data( + self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + ) -> None: """ Put data into the storage backend. Args: data: Data to be put into the storage. metadata: BatchMeta of the corresponding data. + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. The input is a plain dict (not TensorDict) mapping + field_name -> batched values. For a regular tensor column the + value is a batched tensor; for nested tensors (jagged or strided) + and NonTensorStack columns the values are extracted into a list. + It must modify values in-place based on the original keys; do not + add or remove keys. The number of elements per column must also + remain unchanged. Do not change the inner order of values within + each column. Only supported by SimpleStorage backend. """ raise NotImplementedError("Subclasses must implement put_data") @@ -561,10 +572,17 @@ def _get_shape_type_custom_backend_meta_list(metadata: BatchMeta): ) return shapes, dtypes, custom_backend_meta_list - async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: + async def put_data( + self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + ) -> None: """ Store tensor data in the backend storage and notify the controller. """ + if data_parser is not None: + raise NotImplementedError( + "data_parser is not supported for KV-based backends (MooncakeStore, Yuanrong, RayStore)." + ) + num_samples = len(metadata.global_indexes) if data.batch_size[0] != num_samples: raise ValueError(f"Batch size of data ({data.batch_size[0]}) does not match expected ({num_samples})") diff --git a/transfer_queue/storage/managers/simple_backend_manager.py b/transfer_queue/storage/managers/simple_backend_manager.py index 00d8782..99c6c7d 100644 --- a/transfer_queue/storage/managers/simple_backend_manager.py +++ b/transfer_queue/storage/managers/simple_backend_manager.py @@ -21,7 +21,7 @@ from collections.abc import Mapping from functools import wraps from operator import itemgetter -from typing import Any, Callable, NamedTuple +from typing import Any, Callable, NamedTuple, Optional from uuid import uuid4 import torch @@ -285,7 +285,9 @@ def _select_by_positions(field_data, positions: list[int]): else: return field_data[positions] - async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: + async def put_data( + self, data: TensorDict, metadata: BatchMeta, data_parser: Optional[Callable[[Any], Any]] = None + ) -> None: """ Send data to remote StorageUnit based on metadata. @@ -295,6 +297,15 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: Args: data: TensorDict containing the data to store. metadata: BatchMeta containing storage location information. + data_parser: Optional callable to parse reference data (e.g., URLs) into real + content. The input is a plain dict (not TensorDict) mapping + field_name -> batched values. For a regular tensor column the + value is a batched tensor; for nested tensors (jagged or strided) + and NonTensorStack columns the values are extracted into a list. + It must modify values in-place based on the original keys; do not + add or remove keys. The number of elements per column must also + remain unchanged. Do not change the inner order of values within + each column. Executed distributedly on each SimpleStorageUnit. """ logger.debug(f"[{self.storage_manager_id}]: receive put_data request, putting {metadata.size} samples.") @@ -312,6 +323,7 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: group.global_indexes, {f: self._select_by_positions(data[f], group.batch_positions) for f in data.keys()}, target_storage_unit=su_id, + data_parser=data_parser, ) for su_id, group in routing.items() ] @@ -341,6 +353,7 @@ async def _put_to_single_storage_unit( global_indexes: list[int], storage_data: dict[str, Any], target_storage_unit: str, + data_parser: Optional[Callable[[Any], Any]] = None, socket: zmq.Socket = None, ): """ @@ -351,7 +364,7 @@ async def _put_to_single_storage_unit( request_type=ZMQRequestType.PUT_DATA, # type: ignore[arg-type] sender_id=self.storage_manager_id, receiver_id=target_storage_unit, - body={"global_indexes": global_indexes, "data": storage_data}, + body={"global_indexes": global_indexes, "data": storage_data, "data_parser": data_parser}, ) try: diff --git a/transfer_queue/storage/simple_backend.py b/transfer_queue/storage/simple_backend.py index e6908e5..7f0a001 100644 --- a/transfer_queue/storage/simple_backend.py +++ b/transfer_queue/storage/simple_backend.py @@ -342,10 +342,54 @@ def _handle_put(self, data_parts: ZMQMessage) -> ZMQMessage: """ try: global_indexes = data_parts.body["global_indexes"] - field_data = data_parts.body["data"] # field_data should be a TensorDict. + field_data = data_parts.body["data"] # field_data should be a dict. + data_parser = data_parts.body.get("data_parser", None) + with limit_pytorch_auto_parallel_threads( target_num_threads=TQ_NUM_THREADS, info=f"[{self.storage_unit_id}] _handle_put" ): + if data_parser is not None: + if not callable(data_parser): + raise TypeError(f"data_parser must be callable, got {type(data_parser).__name__}") + + original_keys = set(field_data.keys()) + original_lengths = {} + for k, v in field_data.items(): + if hasattr(v, "shape") and isinstance(v.shape, tuple | list) and len(v.shape) > 0: + original_lengths[k] = v.shape[0] + else: + try: + original_lengths[k] = len(v) + except Exception: + original_lengths[k] = None + + field_data = data_parser(field_data) + + if not isinstance(field_data, dict): + raise TypeError(f"data_parser must return a dict, got {type(field_data).__name__}") + + new_keys = set(field_data.keys()) + if new_keys != original_keys: + raise ValueError( + f"data_parser must not change dict keys. " + f"Original keys: {sorted(original_keys)}, got: {sorted(new_keys)}" + ) + + for k, v in field_data.items(): + if hasattr(v, "shape") and isinstance(v.shape, tuple | list) and len(v.shape) > 0: + new_len = v.shape[0] + else: + try: + new_len = len(v) + except Exception: + new_len = None + + orig_len = original_lengths[k] + if orig_len is not None and new_len is not None and orig_len != new_len: + raise ValueError( + f"data_parser changed the number of elements for key '{k}': " + f"expected {orig_len}, got {new_len}" + ) self.storage_data.put_data(field_data, global_indexes) # After put operation finish, send a message to the client diff --git a/transfer_queue/utils/serial_utils.py b/transfer_queue/utils/serial_utils.py index fd57b80..90f8592 100644 --- a/transfer_queue/utils/serial_utils.py +++ b/transfer_queue/utils/serial_utils.py @@ -23,7 +23,6 @@ import warnings from collections.abc import Sequence from contextvars import ContextVar -from types import FunctionType from typing import Any, TypeAlias import cloudpickle @@ -118,8 +117,9 @@ def enc_hook(self, obj: Any) -> Any: # Only true object arrays (or structured dtypes with object fields) reach here return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)) - if isinstance(obj, FunctionType): - # cloudpickle for functions/methods + if callable(obj): + # cloudpickle for arbitrary callables (functions, lambdas, functools.partial, + # callable class instances, bound methods, etc.) return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) # Fallback to pickle for unknown types diff --git a/tutorial/basic.ipynb b/tutorial/basic.ipynb index 5e5944d..90ad241 100644 --- a/tutorial/basic.ipynb +++ b/tutorial/basic.ipynb @@ -21,9 +21,10 @@ "7. Updating fields incrementally\n", "8. Working with nested (variable-length) tensors\n", "9. Storing variable-size image data\n", - "10. Storing non-tensor data (`NonTensorData` / `NonTensorStack`)\n", - "11. Multiple partitions\n", - "12. Clean up — `kv_clear` / `close`\n", + "10. Storing non-tensor data (`NonTensorStack`)\n", + "11. Lazy Data Parsing with `data_parser`\n", + "12. Multiple partitions\n", + "13. Clean up — `kv_clear` / `close`\n", "\n", "> **Prerequisites:** `pip install TransferQueue` (or install from source). \n", "> Ray will be started automatically in this notebook." @@ -33,11 +34,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. Initialization\n", - "\n", - "TransferQueue runs on top of [Ray](https://www.ray.io/). \n", - "We start Ray, then call `tq.init()` with a minimal configuration that uses the\n", - "built-in **SimpleStorage** backend." + "## 1. Initialization\n\nTransferQueue runs on top of [Ray](https://www.ray.io/). \nWe start Ray, then call `tq.init()` with a minimal configuration that uses the\nbuilt-in **SimpleStorage** backend." ] }, { @@ -49,16 +46,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "2026-03-27 23:06:04,557\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n", - "/opt/miniconda3/envs/verl/lib/python3.11/site-packages/ray/_private/worker.py:2062: FutureWarning: Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0\n", - " warnings.warn(\n" + "2026-04-22 14:15:11,786\tINFO worker.py:2014 -- Started a local Ray instance. View the dashboard at \u001B[1m\u001B[32mhttp://127.0.0.1:8265 \u001B[39m\u001B[22m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33m(raylet)\u001b[0m It looks like you're creating a detached actor in an anonymous namespace. In order to access this actor in the future, you will need to explicitly connect to this namespace with ray.init(namespace=\"798d49c0-a4e8-4877-813b-bce2d966c4eb\", ...)\n", "TransferQueue is ready!\n" ] } @@ -95,13 +89,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 2. Store a Single Sample — `kv_put`\n", - "\n", - "`kv_put` stores **one** key-value pair. \n", - "- `key` — a unique string identifier for the sample \n", - "- `partition_id` — a logical namespace (like a table name) \n", - "- `fields` — a `dict` of tensors **or** a `TensorDict` \n", - "- `tag` — optional metadata dict attached to the key" + "## 2. Store a Single Sample — `kv_put`\n\n`kv_put` stores **one** key-value pair. \n- `key` — a unique string identifier for the sample \n- `partition_id` — a logical namespace (like a table name) \n- `fields` — a `dict` of tensors **or** a `TensorDict` \n- `tag` — optional metadata dict attached to the key" ] }, { @@ -131,8 +119,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "You can also pass a pre-built `TensorDict` directly (the batch dimension\n", - "must be 1 for `kv_put`):" + "You can also pass a pre-built `TensorDict` directly (the batch dimension\nmust be 1 for `kv_put`):" ] }, { @@ -170,11 +157,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 3. Store a Batch of Samples — `kv_batch_put`\n", - "\n", - "When you have multiple samples, `kv_batch_put` is more efficient than\n", - "calling `kv_put` in a loop. The `fields` TensorDict must have\n", - "`batch_size == len(keys)`." + "## 3. Store a Batch of Samples — `kv_batch_put`\n\nWhen you have multiple samples, `kv_batch_put` is more efficient than\ncalling `kv_put` in a loop. The `fields` TensorDict must have\n`batch_size == len(keys)`." ] }, { @@ -215,9 +198,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. Retrieve Data — `kv_batch_get`\n", - "\n", - "Retrieve samples by key(s). The result is always a `TensorDict`." + "## 4. Retrieve Data — `kv_batch_get`\n\nRetrieve samples by key(s). The result is always a `TensorDict`." ] }, { @@ -285,13 +266,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 5. List Keys & Tags — `kv_list`\n", - "\n", - "`kv_list` returns a nested dict:\n", - "```\n", - "{ partition_id: { key: tag_dict, ... }, ... }\n", - "```\n", - "Pass `partition_id` to filter, or omit it to see everything." + "## 5. List Keys & Tags — `kv_list`\n\n`kv_list` returns a nested dict:\n```\n{ partition_id: { key: tag_dict, ... }, ... }\n```\nPass `partition_id` to filter, or omit it to see everything." ] }, { @@ -326,18 +301,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 6. Partial-Key and Partial-Field Retrieval\n", - "\n", - "You don't have to retrieve *all* keys or *all* fields at once." + "## 6. Partial-Key and Partial-Field Retrieval\n\nYou don't have to retrieve *all* keys or *all* fields at once." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### 6a. Partial Keys\n", - "\n", - "Just pass a subset of the keys you stored." + "### 6a. Partial Keys\n\nJust pass a subset of the keys you stored." ] }, { @@ -365,9 +336,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### 6b. Partial Fields\n", - "\n", - "Use the `fields` argument to select specific columns." + "### 6b. Partial Fields\n\nUse the `fields` argument to select specific columns." ] }, { @@ -404,11 +373,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 7. Updating Fields Incrementally\n", - "\n", - "TransferQueue tracks each field (column) independently per key (row). \n", - "You can **add new fields** to existing keys with another `kv_put` /\n", - "`kv_batch_put` call — the earlier fields are preserved." + "## 7. Updating Fields Incrementally\n\nTransferQueue tracks each field (column) independently per key (row). \nYou can **add new fields** to existing keys with another `kv_put` /\n`kv_batch_put` call — the earlier fields are preserved." ] }, { @@ -446,16 +411,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 8. Working with Nested (Variable-Length) Tensors\n", - "\n", - "In many NLP and RL workloads each sample has a **different sequence length**\n", - "(e.g. generated responses). PyTorch represents these as\n", - "[nested tensors](https://pytorch.org/docs/stable/nested.html) with the\n", - "**jagged layout** (`layout=torch.jagged`), and TransferQueue handles them\n", - "natively.\n", - "\n", - "> **Note:** Because individual samples have different shapes, you must use\n", - "> `kv_batch_put` (not `kv_put`) to store nested tensors." + "## 8. Working with Nested (Variable-Length) Tensors\n\nIn many NLP and RL workloads each sample has a **different sequence length**\n(e.g. generated responses). PyTorch represents these as\n[nested tensors](https://pytorch.org/docs/stable/nested.html) with the\n**jagged layout** (`layout=torch.jagged`), and TransferQueue handles them\nnatively.\n\n> **Note:** Because individual samples have different shapes, you must use\n> `kv_batch_put` (not `kv_put`) to store nested tensors." ] }, { @@ -535,8 +491,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Partial-key retrieval works the same way with nested tensors — only the\n", - "requested samples are returned:" + "Partial-key retrieval works the same way with nested tensors — only the\nrequested samples are returned:" ] }, { @@ -578,9 +533,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Higher-dimensional nested tensors work too. Here each sample is a 3D\n", - "tensor with a variable first dimension (e.g. a different number of\n", - "attention heads or generated candidates):" + "Higher-dimensional nested tensors work too. Here each sample is a 3D\ntensor with a variable first dimension (e.g. a different number of\nattention heads or generated candidates):" ] }, { @@ -633,17 +586,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 9. Storing Variable-Size Image Data\n", - "\n", - "A common multimodal scenario: each sample in a batch contains a\n", - "**different number of images**, and each image has a **different\n", - "resolution**. We can model this with a list of nested tensors — one\n", - "nested tensor per sample — wrapped inside a `TensorDict`.\n", - "\n", - "Since the data is doubly ragged (variable count *and* variable size),\n", - "we store each image as a flattened 1-D tensor and pack all images per\n", - "sample into a single jagged nested tensor. This way every sample is\n", - "one element of the batch, yet images retain their individual sizes." + "## 9. Storing Variable-Size Image Data\n\nA common multimodal scenario: each sample in a batch contains a\n**different number of images**, and each image has a **different\nresolution**. We can model this with a list of nested tensors — one\nnested tensor per sample — wrapped inside a `TensorDict`.\n\nSince the data is doubly ragged (variable count *and* variable size),\nwe store each image as a flattened 1-D tensor and pack all images per\nsample into a single jagged nested tensor. This way every sample is\none element of the batch, yet images retain their individual sizes." ] }, { @@ -758,13 +701,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 10. Storing Non-Tensor Data — `NonTensorData` / `NonTensorStack`\n", + "## 10. Storing Non-Tensor Data — `NonTensorStack`\n", "\n", "Not every field is a numeric tensor. Prompts, file paths, JSON metadata,\n", "or arbitrary Python objects can be stored as **non-tensor data** using\n", - "tensordict's `NonTensorData` and `NonTensorStack`.\n", + "tensordict's `NonTensorStack`.\n", "\n", - "- `NonTensorData` wraps a **single** Python object (string, dict, list, …).\n", "- `NonTensorStack` wraps a **batch** of Python objects — one per sample.\n", "\n", "TransferQueue serialises them transparently alongside regular tensors." @@ -852,7 +794,7 @@ "for i, meta in enumerate(list(result_nt[\"metadata\"])):\n", " print(f\" [{i}] {meta}\")\n", "\n", - "# You can also add a NonTensorData field to a single key via kv_put\n", + "# You can also add a NonTensorStack field to a single key via kv_put\n", "tq.kv_put(\n", " key=\"single_nt\",\n", " partition_id=\"train\",\n", @@ -870,16 +812,204 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 11. Multiple Partitions\n", + "## 11. Lazy Data Parsing with `data_parser`\n", "\n", - "Partitions provide logical isolation — the same key name can exist in\n", - "different partitions without conflict." + "Sometimes you want to store lightweight **references** (e.g. URLs or file paths)\n", + "and defer the expensive loading / decoding until the data reaches the storage unit.\n", + "\n", + "`kv_put` / `kv_batch_put` accept an optional `data_parser` callable. It is executed **inside**\n", + "each `SimpleStorageUnit` at put time. The callable receives a plain `dict` (not a TensorDict)\n", + "mapping `field_name -> batched_values`. For a regular tensor column the value is a batched tensor;\n", + "for nested tensors (jagged or strided) and `NonTensorStack` columns the values are extracted into\n", + "a `list`. It must modify values in-place based on the original keys; do not add or remove keys.\n", + "The number of elements per column must also remain unchanged. Do not change the inner order of\n", + "values within each column.\n", + "\n", + "> **Design tip:** Separate the **core single-sample parser** from the **batch concurrency wrapper**.\n", + "> The wrapper can use `asyncio` to process all samples in parallel while the parser function itself\n", + "> remains synchronous to the caller.\n", + "\n", + "> **Note:** `data_parser` is only supported by the **SimpleStorage** backend." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Put succeeded. Fields: ['data_to_be_parsed', 'normal_data']\n", + "Total kv_batch_put time: 1.02s (concurrency keeps it ~1s, not 32s)\n", + "\n" + ] + } + ], + "source": [ + "import asyncio\n", + "import time\n", + "\n", + "\n", + "# 1. Define core single-sample parser (pure business logic, no asyncio, no batch)\n", + "def parse_url(url: str) -> torch.Tensor:\n", + " \"\"\"Parse a URL-like descriptor 'dtype:HxW' into a random tensor.\"\"\"\n", + " dtype_str, shape_str = url.split(\":\")\n", + " dtype = getattr(torch, dtype_str)\n", + " shape = [int(dim) for dim in shape_str.split(\"x\")]\n", + " return torch.randn(shape, dtype=dtype)\n", + "\n", + "\n", + "# 2. Define Batch-level parser: sync on the outside, async-parallel on the inside\n", + "def concurrent_batch_url_parser(field_data: dict) -> dict:\n", + " \"\"\"Batch-level data_parser executed inside SimpleStorageUnit.\n", + "\n", + " It receives a ``dict`` (not a TensorDict) where each value is a\n", + " batched column. For columns created from ``NonTensorStack`` the\n", + " value is a plain ``list`` of Python objects.\n", + "\n", + " Workflow:\n", + " 1. Spawns one async task per list element.\n", + " 2. Waits until *all* tasks finish (``asyncio.gather``).\n", + " 3. Replaces the list with the list of results.\n", + "\n", + " Because ``asyncio.run`` blocks until the loop finishes, this function\n", + " is **synchronous** to its caller: when it returns, every sample has\n", + " been processed.\n", + "\n", + " Args:\n", + " field_data: Mapping ``field_name -> batched_values``. The dict\n", + " keys must stay exactly the same; only values may be\n", + " transformed in-place.\n", + "\n", + " Returns:\n", + " The same dict with parsed values substituted.\n", + " \"\"\"\n", + " if \"data_to_be_parsed\" not in field_data:\n", + " return field_data\n", + "\n", + " urls: list[str] = field_data[\"data_to_be_parsed\"]\n", + "\n", + " async def _async_parse_single(url: str) -> torch.Tensor:\n", + " await asyncio.sleep(1.0) # Add fixed delay per sample\n", + " return parse_url(url)\n", + "\n", + " async def _process_all():\n", + " tasks = [asyncio.create_task(_async_parse_single(url)) for url in urls]\n", + " return await asyncio.gather(*tasks)\n", + "\n", + " start = time.perf_counter()\n", + " field_data[\"data_to_be_parsed\"] = asyncio.run(_process_all())\n", + " elapsed = time.perf_counter() - start\n", + "\n", + " print(f\"[data_parser] Processed {len(urls)} samples in {elapsed:.2f}s (serial would be ~{len(urls)}.0s)\")\n", + " return field_data\n", + "\n", + "\n", + "# ---------------------------------------------------------------------------\n", + "# Build the batch\n", + "# ---------------------------------------------------------------------------\n", + "batch_size = 32\n", + "\n", + "normal_data = torch.randn(batch_size, 2)\n", + "\n", + "# URL-like strings: all use the same dtype so TQ can pack them on get\n", + "shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]\n", + "urls = [f\"float32:{h}x{w}\" for h, w in shapes]\n", + "\n", + "parser_fields = TensorDict(\n", + " {\n", + " \"normal_data\": normal_data,\n", + " \"data_to_be_parsed\": NonTensorStack(*urls),\n", + " },\n", + " batch_size=batch_size,\n", + ")\n", + "\n", + "data_parser_keys = [f\"data_parser_sample_{i}\" for i in range(batch_size)]\n", + "\n", + "put_start_time = time.perf_counter()\n", + "meta = tq.kv_batch_put(\n", + " keys=data_parser_keys,\n", + " partition_id=\"train\",\n", + " fields=parser_fields,\n", + " data_parser=concurrent_batch_url_parser,\n", + ")\n", + "put_elapsed = time.perf_counter() - put_start_time\n", + "print(f\"Put succeeded. Fields: {meta.fields}\")\n", + "print(f\"Total kv_batch_put time: {put_elapsed:.2f}s (concurrency keeps it ~1s, not {batch_size}s)\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": "When `kv_batch_put` returns, the user-defined data parser has also finished executing. We can then safely call `kv_batch_get` to retrieve the parsed data." + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2026-04-22 14:15:17,477 - WARNING - transfer_queue.storage.managers.simple_backend_manager - Failed to pack nested tensor with jagged layout. Falling back to strided layout. Detailed error: Cannot represent given tensor list as a nested tensor with the jagged layout. Note that the jagged layout only allows for a single ragged dimension. For example: (B, *, D_0, D_1, ..., D_N), with ragged * dim.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[PASS] normal_data is unchanged.\n", + "[PASS] All 32 parsed tensors have correct dtype & shape.\n", + "[PASS] Timing looks concurrent: 1.02s < 2.0s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "" + ] + } + ], + "source": [ + "result = tq.kv_batch_get(keys=data_parser_keys, partition_id=\"train\")\n", + "\n", + "# normal_data should be unchanged\n", + "torch.testing.assert_close(result[\"normal_data\"], normal_data)\n", + "print(\"[PASS] normal_data is unchanged.\")\n", + "\n", + "# data_to_be_parsed should now be tensors with the requested shapes\n", + "expected_shapes = [(i % 4 + 1, i % 3 + 2) for i in range(batch_size)]\n", + "for i, expected in enumerate(expected_shapes):\n", + " tensor = result[\"data_to_be_parsed\"][i]\n", + " assert tensor.dtype == torch.float32\n", + " actual = tuple(tensor.shape)\n", + " assert actual == expected, f\"Mismatch at {i}: {actual} != {expected}\"\n", + "print(f\"[PASS] All {batch_size} parsed tensors have correct dtype & shape.\")\n", + "\n", + "# Timing sanity check: serial would be ~batch_size seconds.\n", + "# Because asyncio tasks run in parallel inside the parser, it should be ~1 s.\n", + "assert put_elapsed < 2.0, f\"Expected concurrent execution (~1s), but took {put_elapsed:.2f}s.\"\n", + "print(f\"[PASS] Timing looks concurrent: {put_elapsed:.2f}s < 2.0s\")\n", + "\n", + "# wait for Ray log collect\n", + "time.sleep(2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 12. Multiple Partitions\n\nPartitions provide logical isolation — the same key name can exist in\ndifferent partitions without conflict." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -905,14 +1035,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 12. Clean Up — `kv_clear` and `close`\n", - "\n", - "Remove specific keys with `kv_clear`, then shut down the system with `tq.close()`." + "## 13. Clean Up — `kv_clear` and `close`\n\nRemove specific keys with `kv_clear`, then shut down the system with `tq.close()`." ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -929,6 +1057,7 @@ "tq.kv_clear(keys=\"sample_0\", partition_id=\"train\")\n", "tq.kv_clear(keys=\"sample_1\", partition_id=\"train\")\n", "tq.kv_clear(keys=keys, partition_id=\"train\")\n", + "tq.kv_clear(keys=data_parser_keys, partition_id=\"train\")\n", "tq.kv_clear(keys=\"val_sample_0\", partition_id=\"validation\")\n", "print(\"All keys cleared.\")\n", "\n", @@ -939,7 +1068,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -960,30 +1089,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "---\n", - "\n", - "## Summary\n", - "\n", - "| Operation | Function | Notes |\n", - "|---|---|---|\n", - "| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n", - "| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n", - "| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n", - "| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n", - "| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n", - "| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n", - "| Close | `tq.close()` | Tears down controller & storage |\n", - "\n", - "For **async** variants, use `async_kv_put`, `async_kv_batch_put`,\n", - "`async_kv_batch_get`, `async_kv_list`, and `async_kv_clear`.\n", - "\n", - "For low-level, metadata-based access, see `tq.get_client()` and the\n", - "[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)." + "---\n\n## Summary\n\n| Operation | Function | Notes |\n|---|---|---|\n| Init | `tq.init(config)` | Call once; subsequent processes auto-connect |\n| Put single | `tq.kv_put(key, partition_id, fields, tag)` | `fields` can be a plain dict |\n| Put batch | `tq.kv_batch_put(keys, partition_id, fields, tags)` | `fields` must be a `TensorDict` |\n| Put with parser | `tq.kv_batch_put(..., data_parser=fn)` | Only for **SimpleStorage**; receives dict, can use asyncio inside |\n| Get | `tq.kv_batch_get(keys, partition_id, select_fields=None)` | Returns a `TensorDict` |\n| List | `tq.kv_list(partition_id=None)` | Returns `{partition: {key: tag}}` |\n| Clear | `tq.kv_clear(keys, partition_id)` | Removes keys + data |\n| Close | `tq.close()` | Tears down controller & storage |\n\nFor **async** variants, use `async_kv_put`, `async_kv_batch_put`,\n`async_kv_batch_get`, `async_kv_list`, and `async_kv_clear`.\n\nFor low-level, metadata-based access, see `tq.get_client()` and the\n[official tutorials](https://github.com/Ascend/TransferQueue/tree/main/tutorial)." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, "outputs": [], "source": []