Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions recipe/simple_use_case/single_controller_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,24 @@ def compute_loss(data1, _data2):

def compute_reward(response_ids: torch.Tensor) -> TensorDict:
"""Simulate a reward model that scores each token position in the response.
Returns a TensorDict with a ``"rm_score"`` field whose shape matches
``response_ids`` (i.e. one scalar per response token).
"""
time.sleep(1)
reward = torch.randn_like(response_ids, dtype=torch.float32)

return TensorDict({"rm_score": reward}, batch_size=response_ids.size(0))

Comment on lines 54 to +63
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring says the returned TensorDict's shape matches response_ids (one scalar per token), but the batch_size is set to response_ids.size(0) (batch-only). This is misleading and can cause confusion for TensorDict operations that rely on batch_size. Either update the docstring to clarify that only the first dimension is treated as batch, or set batch_size to match the intended per-token shape (e.g., include the sequence length dimension when appropriate).

Copilot uses AI. Check for mistakes.

def compute_advantage(rewards: torch.Tensor) -> TensorDict:
"""Simulate the process of computing advantage.

Returns a TensorDict with an ``"advantage"`` field whose shape matches
``response_ids`` (i.e. one scalar per response token).
``rewards`` (i.e. one scalar per reward).
"""
time.sleep(1)
advantage = torch.randn_like(response_ids, dtype=torch.float32)
return TensorDict({"advantage": advantage}, batch_size=response_ids.size(0))
advantage = torch.randn_like(rewards, dtype=torch.float32)
return TensorDict({"advantage": advantage}, batch_size=rewards.size(0))
Comment on lines +65 to +73
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as compute_reward: the docstring implies the output TensorDict shape matches rewards, but batch_size is set to rewards.size(0) (batch-only). Align the docstring with the actual TensorDict batch_size, or adjust batch_size to reflect the full intended shape.

Copilot uses AI. Check for mistakes.


class TrainingWorker:
Expand Down Expand Up @@ -89,7 +100,7 @@ def infer_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta:
"""Simulate forward-only inference"""
# 1. Pull data from storage
data = tq.kv_batch_get_by_meta(meta=kv_meta)
logger.info(f"compute_log_prob: got data {data}")
logger.info(f"infer_batch: got data {data}")

# 2. Model forward
output = compute_log_prob(data["prompt_ids"], data["response_ids"])
Expand Down Expand Up @@ -494,6 +505,13 @@ def fit(self):
meta = tq.kv_batch_put(keys=meta.keys, partition_id=meta.partition_id, fields=reward_output)
logger.info(f"demo reward KVBatchMeta: {meta}")

# ========================= Compute advantage =========================
meta.fields = ["response_ids", "ref_log_prob", "old_log_prob", "rm_score"]
advantage_data = tq.kv_batch_get_by_meta(meta=meta)
advantage_output = compute_advantage(advantage_data["rm_score"])
meta = tq.kv_batch_put(keys=meta.keys, partition_id=meta.partition_id, fields=advantage_output)
logger.info(f"demo advantage KVBatchMeta: {meta}")

# ========================= Update actor =========================
meta.fields = [
"input_ids",
Expand Down
11 changes: 2 additions & 9 deletions transfer_queue/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import asyncio
import logging
import os
import threading
from typing import Any, Callable, Optional
Expand All @@ -32,21 +31,15 @@
TransferQueueStorageManagerFactory,
)
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
with_zmq_socket,
)

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))

# Ensure logger has a handler
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logger.addHandler(handler)
logger = get_logger(__name__)

TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8))

Expand Down
11 changes: 2 additions & 9 deletions transfer_queue/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import copy
import logging
import os
import time
from collections import defaultdict
Expand All @@ -37,6 +36,7 @@
)
from transfer_queue.sampler import BaseSampler, SequentialSampler
from transfer_queue.utils.enum_utils import TransferQueueRole
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.perf_utils import IntervalPerfMonitor
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
Expand All @@ -48,14 +48,7 @@
get_node_ip_address_raw,
)

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))

# Ensure logger has a handler (for Ray Actor subprocess)
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logger.addHandler(handler)
logger = get_logger(__name__)

TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 1))
TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 5))
Expand Down
12 changes: 2 additions & 10 deletions transfer_queue/dataloader/streaming_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import Optional

import torch
from tensordict import TensorDict

from transfer_queue.dataloader.streaming_dataset import StreamingDataset
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.logging_utils import get_logger

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))

# Ensure logger has a handler
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logger.addHandler(handler)
logger = get_logger(__name__)


def _identity_collate_fn(data: tuple[TensorDict, BatchMeta]) -> tuple[TensorDict, BatchMeta]:
Expand Down
11 changes: 2 additions & 9 deletions transfer_queue/dataloader/streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import time
import uuid
Expand All @@ -25,19 +24,13 @@

from transfer_queue.client import TransferQueueClient
from transfer_queue.metadata import BatchMeta
from transfer_queue.utils.logging_utils import get_logger

TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float(
os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1)
) # in seconds

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))

# Ensure logger has a handler
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logger.addHandler(handler)
logger = get_logger(__name__)


class StreamingDataset(IterableDataset):
Expand Down
5 changes: 2 additions & 3 deletions transfer_queue/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import math
import os
import subprocess
Expand All @@ -35,14 +34,14 @@
from transfer_queue.sampler import BaseSampler
from transfer_queue.storage.simple_backend import SimpleStorageUnit
from transfer_queue.utils.common import get_placement_group
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.yuanrong_utils import (
cleanup_yuanrong_resources,
initialize_yuanrong_backend,
)
from transfer_queue.utils.zmq_utils import process_zmq_server_info

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
logger = get_logger(__name__)

_TRANSFER_QUEUE_CLIENT: Any = None
_TRANSFER_QUEUE_STORAGE: Any = None
Expand Down
11 changes: 2 additions & 9 deletions transfer_queue/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import copy
import dataclasses
import itertools
import logging
import os
from collections import defaultdict
from dataclasses import dataclass
from types import MappingProxyType
Expand All @@ -27,14 +25,9 @@
import torch
from tensordict import TensorDict

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
from transfer_queue.utils.logging_utils import get_logger

# Ensure logger has a handler
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logger.addHandler(handler)
logger = get_logger(__name__)


# ---------------------------------------------------------------------------
Expand Down
6 changes: 2 additions & 4 deletions transfer_queue/storage/clients/mooncake_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import pickle
from typing import Any, Optional

Expand All @@ -23,9 +21,9 @@

from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
from transfer_queue.storage.clients.factory import StorageClientFactory
from transfer_queue.utils.logging_utils import get_logger

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
logger = get_logger(__name__)

MOONCAKE_STORE_IMPORTED: bool = True
try:
Expand Down
6 changes: 2 additions & 4 deletions transfer_queue/storage/clients/yuanrong_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import struct
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -25,11 +23,11 @@

from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
from transfer_queue.storage.clients.factory import StorageClientFactory
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.serial_utils import _decoder, _encoder
from transfer_queue.utils.yuanrong_utils import find_reachable_host

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
logger = get_logger(__name__)


YUANRONG_DATASYSTEM_IMPORTED: bool = True
Expand Down
11 changes: 2 additions & 9 deletions transfer_queue/storage/managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import asyncio
import itertools
import logging
import os
import time
import weakref
Expand All @@ -34,16 +33,10 @@

from transfer_queue.metadata import BatchMeta, extract_field_schema
from transfer_queue.storage.clients.factory import StorageClientFactory
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))

# Ensure logger has a handler
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logger.addHandler(handler)
logger = get_logger(__name__)

# ZMQ timeouts (in seconds) and retry configurations
TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5))
Expand Down
6 changes: 2 additions & 4 deletions transfer_queue/storage/managers/mooncake_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import Any

from transfer_queue.storage.managers.base import KVStorageManager
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.zmq_utils import ZMQServerInfo

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
logger = get_logger(__name__)


@TransferQueueStorageManagerFactory.register("MooncakeStore")
Expand Down
11 changes: 2 additions & 9 deletions transfer_queue/storage/managers/simple_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import asyncio
import logging
import os
import warnings
from collections import defaultdict
Expand All @@ -30,21 +29,15 @@
from transfer_queue.metadata import BatchMeta, extract_field_schema
from transfer_queue.storage.managers.base import TransferQueueStorageManager
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.zmq_utils import (
ZMQMessage,
ZMQRequestType,
ZMQServerInfo,
with_zmq_socket,
)

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))

# Ensure logger has a handler
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logger.addHandler(handler)
logger = get_logger(__name__)

TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT = int(os.environ.get("TQ_SIMPLE_STORAGE_SEND_RECV_TIMEOUT", 200)) # seconds

Expand Down
12 changes: 2 additions & 10 deletions transfer_queue/storage/managers/yuanrong_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import Any

from transfer_queue.storage.managers.base import KVStorageManager
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.zmq_utils import ZMQServerInfo

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))

# Ensure logger has a handler
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
logger.addHandler(handler)
logger = get_logger(__name__)


@TransferQueueStorageManagerFactory.register("Yuanrong")
Expand Down
Loading
Loading