Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def client_setup(mock_controller, mock_storage):
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)

# Mock all storage manager methods to avoid real ZMQ operations
async def mock_put_data(data, metadata):
async def mock_put_data(data, metadata, data_parser=None):
pass # Just pretend to store the data

async def mock_get_data(metadata):
Expand Down Expand Up @@ -511,7 +511,7 @@ def test_single_controller_multiple_storages():
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)

# Mock all storage manager methods to avoid real ZMQ operations
async def mock_put_data(data, metadata):
async def mock_put_data(data, metadata, data_parser=None):
pass # Just pretend to store the data

async def mock_get_data(metadata):
Expand Down
181 changes: 179 additions & 2 deletions tests/test_simple_storage_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@ def __init__(self, storage_put_get_address):
self.socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout
self.socket.connect(storage_put_get_address)

def send_put(self, client_id, global_indexes, field_data):
def send_put(self, client_id, global_indexes, field_data, data_parser=None):
body = {"global_indexes": global_indexes, "data": field_data}
if data_parser is not None:
body["data_parser"] = data_parser
msg = ZMQMessage.create(
request_type=ZMQRequestType.PUT_DATA,
sender_id=f"mock_client_{client_id}",
body={"global_indexes": global_indexes, "data": field_data},
body=body,
)
self.socket.send_multipart(msg.serialize())
return ZMQMessage.deserialize(self.socket.recv_multipart(copy=False))
Expand Down Expand Up @@ -434,3 +437,177 @@ def test_storage_unit_data_capacity_uses_active_keys():
assert len(storage._active_keys) == 2
storage.put_data({"f": [4]}, global_indexes=[3])
assert storage._active_keys == {0, 1, 3}


def test_storage_unit_data_parser(storage_setup):
"""Test data_parser functionality in SimpleStorageUnit.

Writes two columns:
- normal_data: regular tensors, should remain unchanged
- data_to_be_parsed: list of shape descriptors (list of ints)

data_parser converts shape descriptors into random tensors of those shapes.
"""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)

def create_data_by_shape_parser(field_data):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need a clear input/output type definition of this function

if "data_to_be_parsed" in field_data:
shapes = field_data["data_to_be_parsed"]
field_data["data_to_be_parsed"] = [torch.randn(shape) for shape in shapes]
return field_data

# Prepare data: normal_data is a batch tensor, data_to_be_parsed is a list of shape lists
field_data = {
"normal_data": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
"data_to_be_parsed": [[2, 3], [1, 4], [3, 2]],
}
global_indexes = [0, 1, 2]

# Put with data_parser
response = client.send_put(0, global_indexes, field_data, data_parser=create_data_by_shape_parser)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"Put failed: {response.body}"

# Get back
response = client.send_get(0, global_indexes, ["normal_data", "data_to_be_parsed"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE

result = response.body["data"]

# Verify normal_data is unchanged
torch.testing.assert_close(result["normal_data"][0], torch.tensor([1.0, 2.0]))
torch.testing.assert_close(result["normal_data"][1], torch.tensor([3.0, 4.0]))
torch.testing.assert_close(result["normal_data"][2], torch.tensor([5.0, 6.0]))

# Verify data_to_be_parsed shapes match the input shape descriptors
expected_shapes = [(2, 3), (1, 4), (3, 2)]
for i, expected_shape in enumerate(expected_shapes):
actual_shape = tuple(result["data_to_be_parsed"][i].shape)
assert actual_shape == expected_shape, (
f"Shape mismatch at index {i}: expected {expected_shape}, got {actual_shape}"
)

client.close()


def test_storage_unit_data_parser_callable_types(storage_setup):
"""Test that various callable types (partial, callable class) work as data_parser."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)

from functools import partial

# 1. Test functools.partial
def _partial_parser(field_data, prefix):
if "text" in field_data:
field_data["text"] = [f"{prefix}{t}" for t in field_data["text"]]
return field_data

partial_parser = partial(_partial_parser, prefix="parsed_")

response = client.send_put(
0,
[0, 1],
{"text": ["a", "b"]},
data_parser=partial_parser,
)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"partial parser failed: {response.body}"

response = client.send_get(0, [0, 1], ["text"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
assert response.body["data"]["text"] == ["parsed_a", "parsed_b"]

# 2. Test callable class instance
class CallableParser:
def __call__(self, field_data):
if "value" in field_data:
field_data["value"] = [v * 2 for v in field_data["value"]]
return field_data

callable_parser = CallableParser()
response = client.send_put(
0,
[2, 3],
{"value": [1, 2]},
data_parser=callable_parser,
)
assert response.request_type == ZMQRequestType.PUT_DATA_RESPONSE, f"callable class parser failed: {response.body}"

response = client.send_get(0, [2, 3], ["value"])
assert response.request_type == ZMQRequestType.GET_DATA_RESPONSE
assert response.body["data"]["value"] == [2, 4]

client.close()


def test_storage_unit_data_parser_validation(storage_setup):
"""Test that invalid data_parser inputs produce clear error messages."""
_, put_get_address = storage_setup
client = MockStorageClient(put_get_address)

# 1. Non-callable data_parser should return a clear TypeError
response = client.send_put(
0,
[0],
{"data": [1]},
data_parser="not_callable",
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser must be callable" in response.body["message"]

# 2. data_parser returning non-dict should return a clear TypeError
def bad_parser(field_data):
return "not_a_dict"

response = client.send_put(
0,
[1],
{"data": [1]},
data_parser=bad_parser,
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser must return a dict" in response.body["message"]

# 3. data_parser deleting a key should return a clear ValueError
def delete_key_parser(field_data):
del field_data["data"]
return field_data

response = client.send_put(
0,
[2],
{"data": [1], "extra": [2]},
data_parser=delete_key_parser,
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser must not change dict keys" in response.body["message"]

# 4. data_parser adding a key should return a clear ValueError
def add_key_parser(field_data):
field_data["new_key"] = [999]
return field_data

response = client.send_put(
0,
[3],
{"data": [1]},
data_parser=add_key_parser,
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser must not change dict keys" in response.body["message"]

# 5. data_parser changing element count should return a clear ValueError
def wrong_len_parser(field_data):
field_data["data"] = field_data["data"][:-1]
return field_data

response = client.send_put(
0,
[4, 5],
{"data": [1, 2]},
data_parser=wrong_len_parser,
)
assert response.request_type == ZMQRequestType.PUT_ERROR
assert "data_parser changed the number of elements" in response.body["message"]

client.close()
31 changes: 28 additions & 3 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ async def async_put(
data: TensorDict,
metadata: Optional[BatchMeta] = None,
partition_id: Optional[str] = None,
data_parser: Optional[Callable[[Any], Any]] = None,
) -> BatchMeta:
"""Asynchronously write data to storage units based on metadata.

Expand All @@ -342,6 +343,16 @@ async def async_put(
metadata: Records the metadata of a batch of data samples, containing index and
storage unit information. If None, metadata will be auto-generated.
partition_id: Target data partition id (required if metadata is not provided)
data_parser: Optional callable to parse reference data (e.g., URLs) into real
content. The input is a slice of the `data` parameter, in plain
dict format (not TensorDict), mapping field_name -> batched values.
For a regular tensor column the value is a batched tensor; for
nested tensors (jagged or strided) and NonTensorStack columns
the values are extracted into a list. It must modify values
in-place based on the original keys; do not add or remove keys.
The number of elements per column must also remain unchanged.
Do not change the inner order of values within each column.
Only supported by SimpleStorage.

Returns:
BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
Expand Down Expand Up @@ -411,7 +422,7 @@ async def async_put(
with limit_pytorch_auto_parallel_threads(
target_num_threads=TQ_NUM_THREADS, info=f"[{self.client_id}] async_put"
):
await self.storage_manager.put_data(data, metadata)
await self.storage_manager.put_data(data, metadata, data_parser=data_parser)

Comment thread
0oshowero0 marked this conversation as resolved.
await self.async_set_custom_meta(metadata)

Expand Down Expand Up @@ -1279,7 +1290,11 @@ def set_custom_meta(self, metadata: BatchMeta) -> None:
return self._set_custom_meta(metadata=metadata)

def put(
self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None
self,
data: TensorDict,
metadata: Optional[BatchMeta] = None,
partition_id: Optional[str] = None,
data_parser: Optional[Callable[[Any], Any]] = None,
) -> BatchMeta:
"""Synchronously write data to storage units based on metadata.

Expand All @@ -1298,6 +1313,16 @@ def put(
metadata: Records the metadata of a batch of data samples, containing index and
storage unit information. If None, metadata will be auto-generated.
partition_id: Target data partition id (required if metadata is not provided)
data_parser: Optional callable to parse reference data (e.g., URLs) into real
content. The input is a slice of the `data` parameter, in plain
dict format (not TensorDict), mapping field_name -> batched values.
For a regular tensor column the value is a batched tensor; for
nested tensors (jagged or strided) and NonTensorStack columns
the values are extracted into a list. It must modify values
in-place based on the original keys; do not add or remove keys.
The number of elements per column must also remain unchanged.
Do not change the inner order of values within each column.
Only supported by SimpleStorage.

Returns:
BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
Expand Down Expand Up @@ -1336,7 +1361,7 @@ def put(
>>> # This will create metadata in "insert" mode internally.
>>> metadata = client.put(data=prompts_repeated_batch, partition_id=current_partition_id)
"""
return self._put(data=data, metadata=metadata, partition_id=partition_id)
return self._put(data=data, metadata=metadata, partition_id=partition_id, data_parser=data_parser)

def get_data(self, metadata: BatchMeta) -> TensorDict:
"""Synchronously fetch data from storage units and organize into TensorDict.
Expand Down
Loading
Loading