Skip to content

[Feature] Train infer disaggregated#523

Open
jiapingW wants to merge 2 commits intomainfrom
train_infer_disaggre
Open

[Feature] Train infer disaggregated#523
jiapingW wants to merge 2 commits intomainfrom
train_infer_disaggre

Conversation

@jiapingW
Copy link
Copy Markdown
Collaborator

@jiapingW jiapingW commented Apr 2, 2026

Motivation

Modifications

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Ray-based distributed architecture for SpecForge, enabling both colocated and disaggregated (训推分离) training modes. The changes include new Ray-based worker groups for rollout and training, a centralized orchestrator, and support for NCCL-based GPU-to-GPU data transfer. My feedback highlights performance bottlenecks in the rollout dispatch logic, potential runtime errors in the DataCollator initialization, risks associated with clearing the global device mesh, the need for robust error handling when waiting for distributed workers, and unnecessary synchronization in the data transfer utility.

Comment on lines +173 to +184
for dp_idx in range(dp_size):
data_batch, actual_count = self._fetch_multi_local(
self._rollout_batch_size
)
if data_batch is None:
break
per_dp_count = actual_count

send_ref = self.rollout_group.generate_and_send_single(
tp_idx, data_batch, [sp_leader_ranks[dp_idx]]
)
send_refs.append(send_ref)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation performs dp_size separate forward passes on the target model per logical training step. Since the target model is typically much larger than the draft model, this creates a significant performance bottleneck, especially as the number of DP groups increases.

Consider batching all dp_size requests into a single forward pass on the RolloutWorkerGroup (with a total batch size of dp_size * rollout_batch_size), then sharding and sending the results to the respective TrainWorker groups. This would leverage GPU parallelism much more effectively for the target model inference.

Comment on lines +37 to +47
def __init__(self, sp_degree=None, ulysses_degree=None):
if sp_degree is not None:
self.sp_degree = sp_degree
else:
self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group())
if ulysses_degree is not None:
self.ulysses_degree = ulysses_degree
else:
self.ulysses_degree = torch.distributed.get_world_size(
get_sp_ulysses_group()
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling torch.distributed.get_world_size() in the constructor of DataCollatorWithPadding will raise a RuntimeError if the collator is instantiated in a process where torch.distributed is not yet initialized (e.g., the driver process during dataset pre-building or in the orchestrator before workers are launched).

While the current RayOrchestrator passes these values explicitly, other utility functions like prepare_dp_dataloaders use the default constructor, which could lead to crashes if called outside a distributed context. Consider deferring the world size check until the first call to __call__ or providing safe defaults.

_SP_RING_GROUP = PROCESS_GROUP.RING_PG if sp_size > 1 else my_draft_sp_group
_TP_DEVICE_MESH = dist.DeviceMesh.from_group(my_tp_group, device_type="cuda")
_DP_DEVICE_MESH = dist.DeviceMesh.from_group(my_dp_group, device_type="cuda")
_DEVICE_MESH = None # 2D mesh not available in subgroup mode
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting _DEVICE_MESH = None in init_distributed_from_subgroup might cause failures in other parts of the codebase that rely on get_device_mesh(). While 1D meshes (_TP_DEVICE_MESH, _DP_DEVICE_MESH) are initialized, some FSDP configurations or monitoring tools in the existing codebase might expect the global 2D mesh to be present.

if self._enable_perf:
t3 = time.perf_counter()

metrics = ray.get(train_refs[0])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

ray.get(train_refs[0]) only waits for the first worker (rank 0) to complete. If any other worker in the distributed group encounters an error or is significantly slower, the orchestrator may proceed to the next step prematurely or hang in subsequent collective operations, making debugging difficult.

It is safer to wait for all workers to ensure consistency and catch exceptions occurring on non-zero ranks.

Suggested change
metrics = ray.get(train_refs[0])
metrics_list = ray.get(train_refs)
metrics = metrics_list[0]

position_ids=_to(batch.position_ids),
)
if needs_sync:
torch.cuda.synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch.cuda.synchronize() is a heavy operation that stalls the CPU until all GPU tasks are finished, which can reduce the benefits of using non_blocking=True for overlapping transfers.

Since this is called immediately before the forward pass, you can rely on the default stream's serialization or use CUDA events for more fine-grained synchronization if multiple streams are involved.

@FrankLeeeee
Copy link
Copy Markdown
Collaborator

need to add ray to pyproject.toml.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants