diff --git a/tests/test_client.py b/tests/test_client.py index 0e77776f..36ab84dd 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -679,3 +679,98 @@ async def test_async_clear_samples_with_empty_metadata(client_setup): # If no exception is raised, the test passes assert True + + +@pytest.mark.asyncio +async def test_sync_methods_work_in_async_event_loop(client_setup): + """Test all synchronous methods can be called from within an asyncio event loop. + + This test verifies that the sync methods can be called directly from an async + function without causing "asyncio.run() cannot be called from a running loop" errors. + """ + client, _, _ = client_setup + + test_data = TensorDict({"tokens": torch.randint(0, 100, (3, 64))}, batch_size=3) + + # Test sync put + metadata = client.put(data=test_data, partition_id="0") + assert metadata is not None + + # Test sync get_meta - use fields that mock returns + metadata = client.get_meta( + data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, partition_id="0" + ) + assert metadata is not None + assert len(metadata.global_indexes) == 2 + + # Test sync get_data - verify we get the expected fields from mock + result = client.get_data(metadata) + assert result is not None + assert "log_probs" in result + assert "prompt_text" in result + + # Test sync check_consumption_status + is_consumed = client.check_consumption_status(task_name="generate_sequences", partition_id="train_0") + assert isinstance(is_consumed, bool) + + # Test sync get_consumption_status + global_index, consumption_status = client.get_consumption_status( + task_name="generate_sequences", partition_id="train_0" + ) + assert global_index is not None + assert consumption_status is not None + + # Test sync check_production_status + is_produced = client.check_production_status(data_fields=["log_probs", "prompt_text"], partition_id="train_0") + assert isinstance(is_produced, bool) + + # Test sync get_production_status + global_index, production_status = client.get_production_status( + data_fields=["log_probs", "prompt_text"], partition_id="train_0" + ) + assert global_index is not None + assert production_status is not None + + # Test sync get_partition_list + partition_list = client.get_partition_list() + assert isinstance(partition_list, list) + assert len(partition_list) > 0 + + # Test sync clear_partition + client.clear_partition(partition_id="test_partition") + + # Test sync clear_samples + metadata = client.get_meta(data_fields=["log_probs", "prompt_text"], batch_size=2, partition_id="0") + client.clear_samples(metadata=metadata) + + print("✓ All sync methods work correctly when called from within asyncio event loop") + + +@pytest.mark.asyncio +async def test_sync_and_async_methods_mixed_usage(client_setup): + """Test mixing sync and async method calls within the same async context. + + This test verifies that async methods and sync methods can be used interchangeably + without conflicts when called from an async function. + """ + client, _, _ = client_setup + + test_data = TensorDict({"tokens": torch.randint(0, 100, (2, 32))}, batch_size=2) + + # Call sync method first + sync_put_result = client.put(data=test_data, partition_id="0") + assert sync_put_result is not None + + # Call async method + async_metadata = await client.async_get_meta(data_fields=["tokens"], batch_size=2, partition_id="0") + assert async_metadata is not None + + # Call sync method again + sync_get_meta_result = client.get_meta(data_fields=["tokens"], batch_size=2, partition_id="0") + assert sync_get_meta_result is not None + + # Call async method + async_data = await client.async_get_data(sync_get_meta_result) + assert async_data is not None + + print("✓ Mixed async and sync method calls work correctly") diff --git a/transfer_queue/client.py b/transfer_queue/client.py index 3bcfcc5e..dcbf2e02 100644 --- a/transfer_queue/client.py +++ b/transfer_queue/client.py @@ -16,6 +16,7 @@ import asyncio import logging import os +import threading from functools import wraps from typing import Any, Callable, Optional, Union from uuid import uuid4 @@ -808,6 +809,47 @@ def __init__( controller_info, ) + # create new event loop in a separate thread + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._start_loop, daemon=True) + self._thread.start() + + # convert and bind sync methods + self._bind_sync_methods() + + def _start_loop(self): + """Start the synchronous loop.""" + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + def _bind_sync_methods( + self, + ): + """Convert and bind synchronous methods.""" + + def _run(coro): + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result() + + def _make_sync(async_method): + def wrapper(*args, **kwargs): + return _run(async_method(*args, **kwargs)) + + return wrapper + + # Bind internal sync wrappers. Public methods are defined explicitly below + # to ensure proper type hints and documentation. + self._put = _make_sync(self.async_put) + self._get_meta = _make_sync(self.async_get_meta) + self._get_data = _make_sync(self.async_get_data) + self._clear_partition = _make_sync(self.async_clear_partition) + self._clear_samples = _make_sync(self.async_clear_samples) + self._get_consumption_status = _make_sync(self.async_get_consumption_status) + self._get_production_status = _make_sync(self.async_get_production_status) + self._check_consumption_status = _make_sync(self.async_check_consumption_status) + self._check_production_status = _make_sync(self.async_check_production_status) + self._get_partition_list = _make_sync(self.async_get_partition_list) + def put( self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None ) -> BatchMeta: @@ -822,7 +864,7 @@ def put( BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved metadata; will be updated in a future version to reflect the post-put state) """ - return asyncio.run(self.async_put(data, metadata, partition_id)) + return self._put(data=data, metadata=metadata, partition_id=partition_id) def get_meta( self, @@ -845,14 +887,12 @@ def get_meta( Returns: BatchMeta: Batch metadata containing data location information """ - return asyncio.run( - self.async_get_meta( - data_fields=data_fields, - batch_size=batch_size, - partition_id=partition_id, - task_name=task_name, - sampling_config=sampling_config, - ) + return self._get_meta( + data_fields=data_fields, + batch_size=batch_size, + partition_id=partition_id, + task_name=task_name, + sampling_config=sampling_config, ) def get_data(self, metadata: BatchMeta) -> TensorDict: @@ -864,7 +904,7 @@ def get_data(self, metadata: BatchMeta) -> TensorDict: Returns: TensorDict containing requested data fields """ - return asyncio.run(self.async_get_data(metadata)) + return self._get_data(metadata=metadata) def clear_partition(self, partition_id: str): """Synchronously clear the whole partition from storage units and controller. @@ -872,7 +912,7 @@ def clear_partition(self, partition_id: str): Args: partition_id: The partition id to clear data for """ - return asyncio.run(self.async_clear_partition(partition_id)) + return self._clear_partition(partition_id=partition_id) def clear_samples(self, metadata: BatchMeta): """Synchronously clear specific samples from storage units and controller metadata. @@ -880,7 +920,7 @@ def clear_samples(self, metadata: BatchMeta): Args: metadata: The BatchMeta of the corresponding data to be cleared """ - return asyncio.run(self.async_clear_samples(metadata)) + return self._clear_samples(metadata=metadata) def check_consumption_status(self, task_name: str, partition_id: str) -> bool: """Synchronously check if all samples for a partition have been consumed by a specific task. @@ -892,7 +932,7 @@ def check_consumption_status(self, task_name: str, partition_id: str) -> bool: Returns: bool: True if all samples have been consumed by the task, False otherwise """ - return asyncio.run(self.async_check_consumption_status(task_name, partition_id)) + return self._check_consumption_status(task_name=task_name, partition_id=partition_id) def get_consumption_status( self, @@ -917,7 +957,7 @@ def get_consumption_status( ... ) >>> print(f"Global index: {global_index}, Consumption status: {consumption_status}") """ - return asyncio.run(self.async_get_consumption_status(task_name, partition_id)) + return self._get_consumption_status(task_name, partition_id) def check_production_status(self, data_fields: list[str], partition_id: str) -> bool: """Synchronously check if all samples for a partition are ready (produced) for consumption. @@ -929,7 +969,7 @@ def check_production_status(self, data_fields: list[str], partition_id: str) -> Returns: bool: True if all samples have been produced and ready, False otherwise """ - return asyncio.run(self.async_check_production_status(data_fields, partition_id)) + return self._check_production_status(data_fields=data_fields, partition_id=partition_id) def get_production_status( self, @@ -954,7 +994,7 @@ def get_production_status( ... ) >>> print(f"Global index: {global_index}, Production status: {production_status}") """ - return asyncio.run(self.async_get_production_status(data_fields, partition_id)) + return self._get_production_status(data_fields=data_fields, partition_id=partition_id) def get_partition_list( self, @@ -964,7 +1004,25 @@ def get_partition_list( Returns: list[str]: List of partition ids managed by the controller """ - return asyncio.run(self.async_get_partition_list()) + return self._get_partition_list() + + def close(self) -> None: + """Close the client and cleanup resources including event loop and thread.""" + + if hasattr(self, "_loop") and self._loop is not None: + self._loop.call_soon_threadsafe(self._loop.stop) + + if hasattr(self, "_thread") and self._thread is not None: + self._thread.join(timeout=5.0) + if self._thread.is_alive(): + logger.warning(f"[{self.client_id}]: Background thread did not stop within timeout") + + try: + self._loop.close() + except Exception as e: + logger.warning(f"[{self.client_id}]: Error closing event loop: {e}") + + super().close() def process_zmq_server_info(