Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ This class encapsulates the core interaction logic within the TransferQueue syst

Currently, we support the following storage backends:

- SimpleStorageUnit: A basic CPU memory storage with minimal data format constraints and easy usability.
- SimpleStorage: A basic CPU memory storage with minimal data format constraints and easy usability.
- [Yuanrong](https://gitee.com/openeuler/yuanrong-datasystem) (beta, [#PR107](https://github.com/TransferQueue/TransferQueue/pull/107), [#PR96](https://github.com/TransferQueue/TransferQueue/pull/96)): An Ascend native data system that provides hierarchical storage interfaces including HBM/DRAM/SSD.
- [Mooncake Store](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM.
- [Ray Direct Transport](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors.
- [MooncakeStore](https://github.com/kvcache-ai/Mooncake) (alpha, [#PR162](https://github.com/TransferQueue/TransferQueue/pull/162)): A high-performance, KV-based hierarchical storage that supports RDMA transport between GPU and DRAM.
- [RayRDT](https://docs.ray.io/en/master/ray-core/direct-transport.html) (alpha, [#PR167](https://github.com/TransferQueue/TransferQueue/pull/167)): Ray's new feature that allows Ray to store and pass objects directly between Ray actors.

Among them, `SimpleStorageUnit` serves as our default storage backend, coordinated by the `AsyncSimpleStorageManager` class. Each storage unit can be deployed on a separate node, allowing for distributed data management.
Comment thread
0oshowero0 marked this conversation as resolved.

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,6 @@ yuanrong = [
# This is the rough equivalent of package_data={'': ['version/*']}
[tool.setuptools.package-data]
transfer_queue = [
"version/*",
"version/*",
"*.yaml"
]
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pyzmq
hydra-core
numpy<2.0.0
msgspec
psutil
psutil
omegaconf
15 changes: 6 additions & 9 deletions tests/test_async_simple_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ async def mock_async_storage_manager():
)

config = {
"storage_unit_infos": storage_unit_infos,
"controller_info": controller_info,
"zmq_info": storage_unit_infos,
}

# Mock the handshake process entirely to avoid ZMQ complexity
Expand Down Expand Up @@ -199,8 +198,7 @@ async def test_async_storage_manager_mapping_functions():
)

config = {
"storage_unit_infos": storage_unit_infos,
"controller_info": controller_info,
"zmq_info": storage_unit_infos,
}

# Mock ZMQ operations
Expand Down Expand Up @@ -230,7 +228,7 @@ async def test_async_storage_manager_mapping_functions():
mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize())

# Create manager
manager = AsyncSimpleStorageManager(config)
manager = AsyncSimpleStorageManager(controller_info, config)

# Test round-robin mapping for 3 storage units
# global_index -> storage_unit mapping: 0->storage_0, 1->storage_1, 2->storage_2,
Expand Down Expand Up @@ -266,16 +264,15 @@ async def test_async_storage_manager_error_handling():
}

# Mock controller info
controller_infos = ZMQServerInfo(
controller_info = ZMQServerInfo(
role=TransferQueueRole.CONTROLLER,
id="controller_0",
ip="127.0.0.1",
ports={"handshake_socket": 12346, "data_status_update_socket": 12347},
)

config = {
"storage_unit_infos": storage_unit_infos,
"controller_info": controller_infos,
"zmq_info": storage_unit_infos,
}

# Mock ZMQ operations
Expand Down Expand Up @@ -305,7 +302,7 @@ async def test_async_storage_manager_error_handling():
mock_socket.recv_multipart = Mock(return_value=handshake_response.serialize())

# Create manager
manager = AsyncSimpleStorageManager(config)
manager = AsyncSimpleStorageManager(controller_info, config)

# Mock operations that raise exceptions
manager._put_to_single_storage_unit = AsyncMock(side_effect=RuntimeError("Mock PUT error"))
Expand Down
8 changes: 4 additions & 4 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def client_setup(mock_controller, mock_storage):
):
config = {
"controller_info": mock_controller.zmq_server_info,
"storage_unit_infos": {mock_storage.storage_id: mock_storage.zmq_server_info},
"zmq_info": {mock_storage.storage_id: mock_storage.zmq_server_info},
}
client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)

# Mock all storage manager methods to avoid real ZMQ operations
async def mock_put_data(data, metadata):
Expand Down Expand Up @@ -411,9 +411,9 @@ def test_single_controller_multiple_storages():
):
config = {
"controller_info": controller.zmq_server_info,
"storage_unit_infos": {s.storage_id: s.zmq_server_info for s in storages},
"zmq_info": {s.storage_id: s.zmq_server_info for s in storages},
}
client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
client.initialize_storage_manager(manager_type="SimpleStorage", config=config)

# Mock all storage manager methods to avoid real ZMQ operations
async def mock_put_data(data, metadata):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_kv_storage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_merge_tensors_to_tensordict(mock_create, test_data):
mock_client = MagicMock()
mock_create.return_value = mock_client

manager = KVStorageManager(test_data["cfg"])
manager = KVStorageManager(controller_info=MagicMock(), config=test_data["cfg"])
assert manager.storage_client is mock_client
assert manager._multi_threads_executor is None

Expand Down Expand Up @@ -296,9 +296,9 @@ def test_put_data_with_custom_meta_from_storage_client(mock_notify, test_data_fo
mock_storage_client.put.return_value = mock_custom_meta

# Create manager with mocked dependencies
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
config = {"client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(config)
manager = KVStorageManager(controller_info=MagicMock(), config=config)

# Run put_data
asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"]))
Expand Down Expand Up @@ -348,7 +348,7 @@ def test_put_data_without_custom_meta(mock_notify, test_data_for_put_data):
# Create manager with mocked dependencies
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(config)
manager = KVStorageManager(controller_info=MagicMock(), config=config)

# Run put_data
asyncio.run(manager.put_data(test_data_for_put_data["data"], test_data_for_put_data["metadata"]))
Expand All @@ -371,7 +371,7 @@ def test_put_data_custom_meta_length_mismatch_raises_error(test_data_for_put_dat
# Create manager with mocked dependencies
config = {"controller_info": MagicMock(), "client_name": "MockClient"}
with patch(f"{STORAGE_CLIENT_FACTORY_PATH}.create", return_value=mock_storage_client):
manager = KVStorageManager(config)
manager = KVStorageManager(controller_info=MagicMock(), config=config)

# Run put_data and expect ValueError
with pytest.raises(ValueError) as exc_info:
Expand Down
38 changes: 33 additions & 5 deletions transfer_queue/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,50 @@

import os

from .client import (
TransferQueueClient,
process_zmq_server_info,
)
from .client import TransferQueueClient
from .controller import TransferQueueController
from .dataloader import StreamingDataLoader, StreamingDataset
from .interface import (
async_clear_partition,
async_clear_samples,
async_get_data,
async_get_meta,
async_put,
async_set_custom_meta,
clear_partition,
clear_samples,
close,
get_data,
get_meta,
init,
put,
set_custom_meta,
)
from .metadata import BatchMeta
from .sampler import BaseSampler
from .sampler.grpo_group_n_sampler import GRPOGroupNSampler
from .sampler.rank_aware_sampler import RankAwareSampler
from .sampler.sequential_sampler import SequentialSampler
from .storage import SimpleStorageUnit
from .utils.common import get_placement_group
from .utils.zmq_utils import ZMQServerInfo
from .utils.zmq_utils import ZMQServerInfo, process_zmq_server_info

__all__ = [
"init",
"get_meta",
"get_data",
"put",
"set_custom_meta",
"clear_samples",
"clear_partition",
"async_get_meta",
"async_get_data",
"async_put",
"async_set_custom_meta",
"async_clear_samples",
Comment thread
0oshowero0 marked this conversation as resolved.
"async_clear_partition",
"close",
] + [
"TransferQueueClient",
"StreamingDataset",
"StreamingDataLoader",
Expand Down
48 changes: 7 additions & 41 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,19 @@
import os
import threading
from functools import wraps
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional
from uuid import uuid4

import ray
import torch
import zmq
import zmq.asyncio
from tensordict import TensorDict
from torch import Tensor

from transfer_queue.controller import TransferQueueController
from transfer_queue.metadata import (
BatchMeta,
)
from transfer_queue.storage import (
SimpleStorageUnit,
TransferQueueStorageManager,
TransferQueueStorageManagerFactory,
)
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
Expand Down Expand Up @@ -95,11 +91,12 @@ def initialize_storage_manager(
AsyncSimpleStorageManager, KVStorageManager (under development), etc.
config: Configuration dictionary for the storage manager.
For AsyncSimpleStorageManager, must contain the following required keys:
- controller_info: ZMQ server information about the controller
- storage_unit_infos: ZMQ server information about the storage units
- zmq_info: ZMQ server information about the storage units

"""
self.storage_manager = TransferQueueStorageManagerFactory.create(manager_type, config)
self.storage_manager = TransferQueueStorageManagerFactory.create(
manager_type, controller_info=self._controller, config=config
)

# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
@staticmethod
Expand Down Expand Up @@ -1041,6 +1038,7 @@ def get_meta(
data_fields: list[str],
batch_size: int,
partition_id: str,
mode: str = "fetch",
task_name: Optional[str] = None,
sampling_config: Optional[dict[str, Any]] = None,
) -> BatchMeta:
Expand Down Expand Up @@ -1098,6 +1096,7 @@ def get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
mode=mode,
task_name=task_name,
sampling_config=sampling_config,
)
Expand Down Expand Up @@ -1304,36 +1303,3 @@ def close(self) -> None:
logger.warning(f"[{self.client_id}]: Error closing event loop: {e}")

super().close()


def process_zmq_server_info(
handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]]
| Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"],
): # noqa: UP007
"""Extract ZMQ server information from handler objects.

Args:
handlers: Dictionary of handler objects (controllers, storage managers, or storage units),
or a single handler object

Returns:
If handlers is a dictionary: Dictionary mapping handler names to their ZMQ server information
If handlers is a single object: ZMQ server information for that object

Examples:
>>> # Single handler
>>> controller = TransferQueueController.remote(...)
>>> info = process_zmq_server_info(controller)
>>>
>>> # Multiple handlers
>>> handlers = {"storage_0": storage_0, "storage_1": storage_1}
>>> info_dict = process_zmq_server_info(handlers)"""
# Handle single handler object case
if not isinstance(handlers, dict):
return ray.get(handlers.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
else:
# Handle dictionary case
server_info = {}
for name, handler in handlers.items():
server_info[name] = ray.get(handler.get_zmq_server_info.remote()) # type: ignore[union-attr, attr-defined]
return server_info
31 changes: 31 additions & 0 deletions transfer_queue/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# This is the default configuration of TransferQueue. Users may modify the default value
# and use transfer_queue.init(conf) to overwrite the config entries.

controller:
# User-defined sampler. User can pass sampler instance to overwrite this string config.
sampler: SequentialSampler
# Whether return an empty BatchMeta to prevent request blocking when no enough data is available
polling_mode: False
# ZMQ Server IP & Ports (automatically generated during init)
zmq_info: null


backend:
# Pluggable storage/transport backend of TransferQueue. Choose from:
# SimpleStorage, Yuanrong, MooncakeStore, ...
storage_backend: SimpleStorage

# For SimpleStorage:
SimpleStorage:
# Total number of samples
total_storage_size: 100000
# Number of distributed storage units for SimpleStorage backend
num_data_storage_units: 2
# ZMQ Server IP & Ports (automatically generated during init)
zmq_info: null

# For Yuanrong:
# TODO

# For MooncakeStore:
# TODO
34 changes: 34 additions & 0 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import ray
import torch
import zmq
from omegaconf import DictConfig
from ray.util import get_node_ip_address
from torch import Tensor

Expand Down Expand Up @@ -879,6 +880,7 @@ def __init__(

self.controller_id = f"TQ_CONTROLLER_{uuid4().hex[:8]}"
self.polling_mode = polling_mode
self.tq_config = None # global config for TransferQueue system

# Initialize ZMQ sockets for communication
self._init_zmq_socket()
Expand Down Expand Up @@ -1772,3 +1774,35 @@ def _update_data_status(self):
def get_zmq_server_info(self) -> ZMQServerInfo:
"""Get ZMQ server connection information."""
return self.zmq_server_info

def store_config(self, conf: DictConfig) -> None:
"""Store the global config of TransferQueue."""
self.tq_config = conf

def get_config(self) -> DictConfig:
"""Retrieve the global config of TransferQueue."""
return self.tq_config

def register_sampler(
self,
sampler: BaseSampler | type[BaseSampler] = SequentialSampler,
) -> None:
"""
Register a sampler instance or subclass after the controller is initialized.

Args:
sampler: Sampler instance or sampler class to use for data sampling.
- If a BaseSampler instance is provided, it will be used directly
- If a BaseSampler subclass is provided, it will be instantiated
- Defaults to SequentialSampler for simple sequential sampling
- Example: sampler=GRPOGroupNSampler() (instance)
- Example: sampler=SequentialSampler (class)
"""
if isinstance(sampler, BaseSampler):
self.sampler = sampler
elif isinstance(sampler, type) and issubclass(sampler, BaseSampler):
self.sampler = sampler()
else:
raise TypeError(
f"sampler {getattr(sampler, '__name__', repr(sampler))} must be an instance or subclass of BaseSampler"
)
Loading