From 50c0bc3f459bb87b2b1909db04bcb87cd6160c9f Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 7 May 2026 20:19:20 +0800 Subject: [PATCH 01/22] [recipe] Add recipe demo to use StreamingDataset & StreamingDataLoader Signed-off-by: jxixi <916099156@qq.com> --- .../streaming_dataLoader_demo.py | 428 ++++++++++++++++++ 1 file changed, 428 insertions(+) create mode 100644 recipe/simple_use_case/streaming_dataLoader_demo.py diff --git a/recipe/simple_use_case/streaming_dataLoader_demo.py b/recipe/simple_use_case/streaming_dataLoader_demo.py new file mode 100644 index 0000000..19cf255 --- /dev/null +++ b/recipe/simple_use_case/streaming_dataLoader_demo.py @@ -0,0 +1,428 @@ +import argparse +import logging +import os +import time +from dataclasses import dataclass +from importlib import resources + +import ray +import torch +from omegaconf import OmegaConf +from tensordict import TensorDict + +import transfer_queue as tq +from transfer_queue import RankAwareSampler, StreamingDataLoader, StreamingDataset +from transfer_queue.metadata import BatchMeta + +logging.basicConfig(level=logging.INFO, format="%(asctime)s.%(msecs)03d - %(levelname)s - %(message)s", datefmt="%H:%M:%S") +logger = logging.getLogger(__name__) + +STAGE_NAMES = ["rollout", "ref", "actor", "reward", "update"] + + +def make_prompt_batch(step: int, config: "DemoConfig") -> TensorDict: + start_id = step * config.global_batch_size + generator = torch.Generator().manual_seed(config.seed + step) + sample_ids = torch.arange(start_id, start_id + config.global_batch_size, dtype=torch.long) + prompt_ids = torch.randint( + 0, + config.vocab_size, + (config.global_batch_size, config.prompt_length), + generator=generator, + dtype=torch.long, + ) + return TensorDict({"sample_id": sample_ids.unsqueeze(-1), "prompt_ids": prompt_ids}, batch_size=config.global_batch_size) + + +def generate_sequences(prompt_ids: torch.Tensor, config: "DemoConfig") -> TensorDict: + batch_size = prompt_ids.size(0) + generator = torch.Generator().manual_seed(config.seed + int(prompt_ids.sum().item())) + response_ids = torch.randint( + 0, + config.vocab_size, + (batch_size, config.response_length), + generator=generator, + dtype=torch.long, + ) + return TensorDict( + { + "input_ids": torch.cat([prompt_ids, response_ids], dim=1), + "response_ids": response_ids, + "response_mask": torch.ones_like(response_ids), + }, + batch_size=batch_size, + ) + + +def compute_log_prob(prompt_ids: torch.Tensor, response_ids: torch.Tensor) -> torch.Tensor: + return (response_ids.float().mean(dim=1, keepdim=True) - prompt_ids.float().mean(dim=1, keepdim=True)) / 1000.0 + + +def compute_reward(response_ids: torch.Tensor) -> torch.Tensor: + return response_ids.float().mean(dim=1, keepdim=True) / 1000.0 + + +def compute_loss(old_log_prob: torch.Tensor, ref_log_prob: torch.Tensor, advantage: torch.Tensor) -> torch.Tensor: + return (old_log_prob - ref_log_prob - advantage).abs() + + +@ray.remote +class ProgressTracker: + def __init__(self, stage_names: list[str], num_steps: int): + self.counts = {stage: {step: 0 for step in range(num_steps)} for stage in stage_names} + self.done_workers = {stage: {step: 0 for step in range(num_steps)} for stage in stage_names} + + def record(self, stage: str, step: int, batch_size: int) -> int: + self.counts[stage][step] = self.counts.get(stage, {}).get(step, 0) + batch_size + return self.counts[stage][step] + + def record_done(self, stage: str, step: int) -> int: + self.done_workers[stage][step] = self.done_workers.get(stage, {}).get(step, 0) + 1 + return self.done_workers[stage][step] + + def get_counts(self, step: int) -> dict: + return {stage: self.counts[stage].get(step, 0) for stage in self.counts} + + def get_done_workers(self, step: int) -> dict: + return {stage: self.done_workers[stage].get(step, 0) for stage in self.done_workers} + + +class BaseStageWorker: + stage_name = "base" + + def __init__(self, tq_config, tracker, worker_id: int, config: "DemoConfig"): + tq.init(tq_config) + self.tq_client = tq.get_client() + controller = ray.get_actor("TransferQueueController") + self.cfg = ray.get(controller.get_config.remote()) + self.tracker = tracker + self.worker_id = worker_id + self.cfg_demo = config + self.worker_name = f"{self.stage_name}-{worker_id}" + + def start(self, iteration: int, train_iters: int) -> dict: + logger.info(f"[{self.worker_name}] start (iteration={iteration})") + for step in range(iteration, train_iters): + self._run_step(step) + logger.info(f"[{self.worker_name}] done") + return {"worker": self.worker_name, "stage": self.stage_name} + + def _run_step(self, step: int) -> None: + partition_id = f"{self.cfg_demo.partition_prefix}_{step}" + dataloader = self._build_dataloader(partition_id) + + for batch, batch_meta in dataloader: + sample_ids = batch["sample_id"].view(-1).tolist() + logger.info(f"[{self.worker_name}] step={step} consumed sample_ids={sample_ids}") + + output, written_fields = self.compute(batch, batch_meta) + self.tq_client.put(output, metadata=batch_meta) + + count = ray.get(self.tracker.record.remote(self.stage_name, step, len(sample_ids))) + logger.info( + f"[{self.worker_name}] step={step} done -> written_fields={written_fields}, " + f"{self.stage_name}_count={count}/{self.cfg_demo.global_batch_size}" + ) + + ray.get(self.tracker.record_done.remote(self.stage_name, step)) + logger.info(f"[{self.worker_name}] step={step} worker_done recorded") + + def _build_dataloader(self, partition_id: str) -> StreamingDataLoader: + dataset = StreamingDataset( + config=self.cfg, + batch_size=self.cfg_demo.micro_batch_size, + micro_batch_size=self.cfg_demo.micro_batch_size, + data_fields=self.input_fields(), + partition_id=partition_id, + task_name=f"{self.cfg_demo.task_name_prefix}_{self.stage_name}", + dp_rank=self.worker_id, + should_check_consumption_status=True, + ) + return StreamingDataLoader(dataset=dataset, num_workers=0, prefetch_factor=None) + + def input_fields(self) -> list[str]: + raise NotImplementedError + + def base_sleep_seconds(self) -> float: + return self.cfg_demo.stage_sleep_seconds + + def compute(self, batch: TensorDict, batch_meta: BatchMeta): + raise NotImplementedError + + def sleep_with_jitter(self, batch_meta: BatchMeta) -> None: + jitter_seed = self.worker_id * 7 + int(batch_meta.global_indexes[0]) + jitter = 0.05 * (jitter_seed % 5) + time.sleep(max(0.0, self.base_sleep_seconds() + jitter)) + + +@ray.remote(num_cpus=0.1) +class RolloutWorker(BaseStageWorker): + stage_name = "rollout" + + def input_fields(self) -> list[str]: + return ["sample_id", "prompt_ids"] + + def base_sleep_seconds(self) -> float: + return self.cfg_demo.rollout_sleep_seconds + + def compute(self, batch: TensorDict, batch_meta: BatchMeta): + self.sleep_with_jitter(batch_meta) + output = generate_sequences(batch["prompt_ids"], self.cfg_demo) + return output, ["input_ids", "response_ids", "response_mask"] + + +@ray.remote(num_cpus=0.1) +class RefWorker(BaseStageWorker): + stage_name = "ref" + + def input_fields(self) -> list[str]: + return ["sample_id", "prompt_ids", "response_ids"] + + def compute(self, batch: TensorDict, batch_meta: BatchMeta): + self.sleep_with_jitter(batch_meta) + log_prob = compute_log_prob(batch["prompt_ids"], batch["response_ids"]) + return TensorDict({"ref_log_prob": log_prob}, batch_size=log_prob.size(0)), ["ref_log_prob"] + + +@ray.remote(num_cpus=0.1) +class ActorWorker(BaseStageWorker): + stage_name = "actor" + + def input_fields(self) -> list[str]: + return ["sample_id", "prompt_ids", "response_ids"] + + def compute(self, batch: TensorDict, batch_meta: BatchMeta): + self.sleep_with_jitter(batch_meta) + log_prob = compute_log_prob(batch["prompt_ids"], batch["response_ids"]) + return TensorDict({"old_log_prob": log_prob}, batch_size=log_prob.size(0)), ["old_log_prob"] + + +@ray.remote(num_cpus=0.1) +class RewardWorker(BaseStageWorker): + stage_name = "reward" + + def input_fields(self) -> list[str]: + return ["sample_id", "response_ids"] + + def compute(self, batch: TensorDict, batch_meta: BatchMeta): + self.sleep_with_jitter(batch_meta) + advantage = compute_reward(batch["response_ids"]) + return TensorDict({"advantage": advantage}, batch_size=advantage.size(0)), ["advantage"] + + +@ray.remote(num_cpus=0.1) +class UpdateWorker(BaseStageWorker): + stage_name = "update" + + def input_fields(self) -> list[str]: + return ["sample_id", "old_log_prob", "ref_log_prob", "advantage"] + + def compute(self, batch: TensorDict, batch_meta: BatchMeta): + self.sleep_with_jitter(batch_meta) + loss = compute_loss(batch["old_log_prob"], batch["ref_log_prob"], batch["advantage"]) + return TensorDict({"loss": loss}, batch_size=loss.size(0)), ["loss"] + + +@ray.remote(num_cpus=0.1) +def sync_weights(step: int, sleep_s: float) -> dict: + logger.info(f"[weight-sync] step={step} start") + time.sleep(sleep_s) + logger.info(f"[weight-sync] step={step} done") + return {"step": step} + + +@dataclass(frozen=True) +class DemoConfig: + partition_prefix: str + task_name_prefix: str + num_steps: int + pipeline_depth: int + global_batch_size: int + micro_batch_size: int + prompt_length: int + response_length: int + vocab_size: int + num_rollout_workers: int + num_ref_workers: int + num_actor_workers: int + num_reward_workers: int + num_update_workers: int + rollout_sleep_seconds: float + stage_sleep_seconds: float + weight_sync_seconds: float + empty_poll_log_interval: int + num_data_storage_units: int + seed: int + + def validate(self) -> None: + for name, value in [ + ("num_steps", self.num_steps), + ("global_batch_size", self.global_batch_size), + ("micro_batch_size", self.micro_batch_size), + ("prompt_length", self.prompt_length), + ("response_length", self.response_length), + ("vocab_size", self.vocab_size), + ("num_rollout_workers", self.num_rollout_workers), + ("num_ref_workers", self.num_ref_workers), + ("num_actor_workers", self.num_actor_workers), + ("num_reward_workers", self.num_reward_workers), + ("num_update_workers", self.num_update_workers), + ]: + if value <= 0: + raise ValueError(f"{name} must be > 0, got {value}") + if self.global_batch_size % self.micro_batch_size != 0: + raise ValueError("global_batch_size % micro_batch_size != 0") + + +def build_tq_config(config: DemoConfig): + base = OmegaConf.load(resources.files("transfer_queue") / "config.yaml") + override = OmegaConf.create( + { + "controller": {"sampler": RankAwareSampler, "polling_mode": True}, + "backend": { + "storage_backend": "SimpleStorage", + "SimpleStorage": {"num_data_storage_units": config.num_data_storage_units}, + }, + }, + flags={"allow_objects": True}, + ) + return OmegaConf.merge(base, override) + + +class DecentralizedInheritedWorkerPipelineDemo: + def __init__(self, config: DemoConfig, tq_config): + self.config = config + tq.init(tq_config) + self.tq_client = tq.get_client() + self.tracker = ProgressTracker.remote(STAGE_NAMES, config.num_steps) + + self.rollout_workers = [RolloutWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_rollout_workers)] + self.ref_workers = [RefWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_ref_workers)] + self.actor_workers = [ActorWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_actor_workers)] + self.reward_workers = [RewardWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_reward_workers)] + self.update_workers = [UpdateWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_update_workers)] + + def _put_prompt(self, step: int) -> None: + partition_id = f"{self.config.partition_prefix}_{step}" + batch = make_prompt_batch(step, self.config) + sample_ids = batch["sample_id"].view(-1).tolist() + meta = self.tq_client.put(batch, partition_id=partition_id) + logger.info(f"[driver] step={step} prompt put -> partition={partition_id}, sample_ids={sample_ids}, fields={list(meta.field_names)}") + + def _wait_complete(self, step: int) -> None: + while True: + counts = ray.get(self.tracker.get_counts.remote(step)) + done_workers = ray.get(self.tracker.get_done_workers.remote(step)) + + active_counts = {stage: count for stage, count in counts.items() if count > 0} + logger.info(f"[driver] step={step} status -> counts={active_counts}, done_workers={done_workers}") + + all_workers_done = ( + done_workers.get("rollout", 0) >= self.config.num_rollout_workers + and done_workers.get("ref", 0) >= self.config.num_ref_workers + and done_workers.get("actor", 0) >= self.config.num_actor_workers + and done_workers.get("reward", 0) >= self.config.num_reward_workers + and done_workers.get("update", 0) >= self.config.num_update_workers + ) + if all_workers_done: + return + time.sleep(0.2) + + def _start_worker_group(self, workers: list) -> list: + return [worker.start.remote(0, self.config.num_steps) for worker in workers] + + def fit(self) -> list[dict]: + logger.info("=" * 72) + logger.info("TransferQueue StreamingDataLoader Decentralized Inherited Worker Pipeline Demo") + logger.info("=" * 72) + logger.info( + f"workers: rollout={self.config.num_rollout_workers}, " + f"ref={self.config.num_ref_workers}, actor={self.config.num_actor_workers}, " + f"reward={self.config.num_reward_workers}, update={self.config.num_update_workers}" + ) + logger.info( + f"pipeline: num_steps={self.config.num_steps}, " + f"global_batch_size={self.config.global_batch_size}, " + f"micro_batch_size={self.config.micro_batch_size}" + ) + + refs = [] + refs.extend(self._start_worker_group(self.rollout_workers)) + refs.extend(self._start_worker_group(self.ref_workers)) + refs.extend(self._start_worker_group(self.actor_workers)) + refs.extend(self._start_worker_group(self.reward_workers)) + refs.extend(self._start_worker_group(self.update_workers)) + + for step in range(self.config.num_steps): + self._put_prompt(step) + self._wait_complete(step) + ray.get(sync_weights.remote(step, self.config.weight_sync_seconds)) + self.tq_client.clear_partition(f"{self.config.partition_prefix}_{step}") + logger.info(f"[driver] step={step} cleared partition={self.config.partition_prefix}_{step}") + + ray.get(refs) + logger.info("demo done!") + return [] + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--partition-prefix", type=str, default="autonomous_demo") + parser.add_argument("--task-name-prefix", type=str, default="autonomous") + parser.add_argument("--num-steps", type=int, default=3) + parser.add_argument("--pipeline-depth", type=int, default=2) + parser.add_argument("--global-batch-size", type=int, default=8) + parser.add_argument("--micro-batch-size", type=int, default=2) + parser.add_argument("--prompt-length", type=int, default=24) + parser.add_argument("--response-length", type=int, default=32) + parser.add_argument("--vocab-size", type=int, default=32000) + parser.add_argument("--num-rollout-workers", type=int, default=2) + parser.add_argument("--num-ref-workers", type=int, default=2) + parser.add_argument("--num-actor-workers", type=int, default=2) + parser.add_argument("--num-reward-workers", type=int, default=2) + parser.add_argument("--num-update-workers", type=int, default=1) + parser.add_argument("--rollout-sleep-seconds", type=float, default=0.30) + parser.add_argument("--stage-sleep-seconds", type=float, default=0.15) + parser.add_argument("--weight-sync-seconds", type=float, default=0.20) + parser.add_argument("--empty-poll-log-interval", type=int, default=20) + parser.add_argument("--num-data-storage-units", type=int, default=2) + parser.add_argument("--seed", type=int, default=20260410) + args = parser.parse_args() + + cfg = DemoConfig( + partition_prefix=args.partition_prefix, + task_name_prefix=args.task_name_prefix, + num_steps=args.num_steps, + pipeline_depth=args.pipeline_depth, + global_batch_size=args.global_batch_size, + micro_batch_size=args.micro_batch_size, + prompt_length=args.prompt_length, + response_length=args.response_length, + vocab_size=args.vocab_size, + num_rollout_workers=args.num_rollout_workers, + num_ref_workers=args.num_ref_workers, + num_actor_workers=args.num_actor_workers, + num_reward_workers=args.num_reward_workers, + num_update_workers=args.num_update_workers, + rollout_sleep_seconds=args.rollout_sleep_seconds, + stage_sleep_seconds=args.stage_sleep_seconds, + weight_sync_seconds=args.weight_sync_seconds, + empty_poll_log_interval=args.empty_poll_log_interval, + num_data_storage_units=args.num_data_storage_units, + seed=args.seed, + ) + cfg.validate() + + os.environ["TQ_PRE_ALLOC_SAMPLE_NUM"] = str(cfg.global_batch_size) + + ray.init() + try: + demo = DecentralizedInheritedWorkerPipelineDemo(cfg, build_tq_config(cfg)) + demo.fit() + finally: + tq.close() + ray.shutdown() + + +if __name__ == "__main__": + main() From 14642aae98b51204840b6866a4162d24c6b80cd9 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 8 May 2026 11:22:22 +0800 Subject: [PATCH 02/22] [chore] correct filename for streamingDataloader demo Signed-off-by: jxixi <916099156@qq.com> --- ...{streaming_dataLoader_demo.py => streaming_dataloader_demo.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename recipe/simple_use_case/{streaming_dataLoader_demo.py => streaming_dataloader_demo.py} (100%) diff --git a/recipe/simple_use_case/streaming_dataLoader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py similarity index 100% rename from recipe/simple_use_case/streaming_dataLoader_demo.py rename to recipe/simple_use_case/streaming_dataloader_demo.py From 7deeabff76af3586f14e740a4b9b66d0c299bcf3 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 8 May 2026 15:07:57 +0800 Subject: [PATCH 03/22] [chore] Add license header for recipe\simple_use_case\streaming_dataloader_demo.py Signed-off-by: jxixi <916099156@qq.com> --- .../simple_use_case/streaming_dataloader_demo.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index 19cf255..9020974 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -1,3 +1,18 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import logging import os From 7ce161246e3aa4443b46635df58996051fd92b5f Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 8 May 2026 15:13:54 +0800 Subject: [PATCH 04/22] [refactor] remove unused parameter pipeline_depth Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/streaming_dataloader_demo.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index 9020974..1552531 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -251,7 +251,6 @@ class DemoConfig: partition_prefix: str task_name_prefix: str num_steps: int - pipeline_depth: int global_batch_size: int micro_batch_size: int prompt_length: int @@ -385,7 +384,6 @@ def main() -> None: parser.add_argument("--partition-prefix", type=str, default="autonomous_demo") parser.add_argument("--task-name-prefix", type=str, default="autonomous") parser.add_argument("--num-steps", type=int, default=3) - parser.add_argument("--pipeline-depth", type=int, default=2) parser.add_argument("--global-batch-size", type=int, default=8) parser.add_argument("--micro-batch-size", type=int, default=2) parser.add_argument("--prompt-length", type=int, default=24) @@ -408,7 +406,6 @@ def main() -> None: partition_prefix=args.partition_prefix, task_name_prefix=args.task_name_prefix, num_steps=args.num_steps, - pipeline_depth=args.pipeline_depth, global_batch_size=args.global_batch_size, micro_batch_size=args.micro_batch_size, prompt_length=args.prompt_length, From 766b4a9a5c0b28bb8b652421f4d2c9e0d26def82 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 8 May 2026 15:34:40 +0800 Subject: [PATCH 05/22] [refactor] remove unused parameter empty_poll_log_interval Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/streaming_dataloader_demo.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index 1552531..5dc8196 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -264,7 +264,6 @@ class DemoConfig: rollout_sleep_seconds: float stage_sleep_seconds: float weight_sync_seconds: float - empty_poll_log_interval: int num_data_storage_units: int seed: int @@ -397,7 +396,6 @@ def main() -> None: parser.add_argument("--rollout-sleep-seconds", type=float, default=0.30) parser.add_argument("--stage-sleep-seconds", type=float, default=0.15) parser.add_argument("--weight-sync-seconds", type=float, default=0.20) - parser.add_argument("--empty-poll-log-interval", type=int, default=20) parser.add_argument("--num-data-storage-units", type=int, default=2) parser.add_argument("--seed", type=int, default=20260410) args = parser.parse_args() @@ -419,7 +417,6 @@ def main() -> None: rollout_sleep_seconds=args.rollout_sleep_seconds, stage_sleep_seconds=args.stage_sleep_seconds, weight_sync_seconds=args.weight_sync_seconds, - empty_poll_log_interval=args.empty_poll_log_interval, num_data_storage_units=args.num_data_storage_units, seed=args.seed, ) From 2d315e267115d280c10b9992cbf2232a5e5dcff5 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 8 May 2026 18:35:42 +0800 Subject: [PATCH 06/22] [refactor] Replace logging with transfer_queue.utils.logging_utils.get_logger Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/streaming_dataloader_demo.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index 5dc8196..8838805 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -14,7 +14,6 @@ # limitations under the License. import argparse -import logging import os import time from dataclasses import dataclass @@ -28,9 +27,9 @@ import transfer_queue as tq from transfer_queue import RankAwareSampler, StreamingDataLoader, StreamingDataset from transfer_queue.metadata import BatchMeta +from transfer_queue.utils.logging_utils import get_logger -logging.basicConfig(level=logging.INFO, format="%(asctime)s.%(msecs)03d - %(levelname)s - %(message)s", datefmt="%H:%M:%S") -logger = logging.getLogger(__name__) +logger = get_logger(__name__, default_level="INFO") STAGE_NAMES = ["rollout", "ref", "actor", "reward", "update"] From 1818d6e6168d33f541516b413843053560298c9c Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 8 May 2026 18:42:23 +0800 Subject: [PATCH 07/22] [chore] Update default values for --partition-prefix and --task-name-prefix Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/streaming_dataloader_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index 8838805..1eb9fe2 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -379,8 +379,8 @@ def fit(self) -> list[dict]: def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--partition-prefix", type=str, default="autonomous_demo") - parser.add_argument("--task-name-prefix", type=str, default="autonomous") + parser.add_argument("--partition-prefix", type=str, default="decentralized_demo") + parser.add_argument("--task-name-prefix", type=str, default="decentralized") parser.add_argument("--num-steps", type=int, default=3) parser.add_argument("--global-batch-size", type=int, default=8) parser.add_argument("--micro-batch-size", type=int, default=2) From 44ef8aa2c76f9f27b8b453da9f9efb7091f15008 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 8 May 2026 18:49:23 +0800 Subject: [PATCH 08/22] [refactor] Rename DecentralizedInheritedWorkerPipelineDemo to DataCentricWorkerPipelineDemo Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/streaming_dataloader_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index 1eb9fe2..1e2d185 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -301,7 +301,7 @@ def build_tq_config(config: DemoConfig): return OmegaConf.merge(base, override) -class DecentralizedInheritedWorkerPipelineDemo: +class DataCentricWorkerPipelineDemo: def __init__(self, config: DemoConfig, tq_config): self.config = config tq.init(tq_config) @@ -425,7 +425,7 @@ def main() -> None: ray.init() try: - demo = DecentralizedInheritedWorkerPipelineDemo(cfg, build_tq_config(cfg)) + demo = DataCentricWorkerPipelineDemo(cfg, build_tq_config(cfg)) demo.fit() finally: tq.close() From 1c4de16746b785fb07f3dd25b666adaeb9128a67 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 8 May 2026 19:02:02 +0800 Subject: [PATCH 09/22] [refactor] Remove "[driver]" prefix from main control logs Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/streaming_dataloader_demo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index 1e2d185..b8d284d 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -319,7 +319,7 @@ def _put_prompt(self, step: int) -> None: batch = make_prompt_batch(step, self.config) sample_ids = batch["sample_id"].view(-1).tolist() meta = self.tq_client.put(batch, partition_id=partition_id) - logger.info(f"[driver] step={step} prompt put -> partition={partition_id}, sample_ids={sample_ids}, fields={list(meta.field_names)}") + logger.info(f"step={step} prompt put -> partition={partition_id}, sample_ids={sample_ids}, fields={list(meta.field_names)}") def _wait_complete(self, step: int) -> None: while True: @@ -327,7 +327,7 @@ def _wait_complete(self, step: int) -> None: done_workers = ray.get(self.tracker.get_done_workers.remote(step)) active_counts = {stage: count for stage, count in counts.items() if count > 0} - logger.info(f"[driver] step={step} status -> counts={active_counts}, done_workers={done_workers}") + logger.info(f"step={step} status -> counts={active_counts}, done_workers={done_workers}") all_workers_done = ( done_workers.get("rollout", 0) >= self.config.num_rollout_workers @@ -370,7 +370,7 @@ def fit(self) -> list[dict]: self._wait_complete(step) ray.get(sync_weights.remote(step, self.config.weight_sync_seconds)) self.tq_client.clear_partition(f"{self.config.partition_prefix}_{step}") - logger.info(f"[driver] step={step} cleared partition={self.config.partition_prefix}_{step}") + logger.info(f"step={step} cleared partition={self.config.partition_prefix}_{step}") ray.get(refs) logger.info("demo done!") From 3fd8d1dff602f9a2223c74b952a3332519f7ffee Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 8 May 2026 20:29:42 +0800 Subject: [PATCH 10/22] [refactor] Optimize log display; change worker logging to print() Signed-off-by: jxixi <916099156@qq.com> --- .../streaming_dataloader_demo.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index b8d284d..5ff098a 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -34,6 +34,10 @@ STAGE_NAMES = ["rollout", "ref", "actor", "reward", "update"] +def emit_worker_log(message: str) -> None: + print(message, flush=True) + + def make_prompt_batch(step: int, config: "DemoConfig") -> TensorDict: start_id = step * config.global_batch_size generator = torch.Generator().manual_seed(config.seed + step) @@ -115,10 +119,10 @@ def __init__(self, tq_config, tracker, worker_id: int, config: "DemoConfig"): self.worker_name = f"{self.stage_name}-{worker_id}" def start(self, iteration: int, train_iters: int) -> dict: - logger.info(f"[{self.worker_name}] start (iteration={iteration})") + emit_worker_log(f"[{self.worker_name}] start (iteration={iteration})") for step in range(iteration, train_iters): self._run_step(step) - logger.info(f"[{self.worker_name}] done") + emit_worker_log(f"[{self.worker_name}] done") return {"worker": self.worker_name, "stage": self.stage_name} def _run_step(self, step: int) -> None: @@ -127,19 +131,19 @@ def _run_step(self, step: int) -> None: for batch, batch_meta in dataloader: sample_ids = batch["sample_id"].view(-1).tolist() - logger.info(f"[{self.worker_name}] step={step} consumed sample_ids={sample_ids}") + emit_worker_log(f"[{self.worker_name}] step={step} consumed sample_ids={sample_ids}") output, written_fields = self.compute(batch, batch_meta) self.tq_client.put(output, metadata=batch_meta) count = ray.get(self.tracker.record.remote(self.stage_name, step, len(sample_ids))) - logger.info( + emit_worker_log( f"[{self.worker_name}] step={step} done -> written_fields={written_fields}, " f"{self.stage_name}_count={count}/{self.cfg_demo.global_batch_size}" ) ray.get(self.tracker.record_done.remote(self.stage_name, step)) - logger.info(f"[{self.worker_name}] step={step} worker_done recorded") + emit_worker_log(f"[{self.worker_name}] step={step} worker_done recorded") def _build_dataloader(self, partition_id: str) -> StreamingDataLoader: dataset = StreamingDataset( @@ -239,9 +243,9 @@ def compute(self, batch: TensorDict, batch_meta: BatchMeta): @ray.remote(num_cpus=0.1) def sync_weights(step: int, sleep_s: float) -> dict: - logger.info(f"[weight-sync] step={step} start") + emit_worker_log(f"[weight-sync] step={step} start") time.sleep(sleep_s) - logger.info(f"[weight-sync] step={step} done") + emit_worker_log(f"[weight-sync] step={step} done") return {"step": step} @@ -319,7 +323,10 @@ def _put_prompt(self, step: int) -> None: batch = make_prompt_batch(step, self.config) sample_ids = batch["sample_id"].view(-1).tolist() meta = self.tq_client.put(batch, partition_id=partition_id) - logger.info(f"step={step} prompt put -> partition={partition_id}, sample_ids={sample_ids}, fields={list(meta.field_names)}") + logger.info( + f"MAIN | step={step} put prompts: " + f"partition={partition_id}, sample_ids={sample_ids}, fields={list(meta.field_names)}" + ) def _wait_complete(self, step: int) -> None: while True: @@ -327,7 +334,7 @@ def _wait_complete(self, step: int) -> None: done_workers = ray.get(self.tracker.get_done_workers.remote(step)) active_counts = {stage: count for stage, count in counts.items() if count > 0} - logger.info(f"step={step} status -> counts={active_counts}, done_workers={done_workers}") + logger.info(f"MAIN | step={step} progress: counts={active_counts}, done_workers={done_workers}") all_workers_done = ( done_workers.get("rollout", 0) >= self.config.num_rollout_workers @@ -345,15 +352,15 @@ def _start_worker_group(self, workers: list) -> list: def fit(self) -> list[dict]: logger.info("=" * 72) - logger.info("TransferQueue StreamingDataLoader Decentralized Inherited Worker Pipeline Demo") + logger.info("MAIN | TransferQueue StreamingDataLoader Data-Centric Worker Pipeline Demo") logger.info("=" * 72) logger.info( - f"workers: rollout={self.config.num_rollout_workers}, " + f"MAIN | workers: rollout={self.config.num_rollout_workers}, " f"ref={self.config.num_ref_workers}, actor={self.config.num_actor_workers}, " f"reward={self.config.num_reward_workers}, update={self.config.num_update_workers}" ) logger.info( - f"pipeline: num_steps={self.config.num_steps}, " + f"MAIN | pipeline: num_steps={self.config.num_steps}, " f"global_batch_size={self.config.global_batch_size}, " f"micro_batch_size={self.config.micro_batch_size}" ) @@ -370,10 +377,10 @@ def fit(self) -> list[dict]: self._wait_complete(step) ray.get(sync_weights.remote(step, self.config.weight_sync_seconds)) self.tq_client.clear_partition(f"{self.config.partition_prefix}_{step}") - logger.info(f"step={step} cleared partition={self.config.partition_prefix}_{step}") + logger.info(f"MAIN | step={step} clear partition: {self.config.partition_prefix}_{step}") ray.get(refs) - logger.info("demo done!") + logger.info("MAIN | demo done!") return [] From 60752fac1b9bf761bc40a98288d067a1d9c74aab Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 16:50:42 +0800 Subject: [PATCH 11/22] [refactor] Optimize log display Signed-off-by: jxixi <916099156@qq.com> --- .../streaming_dataloader_demo.py | 52 ++++++++++++------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index 5ff098a..6bad79e 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -19,6 +19,10 @@ from dataclasses import dataclass from importlib import resources +# Disable Ray's cross-worker log deduplication before importing Ray itself, +# otherwise many worker-side prints will be folded into "[repeated Nx across cluster]". +os.environ["RAY_DEDUP_LOGS"] = "0" + import ray import torch from omegaconf import OmegaConf @@ -29,13 +33,14 @@ from transfer_queue.metadata import BatchMeta from transfer_queue.utils.logging_utils import get_logger -logger = get_logger(__name__, default_level="INFO") +logger = get_logger("MAIN", default_level="INFO") STAGE_NAMES = ["rollout", "ref", "actor", "reward", "update"] -def emit_worker_log(message: str) -> None: - print(message, flush=True) +def emit_worker_log(message: str, enabled: bool) -> None: + if enabled: + print(message, flush=True) def make_prompt_batch(step: int, config: "DemoConfig") -> TensorDict: @@ -119,10 +124,8 @@ def __init__(self, tq_config, tracker, worker_id: int, config: "DemoConfig"): self.worker_name = f"{self.stage_name}-{worker_id}" def start(self, iteration: int, train_iters: int) -> dict: - emit_worker_log(f"[{self.worker_name}] start (iteration={iteration})") for step in range(iteration, train_iters): self._run_step(step) - emit_worker_log(f"[{self.worker_name}] done") return {"worker": self.worker_name, "stage": self.stage_name} def _run_step(self, step: int) -> None: @@ -131,19 +134,22 @@ def _run_step(self, step: int) -> None: for batch, batch_meta in dataloader: sample_ids = batch["sample_id"].view(-1).tolist() - emit_worker_log(f"[{self.worker_name}] step={step} consumed sample_ids={sample_ids}") + emit_worker_log( + f"[{self.worker_name}] step={step} consume: sample_ids={sample_ids}", + self.cfg_demo.enable_worker_logs, + ) output, written_fields = self.compute(batch, batch_meta) self.tq_client.put(output, metadata=batch_meta) count = ray.get(self.tracker.record.remote(self.stage_name, step, len(sample_ids))) emit_worker_log( - f"[{self.worker_name}] step={step} done -> written_fields={written_fields}, " - f"{self.stage_name}_count={count}/{self.cfg_demo.global_batch_size}" + f"[{self.worker_name}] step={step} produce: " + f"fields={written_fields}, count={count}/{self.cfg_demo.global_batch_size}", + self.cfg_demo.enable_worker_logs, ) ray.get(self.tracker.record_done.remote(self.stage_name, step)) - emit_worker_log(f"[{self.worker_name}] step={step} worker_done recorded") def _build_dataloader(self, partition_id: str) -> StreamingDataLoader: dataset = StreamingDataset( @@ -243,9 +249,7 @@ def compute(self, batch: TensorDict, batch_meta: BatchMeta): @ray.remote(num_cpus=0.1) def sync_weights(step: int, sleep_s: float) -> dict: - emit_worker_log(f"[weight-sync] step={step} start") time.sleep(sleep_s) - emit_worker_log(f"[weight-sync] step={step} done") return {"step": step} @@ -269,6 +273,7 @@ class DemoConfig: weight_sync_seconds: float num_data_storage_units: int seed: int + enable_worker_logs: bool def validate(self) -> None: for name, value in [ @@ -324,7 +329,7 @@ def _put_prompt(self, step: int) -> None: sample_ids = batch["sample_id"].view(-1).tolist() meta = self.tq_client.put(batch, partition_id=partition_id) logger.info( - f"MAIN | step={step} put prompts: " + f"step={step} | put prompts: " f"partition={partition_id}, sample_ids={sample_ids}, fields={list(meta.field_names)}" ) @@ -334,7 +339,7 @@ def _wait_complete(self, step: int) -> None: done_workers = ray.get(self.tracker.get_done_workers.remote(step)) active_counts = {stage: count for stage, count in counts.items() if count > 0} - logger.info(f"MAIN | step={step} progress: counts={active_counts}, done_workers={done_workers}") + logger.info(f"step={step} | progress: counts={active_counts}, done_workers={done_workers}") all_workers_done = ( done_workers.get("rollout", 0) >= self.config.num_rollout_workers @@ -352,15 +357,15 @@ def _start_worker_group(self, workers: list) -> list: def fit(self) -> list[dict]: logger.info("=" * 72) - logger.info("MAIN | TransferQueue StreamingDataLoader Data-Centric Worker Pipeline Demo") + logger.info("TransferQueue StreamingDataLoader Data-Centric Worker Pipeline Demo") logger.info("=" * 72) logger.info( - f"MAIN | workers: rollout={self.config.num_rollout_workers}, " + f"workers | rollout={self.config.num_rollout_workers}, " f"ref={self.config.num_ref_workers}, actor={self.config.num_actor_workers}, " f"reward={self.config.num_reward_workers}, update={self.config.num_update_workers}" ) logger.info( - f"MAIN | pipeline: num_steps={self.config.num_steps}, " + f"pipeline | num_steps={self.config.num_steps}, " f"global_batch_size={self.config.global_batch_size}, " f"micro_batch_size={self.config.micro_batch_size}" ) @@ -373,14 +378,18 @@ def fit(self) -> list[dict]: refs.extend(self._start_worker_group(self.update_workers)) for step in range(self.config.num_steps): + logger.info("=" * 72) + logger.info(f"STEP {step}") + logger.info("=" * 72) self._put_prompt(step) self._wait_complete(step) + logger.info(f"step={step} | weight sync: start") ray.get(sync_weights.remote(step, self.config.weight_sync_seconds)) + logger.info(f"step={step} | weight sync: done") self.tq_client.clear_partition(f"{self.config.partition_prefix}_{step}") - logger.info(f"MAIN | step={step} clear partition: {self.config.partition_prefix}_{step}") + logger.info(f"step={step} | clear partition: {self.config.partition_prefix}_{step}") ray.get(refs) - logger.info("MAIN | demo done!") return [] @@ -404,6 +413,7 @@ def main() -> None: parser.add_argument("--weight-sync-seconds", type=float, default=0.20) parser.add_argument("--num-data-storage-units", type=int, default=2) parser.add_argument("--seed", type=int, default=20260410) + parser.add_argument("--enable-worker-logs", action="store_true") args = parser.parse_args() cfg = DemoConfig( @@ -425,19 +435,25 @@ def main() -> None: weight_sync_seconds=args.weight_sync_seconds, num_data_storage_units=args.num_data_storage_units, seed=args.seed, + enable_worker_logs=args.enable_worker_logs, ) cfg.validate() os.environ["TQ_PRE_ALLOC_SAMPLE_NUM"] = str(cfg.global_batch_size) + completed = False ray.init() try: demo = DataCentricWorkerPipelineDemo(cfg, build_tq_config(cfg)) demo.fit() + completed = True finally: tq.close() ray.shutdown() + if completed: + logger.info("demo done!") + if __name__ == "__main__": main() From e8fbb92d37a2f9b91429b1ade7a2c315b179b6df Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 17:14:54 +0800 Subject: [PATCH 12/22] [refactor] Add annotation with reference to Relax Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/streaming_dataloader_demo.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/streaming_dataloader_demo.py index 6bad79e..e394cdd 100644 --- a/recipe/simple_use_case/streaming_dataloader_demo.py +++ b/recipe/simple_use_case/streaming_dataloader_demo.py @@ -13,6 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +A simplified data-centric RL workflow demo built with StreamingDataset and +StreamingDataLoader. + +The implementation structure and asynchronous dataflow are inspired by the +Relax project, while keeping the example intentionally lightweight and focused +on educational readability. Reference: https://github.com/redai-infra/Relax +""" + import argparse import os import time From 382c77e39459c4816122ce32791251fb929f2336 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 17:16:29 +0800 Subject: [PATCH 13/22] [refactor] Rename demo file to relax_demo.py Signed-off-by: jxixi <916099156@qq.com> --- .../{streaming_dataloader_demo.py => relax_demo.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename recipe/simple_use_case/{streaming_dataloader_demo.py => relax_demo.py} (100%) diff --git a/recipe/simple_use_case/streaming_dataloader_demo.py b/recipe/simple_use_case/relax_demo.py similarity index 100% rename from recipe/simple_use_case/streaming_dataloader_demo.py rename to recipe/simple_use_case/relax_demo.py From 9d85f6a997d2235f5170e0b9c8cb02014cdaa212 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 17:28:42 +0800 Subject: [PATCH 14/22] [refactor] Rename demo main class Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/relax_demo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/recipe/simple_use_case/relax_demo.py b/recipe/simple_use_case/relax_demo.py index e394cdd..1aacd9a 100644 --- a/recipe/simple_use_case/relax_demo.py +++ b/recipe/simple_use_case/relax_demo.py @@ -319,7 +319,7 @@ def build_tq_config(config: DemoConfig): return OmegaConf.merge(base, override) -class DataCentricWorkerPipelineDemo: +class DataCentricPipelineDemo: def __init__(self, config: DemoConfig, tq_config): self.config = config tq.init(tq_config) @@ -366,7 +366,7 @@ def _start_worker_group(self, workers: list) -> list: def fit(self) -> list[dict]: logger.info("=" * 72) - logger.info("TransferQueue StreamingDataLoader Data-Centric Worker Pipeline Demo") + logger.info("TransferQueue StreamingDataLoader Data-Centric Pipeline Demo (Relax-inspired)") logger.info("=" * 72) logger.info( f"workers | rollout={self.config.num_rollout_workers}, " @@ -453,7 +453,7 @@ def main() -> None: completed = False ray.init() try: - demo = DataCentricWorkerPipelineDemo(cfg, build_tq_config(cfg)) + demo = DataCentricPipelineDemo(cfg, build_tq_config(cfg)) demo.fit() completed = True finally: From 810dfd14ab00056dfd937307626ec9d016abeb2e Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 17:38:35 +0800 Subject: [PATCH 15/22] [ci] Add new relax demo to recipe-check.yml workflow Signed-off-by: jxixi <916099156@qq.com> --- .github/workflows/recipe-check.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/recipe-check.yml b/.github/workflows/recipe-check.yml index 228ff67..fb26057 100644 --- a/.github/workflows/recipe-check.yml +++ b/.github/workflows/recipe-check.yml @@ -30,7 +30,11 @@ jobs: python -m pip install --upgrade pip pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -e ".[yuanrong]" - - name: Run recipes + - name: Run single controller demo run: | export RAY_DEDUP_LOGS=0 python3 recipe/simple_use_case/single_controller_demo.py --num-samples 8 --global-batch-size 4 --rollout-agent-num-workers 1 + - name: Run data-centric pipeline demo + run: | + export RAY_DEDUP_LOGS=0 + python3 recipe/simple_use_case/streaming_decentralized_inherited_worker_pipeline_demo.py From 7c2815a6aedb8b2d3bf0af2a87ea23f4455c7c7d Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 19:21:09 +0800 Subject: [PATCH 16/22] [ci] Fix relax_demo filename in recipe-check.yml Signed-off-by: jxixi <916099156@qq.com> --- .github/workflows/recipe-check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/recipe-check.yml b/.github/workflows/recipe-check.yml index fb26057..2df1499 100644 --- a/.github/workflows/recipe-check.yml +++ b/.github/workflows/recipe-check.yml @@ -37,4 +37,4 @@ jobs: - name: Run data-centric pipeline demo run: | export RAY_DEDUP_LOGS=0 - python3 recipe/simple_use_case/streaming_decentralized_inherited_worker_pipeline_demo.py + python3 recipe/simple_use_case/relax_demo.py From cfd680aa2b908ed5efa36b814c40185d3c9ae1cf Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 19:54:53 +0800 Subject: [PATCH 17/22] [fix] Reuse StreamingDataLoader across steps to avoid resource leaking Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/relax_demo.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/recipe/simple_use_case/relax_demo.py b/recipe/simple_use_case/relax_demo.py index 1aacd9a..0df6ca7 100644 --- a/recipe/simple_use_case/relax_demo.py +++ b/recipe/simple_use_case/relax_demo.py @@ -131,6 +131,7 @@ def __init__(self, tq_config, tracker, worker_id: int, config: "DemoConfig"): self.worker_id = worker_id self.cfg_demo = config self.worker_name = f"{self.stage_name}-{worker_id}" + self._dataloader: StreamingDataLoader | None = None def start(self, iteration: int, train_iters: int) -> dict: for step in range(iteration, train_iters): @@ -139,7 +140,7 @@ def start(self, iteration: int, train_iters: int) -> dict: def _run_step(self, step: int) -> None: partition_id = f"{self.cfg_demo.partition_prefix}_{step}" - dataloader = self._build_dataloader(partition_id) + dataloader = self._get_dataloader(partition_id) for batch, batch_meta in dataloader: sample_ids = batch["sample_id"].view(-1).tolist() @@ -160,6 +161,13 @@ def _run_step(self, step: int) -> None: ray.get(self.tracker.record_done.remote(self.stage_name, step)) + def _get_dataloader(self, partition_id: str) -> StreamingDataLoader: + if self._dataloader is None: + self._dataloader = self._build_dataloader(partition_id) + else: + self._dataloader.step(partition_id) + return self._dataloader + def _build_dataloader(self, partition_id: str) -> StreamingDataLoader: dataset = StreamingDataset( config=self.cfg, From 61ecf1bfd46b6461086cacca5b52a83c7394b6f4 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 19:56:05 +0800 Subject: [PATCH 18/22] [refactor] Update default values for --partition-prefix and --task-name-prefix Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/relax_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/recipe/simple_use_case/relax_demo.py b/recipe/simple_use_case/relax_demo.py index 0df6ca7..b1197a7 100644 --- a/recipe/simple_use_case/relax_demo.py +++ b/recipe/simple_use_case/relax_demo.py @@ -412,8 +412,8 @@ def fit(self) -> list[dict]: def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--partition-prefix", type=str, default="decentralized_demo") - parser.add_argument("--task-name-prefix", type=str, default="decentralized") + parser.add_argument("--partition-prefix", type=str, default="relax_demo") + parser.add_argument("--task-name-prefix", type=str, default="relax") parser.add_argument("--num-steps", type=int, default=3) parser.add_argument("--global-batch-size", type=int, default=8) parser.add_argument("--micro-batch-size", type=int, default=2) From 303707cf50ccf3553162dd5ea614dca14d0e9c97 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 20:46:00 +0800 Subject: [PATCH 19/22] [ci] Update runtime parameters for relax_demo in recipe-check.yml Signed-off-by: jxixi <916099156@qq.com> --- .github/workflows/recipe-check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/recipe-check.yml b/.github/workflows/recipe-check.yml index 2df1499..96d111f 100644 --- a/.github/workflows/recipe-check.yml +++ b/.github/workflows/recipe-check.yml @@ -37,4 +37,4 @@ jobs: - name: Run data-centric pipeline demo run: | export RAY_DEDUP_LOGS=0 - python3 recipe/simple_use_case/relax_demo.py + python3 recipe/simple_use_case/relax_demo.py --num-steps 1 --global-batch-size 2 --micro-batch-size 1 From f3210a2e48b8266cfb945b4b665e2918681bcc1a Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Thu, 14 May 2026 21:08:14 +0800 Subject: [PATCH 20/22] [fix] Fix line endings for pre-commit compliance Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/relax_demo.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/recipe/simple_use_case/relax_demo.py b/recipe/simple_use_case/relax_demo.py index b1197a7..d69740a 100644 --- a/recipe/simple_use_case/relax_demo.py +++ b/recipe/simple_use_case/relax_demo.py @@ -63,7 +63,9 @@ def make_prompt_batch(step: int, config: "DemoConfig") -> TensorDict: generator=generator, dtype=torch.long, ) - return TensorDict({"sample_id": sample_ids.unsqueeze(-1), "prompt_ids": prompt_ids}, batch_size=config.global_batch_size) + return TensorDict( + {"sample_id": sample_ids.unsqueeze(-1), "prompt_ids": prompt_ids}, batch_size=config.global_batch_size + ) def generate_sequences(prompt_ids: torch.Tensor, config: "DemoConfig") -> TensorDict: @@ -334,11 +336,19 @@ def __init__(self, config: DemoConfig, tq_config): self.tq_client = tq.get_client() self.tracker = ProgressTracker.remote(STAGE_NAMES, config.num_steps) - self.rollout_workers = [RolloutWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_rollout_workers)] + self.rollout_workers = [ + RolloutWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_rollout_workers) + ] self.ref_workers = [RefWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_ref_workers)] - self.actor_workers = [ActorWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_actor_workers)] - self.reward_workers = [RewardWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_reward_workers)] - self.update_workers = [UpdateWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_update_workers)] + self.actor_workers = [ + ActorWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_actor_workers) + ] + self.reward_workers = [ + RewardWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_reward_workers) + ] + self.update_workers = [ + UpdateWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_update_workers) + ] def _put_prompt(self, step: int) -> None: partition_id = f"{self.config.partition_prefix}_{step}" From 79dfaa916987458fd18cfc6c09da6c297071668e Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 15 May 2026 10:58:49 +0800 Subject: [PATCH 21/22] [fix] Change demo data type to nested tensor from dense tensor for resolving conflict with PR #92, and add worker exception handling Signed-off-by: jxixi <916099156@qq.com> --- recipe/simple_use_case/relax_demo.py | 57 ++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/recipe/simple_use_case/relax_demo.py b/recipe/simple_use_case/relax_demo.py index d69740a..1bdfd1e 100644 --- a/recipe/simple_use_case/relax_demo.py +++ b/recipe/simple_use_case/relax_demo.py @@ -69,35 +69,39 @@ def make_prompt_batch(step: int, config: "DemoConfig") -> TensorDict: def generate_sequences(prompt_ids: torch.Tensor, config: "DemoConfig") -> TensorDict: - batch_size = prompt_ids.size(0) - generator = torch.Generator().manual_seed(config.seed + int(prompt_ids.sum().item())) - response_ids = torch.randint( - 0, - config.vocab_size, - (batch_size, config.response_length), - generator=generator, - dtype=torch.long, - ) + # This demo focuses on dataflow, so rollout emits placeholder tensors with + # the right schema instead of deriving values from the prompt contents. + batch_size = len(prompt_ids.unbind()) return TensorDict( { - "input_ids": torch.cat([prompt_ids, response_ids], dim=1), - "response_ids": response_ids, - "response_mask": torch.ones_like(response_ids), + "input_ids": torch.zeros( + (batch_size, config.prompt_length + config.response_length), + dtype=torch.long, + ), + "response_ids": torch.zeros((batch_size, config.response_length), dtype=torch.long), + "response_mask": torch.ones((batch_size, config.response_length), dtype=torch.long), }, batch_size=batch_size, ) def compute_log_prob(prompt_ids: torch.Tensor, response_ids: torch.Tensor) -> torch.Tensor: - return (response_ids.float().mean(dim=1, keepdim=True) - prompt_ids.float().mean(dim=1, keepdim=True)) / 1000.0 + # Return a stable placeholder score per sample; downstream stages only need + # the field shape and dtype to demonstrate the pipeline. + batch_size = len(prompt_ids.unbind()) + return torch.zeros((batch_size, 1), dtype=torch.float32) def compute_reward(response_ids: torch.Tensor) -> torch.Tensor: - return response_ids.float().mean(dim=1, keepdim=True) / 1000.0 + # Reward is also mocked out to keep the example independent from model math. + batch_size = len(response_ids.unbind()) + return torch.zeros((batch_size, 1), dtype=torch.float32) def compute_loss(old_log_prob: torch.Tensor, ref_log_prob: torch.Tensor, advantage: torch.Tensor) -> torch.Tensor: - return (old_log_prob - ref_log_prob - advantage).abs() + # Update consumes the upstream fields but emits a placeholder loss tensor. + batch_size = len(old_log_prob.unbind()) + return torch.zeros((batch_size, 1), dtype=torch.float32) @ray.remote @@ -145,7 +149,7 @@ def _run_step(self, step: int) -> None: dataloader = self._get_dataloader(partition_id) for batch, batch_meta in dataloader: - sample_ids = batch["sample_id"].view(-1).tolist() + sample_ids = [int(sample_id.reshape(-1)[0].item()) for sample_id in batch["sample_id"].unbind()] emit_worker_log( f"[{self.worker_name}] step={step} consume: sample_ids={sample_ids}", self.cfg_demo.enable_worker_logs, @@ -167,6 +171,7 @@ def _get_dataloader(self, partition_id: str) -> StreamingDataLoader: if self._dataloader is None: self._dataloader = self._build_dataloader(partition_id) else: + # Reuse the same dataloader across steps and only advance its partition. self._dataloader.step(partition_id) return self._dataloader @@ -335,6 +340,7 @@ def __init__(self, config: DemoConfig, tq_config): tq.init(tq_config) self.tq_client = tq.get_client() self.tracker = ProgressTracker.remote(STAGE_NAMES, config.num_steps) + self._worker_refs: list[ray.ObjectRef] = [] self.rollout_workers = [ RolloutWorker.remote(tq_config, self.tracker, i, config) for i in range(config.num_rollout_workers) @@ -362,6 +368,7 @@ def _put_prompt(self, step: int) -> None: def _wait_complete(self, step: int) -> None: while True: + self._raise_if_worker_failed() counts = ray.get(self.tracker.get_counts.remote(step)) done_workers = ray.get(self.tracker.get_done_workers.remote(step)) @@ -379,6 +386,21 @@ def _wait_complete(self, step: int) -> None: return time.sleep(0.2) + def _raise_if_worker_failed(self) -> None: + if not self._worker_refs: + return + + # Worker exceptions stay attached to their ObjectRefs. The main loop only + # sees them once it explicitly ray.get()s a finished ref, so we probe for + # ready workers here instead of waiting until the very end of fit(). + ready_refs, _ = ray.wait(self._worker_refs, num_returns=1, timeout=0) + if not ready_refs: + return + + ready_ref = ready_refs[0] + ray.get(ready_ref) + self._worker_refs = [ref for ref in self._worker_refs if ref != ready_ref] + def _start_worker_group(self, workers: list) -> list: return [worker.start.remote(0, self.config.num_steps) for worker in workers] @@ -403,7 +425,10 @@ def fit(self) -> list[dict]: refs.extend(self._start_worker_group(self.actor_workers)) refs.extend(self._start_worker_group(self.reward_workers)) refs.extend(self._start_worker_group(self.update_workers)) + self._worker_refs = list(refs) + # Workers stay alive across the whole run; each training step is modeled + # as a fresh partition that carries one batch through the pipeline. for step in range(self.config.num_steps): logger.info("=" * 72) logger.info(f"STEP {step}") From 0ffd5e8634ba9a2425e5618c4acbbb9f7d040ff3 Mon Sep 17 00:00:00 2001 From: jxixi <916099156@qq.com> Date: Fri, 15 May 2026 11:01:16 +0800 Subject: [PATCH 22/22] [ci] Update demo run parameters in recipe-check.yml Signed-off-by: jxixi <916099156@qq.com> --- .github/workflows/recipe-check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/recipe-check.yml b/.github/workflows/recipe-check.yml index 96d111f..f54d354 100644 --- a/.github/workflows/recipe-check.yml +++ b/.github/workflows/recipe-check.yml @@ -37,4 +37,4 @@ jobs: - name: Run data-centric pipeline demo run: | export RAY_DEDUP_LOGS=0 - python3 recipe/simple_use_case/relax_demo.py --num-steps 1 --global-batch-size 2 --micro-batch-size 1 + python3 recipe/simple_use_case/relax_demo.py --num-steps 1 --global-batch-size 1 --micro-batch-size 1 --num-rollout-workers 1 --num-ref-workers 1 --num-actor-workers 1 --num-reward-workers 1 --rollout-sleep-seconds 0.01 --stage-sleep-seconds 0.01 --weight-sync-seconds 0.01