From 940186ea9aa5703694b100b486c71f35c7723c61 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 26 Jan 2026 14:21:01 +0800 Subject: [PATCH 1/5] fix sync client in async functions Signed-off-by: 0oshowero0 --- requirements.txt | 3 ++- transfer_queue/client.py | 52 ++++++++++++++++++++++++++-------------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1da090c8..e4fd9ef9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ pyzmq hydra-core numpy<2.0.0 msgspec -psutil \ No newline at end of file +psutil +asgiref \ No newline at end of file diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 3bcfcc5e..8901f611 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import logging import os from functools import wraps @@ -24,6 +23,7 @@ import torch import zmq import zmq.asyncio +from asgiref.sync import async_to_sync from tensordict import TensorDict from torch import Tensor @@ -808,6 +808,24 @@ def __init__( controller_info, ) + self._bind_sync_methods() + + def _bind_sync_methods( + self, + ): + """Convert and bind synchronous methods.""" + + self._put = async_to_sync(self.async_put) + self._get_meta = async_to_sync(self.async_get_meta) + self._get_data = async_to_sync(self.async_get_data) + self._clear_partition = async_to_sync(self.async_clear_partition) + self._clear_samples = async_to_sync(self.async_clear_samples) + self._get_consumption_status = async_to_sync(self.async_get_consumption_status) + self._get_production_status = async_to_sync(self.async_get_production_status) + self._check_consumption_status = async_to_sync(self.async_check_consumption_status) + self._check_production_status = async_to_sync(self.async_check_production_status) + self._get_partition_list = async_to_sync(self.async_get_partition_list) + def put( self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None ) -> BatchMeta: @@ -822,7 +840,7 @@ def put( BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved metadata; will be updated in a future version to reflect the post-put state) """ - return asyncio.run(self.async_put(data, metadata, partition_id)) + return self._put(data=data, metadata=metadata, partition_id=partition_id) def get_meta( self, @@ -845,14 +863,12 @@ def get_meta( Returns: BatchMeta: Batch metadata containing data location information """ - return asyncio.run( - self.async_get_meta( - data_fields=data_fields, - batch_size=batch_size, - partition_id=partition_id, - task_name=task_name, - sampling_config=sampling_config, - ) + return self._get_meta( + data_fields=data_fields, + batch_size=batch_size, + partition_id=partition_id, + task_name=task_name, + sampling_config=sampling_config, ) def get_data(self, metadata: BatchMeta) -> TensorDict: @@ -864,7 +880,7 @@ def get_data(self, metadata: BatchMeta) -> TensorDict: Returns: TensorDict containing requested data fields """ - return asyncio.run(self.async_get_data(metadata)) + return self._get_data(metadata=metadata) def clear_partition(self, partition_id: str): """Synchronously clear the whole partition from storage units and controller. @@ -872,7 +888,7 @@ def clear_partition(self, partition_id: str): Args: partition_id: The partition id to clear data for """ - return asyncio.run(self.async_clear_partition(partition_id)) + return self._clear_partition(partition_id=partition_id) def clear_samples(self, metadata: BatchMeta): """Synchronously clear specific samples from storage units and controller metadata. @@ -880,7 +896,7 @@ def clear_samples(self, metadata: BatchMeta): Args: metadata: The BatchMeta of the corresponding data to be cleared """ - return asyncio.run(self.async_clear_samples(metadata)) + return self._clear_samples(metadata=metadata) def check_consumption_status(self, task_name: str, partition_id: str) -> bool: """Synchronously check if all samples for a partition have been consumed by a specific task. @@ -892,7 +908,7 @@ def check_consumption_status(self, task_name: str, partition_id: str) -> bool: Returns: bool: True if all samples have been consumed by the task, False otherwise """ - return asyncio.run(self.async_check_consumption_status(task_name, partition_id)) + return self._check_consumption_status(task_name=task_name, partition_id=partition_id) def get_consumption_status( self, @@ -917,7 +933,7 @@ def get_consumption_status( ... ) >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") """ - return asyncio.run(self.async_get_consumption_status(task_name, partition_id)) + return self._get_consumption_status(task_name, partition_id) def check_production_status(self, data_fields: list[str], partition_id: str) -> bool: """Synchronously check if all samples for a partition are ready (produced) for consumption. @@ -929,7 +945,7 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> Returns: bool: True if all samples have been produced and ready, False otherwise """ - return asyncio.run(self.async_check_production_status(data_fields, partition_id)) + return self._check_production_status(data_fields=data_fields, partition_id=partition_id) def get_production_status( self, @@ -954,7 +970,7 @@ def get_production_status( ... ) >>> print(f"Global index: {global_index}, Production status: {production_status}") """ - return asyncio.run(self.async_get_production_status(data_fields, partition_id)) + return self._get_production_status(data_fields=data_fields, partition_id=partition_id) def get_partition_list( self, @@ -964,7 +980,7 @@ def get_partition_list( Returns: list[str]: List of partition ids managed by the controller """ - return asyncio.run(self.async_get_partition_list()) + return self._get_partition_list() def process_zmq_server_info( From 08f8385a78f3a662f3935c2feb35f86259e5aa84 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 26 Jan 2026 15:14:31 +0800 Subject: [PATCH 2/5] add CI Signed-off-by: 0oshowero0 --- tests/test_client.py | 95 ++++++++++++++++++++++++++++++++++++++++ transfer_queue/client.py | 49 ++++++++++++++++----- 2 files changed, 133 insertions(+), 11 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 0e77776f..36ab84dd 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -679,3 +679,98 @@ async def test_async_clear_samples_with_empty_metadata(client_setup): # If no exception is raised, the test passes assert True + + +@pytest.mark.asyncio +async def test_sync_methods_work_in_async_event_loop(client_setup): + """Test all synchronous methods can be called from within an asyncio event loop. + + This test verifies that the sync methods can be called directly from an async + function without causing "asyncio.run() cannot be called from a running loop" errors. + """ + client, _, _ = client_setup + + test_data = TensorDict({"tokens": torch.randint(0, 100, (3, 64))}, batch_size=3) + + # Test sync put + metadata = client.put(data=test_data, partition_id="0") + assert metadata is not None + + # Test sync get_meta - use fields that mock returns + metadata = client.get_meta( + data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, partition_id="0" + ) + assert metadata is not None + assert len(metadata.global_indexes) == 2 + + # Test sync get_data - verify we get the expected fields from mock + result = client.get_data(metadata) + assert result is not None + assert "log_probs" in result + assert "prompt_text" in result + + # Test sync check_consumption_status + is_consumed = client.check_consumption_status(task_name="generate_sequences", partition_id="train_0") + assert isinstance(is_consumed, bool) + + # Test sync get_consumption_status + global_index, consumption_status = client.get_consumption_status( + task_name="generate_sequences", partition_id="train_0" + ) + assert global_index is not None + assert consumption_status is not None + + # Test sync check_production_status + is_produced = client.check_production_status(data_fields=["log_probs", "prompt_text"], partition_id="train_0") + assert isinstance(is_produced, bool) + + # Test sync get_production_status + global_index, production_status = client.get_production_status( + data_fields=["log_probs", "prompt_text"], partition_id="train_0" + ) + assert global_index is not None + assert production_status is not None + + # Test sync get_partition_list + partition_list = client.get_partition_list() + assert isinstance(partition_list, list) + assert len(partition_list) > 0 + + # Test sync clear_partition + client.clear_partition(partition_id="test_partition") + + # Test sync clear_samples + metadata = client.get_meta(data_fields=["log_probs", "prompt_text"], batch_size=2, partition_id="0") + client.clear_samples(metadata=metadata) + + print("✓ All sync methods work correctly when called from within asyncio event loop") + + +@pytest.mark.asyncio +async def test_sync_and_async_methods_mixed_usage(client_setup): + """Test mixing sync and async method calls within the same async context. + + This test verifies that async methods and sync methods can be used interchangeably + without conflicts when called from an async function. + """ + client, _, _ = client_setup + + test_data = TensorDict({"tokens": torch.randint(0, 100, (2, 32))}, batch_size=2) + + # Call sync method first + sync_put_result = client.put(data=test_data, partition_id="0") + assert sync_put_result is not None + + # Call async method + async_metadata = await client.async_get_meta(data_fields=["tokens"], batch_size=2, partition_id="0") + assert async_metadata is not None + + # Call sync method again + sync_get_meta_result = client.get_meta(data_fields=["tokens"], batch_size=2, partition_id="0") + assert sync_get_meta_result is not None + + # Call async method + async_data = await client.async_get_data(sync_get_meta_result) + assert async_data is not None + + print("✓ Mixed async and sync method calls work correctly") diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 8901f611..af2daac0 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import os +import threading from functools import wraps from typing import Any, Callable, Optional, Union from uuid import uuid4 @@ -23,7 +25,6 @@ import torch import zmq import zmq.asyncio -from asgiref.sync import async_to_sync from tensordict import TensorDict from torch import Tensor @@ -808,23 +809,44 @@ def __init__( controller_info, ) + # create new event loop in a separated thread + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._start_loop, daemon=True) + self._thread.start() + + # convert and bind sync methods self._bind_sync_methods() + def _start_loop(self): + """Start the synchronous loop.""" + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + def _bind_sync_methods( self, ): """Convert and bind synchronous methods.""" - self._put = async_to_sync(self.async_put) - self._get_meta = async_to_sync(self.async_get_meta) - self._get_data = async_to_sync(self.async_get_data) - self._clear_partition = async_to_sync(self.async_clear_partition) - self._clear_samples = async_to_sync(self.async_clear_samples) - self._get_consumption_status = async_to_sync(self.async_get_consumption_status) - self._get_production_status = async_to_sync(self.async_get_production_status) - self._check_consumption_status = async_to_sync(self.async_check_consumption_status) - self._check_production_status = async_to_sync(self.async_check_production_status) - self._get_partition_list = async_to_sync(self.async_get_partition_list) + def _run(coro): + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result() + + def _make_sync(async_method): + def wrapper(*args, **kwargs): + return _run(async_method(*args, **kwargs)) + + return wrapper + + self._put = _make_sync(self.async_put) + self._get_meta = _make_sync(self.async_get_meta) + self._get_data = _make_sync(self.async_get_data) + self._clear_partition = _make_sync(self.async_clear_partition) + self._clear_samples = _make_sync(self.async_clear_samples) + self._get_consumption_status = _make_sync(self.async_get_consumption_status) + self._get_production_status = _make_sync(self.async_get_production_status) + self._check_consumption_status = _make_sync(self.async_check_consumption_status) + self._check_production_status = _make_sync(self.async_check_production_status) + self._get_partition_list = _make_sync(self.async_get_partition_list) def put( self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None @@ -982,6 +1004,11 @@ def get_partition_list( """ return self._get_partition_list() + def close(self): + """Close the client and cleanup resources including the executor.""" + if hasattr(self, "_executor"): + self._executor.shutdown(wait=True) + def process_zmq_server_info( handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]] From ae0dcb247c8fac4c274588a44a8e82a2071b9158 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 26 Jan 2026 15:15:16 +0800 Subject: [PATCH 3/5] remove unused package Signed-off-by: 0oshowero0 --- requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index e4fd9ef9..1da090c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,4 @@ pyzmq hydra-core numpy<2.0.0 msgspec -psutil -asgiref \ No newline at end of file +psutil \ No newline at end of file From f92e9df6ca46951e6391fa72fed0a9ed78b00629 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 26 Jan 2026 15:23:54 +0800 Subject: [PATCH 4/5] update Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index af2daac0..f2122eb8 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -837,6 +837,8 @@ def wrapper(*args, **kwargs): return wrapper + # Bind internal sync wrappers. Public methods are defined explicitly below + # to ensure proper type hints and documentation. self._put = _make_sync(self.async_put) self._get_meta = _make_sync(self.async_get_meta) self._get_data = _make_sync(self.async_get_data) @@ -1004,11 +1006,6 @@ def get_partition_list( """ return self._get_partition_list() - def close(self): - """Close the client and cleanup resources including the executor.""" - if hasattr(self, "_executor"): - self._executor.shutdown(wait=True) - def process_zmq_server_info( handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]] From 3d45997d8f5a941ff87d4166f6938df7faeb7453 Mon Sep 17 00:00:00 2001 From: 0oshowero0 Date: Mon, 26 Jan 2026 15:50:56 +0800 Subject: [PATCH 5/5] add close() method for sync client Signed-off-by: 0oshowero0 --- transfer_queue/client.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/transfer_queue/client.py b/transfer_queue/client.py index f2122eb8..dcbf2e02 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -809,7 +809,7 @@ def __init__( controller_info, ) - # create new event loop in a separated thread + # create new event loop in a separate thread self._loop = asyncio.new_event_loop() self._thread = threading.Thread(target=self._start_loop, daemon=True) self._thread.start() @@ -1006,6 +1006,24 @@ def get_partition_list( """ return self._get_partition_list() + def close(self) -> None: + """Close the client and cleanup resources including event loop and thread.""" + + if hasattr(self, "_loop") and self._loop is not None: + self._loop.call_soon_threadsafe(self._loop.stop) + + if hasattr(self, "_thread") and self._thread is not None: + self._thread.join(timeout=5.0) + if self._thread.is_alive(): + logger.warning(f"[{self.client_id}]: Background thread did not stop within timeout") + + try: + self._loop.close() + except Exception as e: + logger.warning(f"[{self.client_id}]: Error closing event loop: {e}") + + super().close() + def process_zmq_server_info( handlers: dict[Any, Union["TransferQueueController", "TransferQueueStorageManager", "SimpleStorageUnit"]]