Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,98 @@ async def test_async_clear_samples_with_empty_metadata(client_setup):

# If no exception is raised, the test passes
assert True


@pytest.mark.asyncio
async def test_sync_methods_work_in_async_event_loop(client_setup):
"""Test all synchronous methods can be called from within an asyncio event loop.

This test verifies that the sync methods can be called directly from an async
function without causing "asyncio.run() cannot be called from a running loop" errors.
"""
client, _, _ = client_setup

test_data = TensorDict({"tokens": torch.randint(0, 100, (3, 64))}, batch_size=3)

# Test sync put
metadata = client.put(data=test_data, partition_id="0")
assert metadata is not None

# Test sync get_meta - use fields that mock returns
metadata = client.get_meta(
data_fields=["log_probs", "variable_length_sequences", "prompt_text"], batch_size=2, partition_id="0"
)
assert metadata is not None
assert len(metadata.global_indexes) == 2

# Test sync get_data - verify we get the expected fields from mock
result = client.get_data(metadata)
assert result is not None
assert "log_probs" in result
assert "prompt_text" in result

# Test sync check_consumption_status
is_consumed = client.check_consumption_status(task_name="generate_sequences", partition_id="train_0")
assert isinstance(is_consumed, bool)

# Test sync get_consumption_status
global_index, consumption_status = client.get_consumption_status(
task_name="generate_sequences", partition_id="train_0"
)
assert global_index is not None
assert consumption_status is not None

# Test sync check_production_status
is_produced = client.check_production_status(data_fields=["log_probs", "prompt_text"], partition_id="train_0")
assert isinstance(is_produced, bool)

# Test sync get_production_status
global_index, production_status = client.get_production_status(
data_fields=["log_probs", "prompt_text"], partition_id="train_0"
)
assert global_index is not None
assert production_status is not None

# Test sync get_partition_list
partition_list = client.get_partition_list()
assert isinstance(partition_list, list)
assert len(partition_list) > 0

# Test sync clear_partition
client.clear_partition(partition_id="test_partition")

# Test sync clear_samples
metadata = client.get_meta(data_fields=["log_probs", "prompt_text"], batch_size=2, partition_id="0")
client.clear_samples(metadata=metadata)

print("✓ All sync methods work correctly when called from within asyncio event loop")


@pytest.mark.asyncio
async def test_sync_and_async_methods_mixed_usage(client_setup):
"""Test mixing sync and async method calls within the same async context.

This test verifies that async methods and sync methods can be used interchangeably
without conflicts when called from an async function.
"""
client, _, _ = client_setup

test_data = TensorDict({"tokens": torch.randint(0, 100, (2, 32))}, batch_size=2)

# Call sync method first
sync_put_result = client.put(data=test_data, partition_id="0")
assert sync_put_result is not None

# Call async method
async_metadata = await client.async_get_meta(data_fields=["tokens"], batch_size=2, partition_id="0")
assert async_metadata is not None

# Call sync method again
sync_get_meta_result = client.get_meta(data_fields=["tokens"], batch_size=2, partition_id="0")
assert sync_get_meta_result is not None

# Call async method
async_data = await client.async_get_data(sync_get_meta_result)
assert async_data is not None

print("✓ Mixed async and sync method calls work correctly")
92 changes: 75 additions & 17 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import asyncio
import logging
import os
import threading
from functools import wraps
from typing import Any, Callable, Optional, Union
from uuid import uuid4
Expand Down Expand Up @@ -808,6 +809,47 @@ def __init__(
controller_info,
)

# create new event loop in a separate thread
self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._start_loop, daemon=True)
self._thread.start()

# convert and bind sync methods
self._bind_sync_methods()

def _start_loop(self):
"""Start the synchronous loop."""
asyncio.set_event_loop(self._loop)
self._loop.run_forever()

def _bind_sync_methods(
self,
):
"""Convert and bind synchronous methods."""

def _run(coro):
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result()

def _make_sync(async_method):
def wrapper(*args, **kwargs):
return _run(async_method(*args, **kwargs))

return wrapper

# Bind internal sync wrappers. Public methods are defined explicitly below
# to ensure proper type hints and documentation.
self._put = _make_sync(self.async_put)
self._get_meta = _make_sync(self.async_get_meta)
self._get_data = _make_sync(self.async_get_data)
self._clear_partition = _make_sync(self.async_clear_partition)
self._clear_samples = _make_sync(self.async_clear_samples)
self._get_consumption_status = _make_sync(self.async_get_consumption_status)
self._get_production_status = _make_sync(self.async_get_production_status)
self._check_consumption_status = _make_sync(self.async_check_consumption_status)
self._check_production_status = _make_sync(self.async_check_production_status)
self._get_partition_list = _make_sync(self.async_get_partition_list)

def put(
self, data: TensorDict, metadata: Optional[BatchMeta] = None, partition_id: Optional[str] = None
) -> BatchMeta:
Expand All @@ -822,7 +864,7 @@ def put(
BatchMeta: The metadata used for the put operation (currently returns the input metadata or auto-retrieved
metadata; will be updated in a future version to reflect the post-put state)
"""
return asyncio.run(self.async_put(data, metadata, partition_id))
return self._put(data=data, metadata=metadata, partition_id=partition_id)

def get_meta(
self,
Expand All @@ -845,14 +887,12 @@ def get_meta(
Returns:
BatchMeta: Batch metadata containing data location information
"""
return asyncio.run(
self.async_get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
task_name=task_name,
sampling_config=sampling_config,
)
return self._get_meta(
data_fields=data_fields,
batch_size=batch_size,
partition_id=partition_id,
task_name=task_name,
sampling_config=sampling_config,
)

def get_data(self, metadata: BatchMeta) -> TensorDict:
Expand All @@ -864,23 +904,23 @@ def get_data(self, metadata: BatchMeta) -> TensorDict:
Returns:
TensorDict containing requested data fields
"""
return asyncio.run(self.async_get_data(metadata))
return self._get_data(metadata=metadata)

def clear_partition(self, partition_id: str):
"""Synchronously clear the whole partition from storage units and controller.

Args:
partition_id: The partition id to clear data for
"""
return asyncio.run(self.async_clear_partition(partition_id))
return self._clear_partition(partition_id=partition_id)

def clear_samples(self, metadata: BatchMeta):
"""Synchronously clear specific samples from storage units and controller metadata.

Args:
metadata: The BatchMeta of the corresponding data to be cleared
"""
return asyncio.run(self.async_clear_samples(metadata))
return self._clear_samples(metadata=metadata)

def check_consumption_status(self, task_name: str, partition_id: str) -> bool:
"""Synchronously check if all samples for a partition have been consumed by a specific task.
Expand All @@ -892,7 +932,7 @@ def check_consumption_status(self, task_name: str, partition_id: str) -> bool:
Returns:
bool: True if all samples have been consumed by the task, False otherwise
"""
return asyncio.run(self.async_check_consumption_status(task_name, partition_id))
return self._check_consumption_status(task_name=task_name, partition_id=partition_id)

def get_consumption_status(
self,
Expand All @@ -917,7 +957,7 @@ def get_consumption_status(
... )
>>> print(f"Global index: {global_index}, Consumption status: {consumption_status}")
"""
return asyncio.run(self.async_get_consumption_status(task_name, partition_id))
return self._get_consumption_status(task_name, partition_id)

def check_production_status(self, data_fields: list[str], partition_id: str) -> bool:
"""Synchronously check if all samples for a partition are ready (produced) for consumption.
Expand All @@ -929,7 +969,7 @@ def check_production_status(self, data_fields: list[str], partition_id: str) ->
Returns:
bool: True if all samples have been produced and ready, False otherwise
"""
return asyncio.run(self.async_check_production_status(data_fields, partition_id))
return self._check_production_status(data_fields=data_fields, partition_id=partition_id)

def get_production_status(
self,
Expand All @@ -954,7 +994,7 @@ def get_production_status(
... )
>>> print(f"Global index: {global_index}, Production status: {production_status}")
"""
return asyncio.run(self.async_get_production_status(data_fields, partition_id))
return self._get_production_status(data_fields=data_fields, partition_id=partition_id)

def get_partition_list(
self,
Expand All @@ -964,7 +1004,25 @@ def get_partition_list(
Returns:
list[str]: List of partition ids managed by the controller
"""
return asyncio.run(self.async_get_partition_list())
return self._get_partition_list()

def close(self) -> None:
"""Close the client and cleanup resources including event loop and thread."""

if hasattr(self, "_loop") and self._loop is not None:
self._loop.call_soon_threadsafe(self._loop.stop)

if hasattr(self, "_thread") and self._thread is not None:
self._thread.join(timeout=5.0)
if self._thread.is_alive():
logger.warning(f"[{self.client_id}]: Background thread did not stop within timeout")

try:
self._loop.close()
except Exception as e:
logger.warning(f"[{self.client_id}]: Error closing event loop: {e}")

super().close()


def process_zmq_server_info(
Expand Down