Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a Ray-based distributed architecture for SpecForge, enabling both colocated and disaggregated (训推分离) training modes. The changes include new Ray-based worker groups for rollout and training, a centralized orchestrator, and support for NCCL-based GPU-to-GPU data transfer. My feedback highlights performance bottlenecks in the rollout dispatch logic, potential runtime errors in the DataCollator initialization, risks associated with clearing the global device mesh, the need for robust error handling when waiting for distributed workers, and unnecessary synchronization in the data transfer utility.
| for dp_idx in range(dp_size): | ||
| data_batch, actual_count = self._fetch_multi_local( | ||
| self._rollout_batch_size | ||
| ) | ||
| if data_batch is None: | ||
| break | ||
| per_dp_count = actual_count | ||
|
|
||
| send_ref = self.rollout_group.generate_and_send_single( | ||
| tp_idx, data_batch, [sp_leader_ranks[dp_idx]] | ||
| ) | ||
| send_refs.append(send_ref) |
There was a problem hiding this comment.
The current implementation performs dp_size separate forward passes on the target model per logical training step. Since the target model is typically much larger than the draft model, this creates a significant performance bottleneck, especially as the number of DP groups increases.
Consider batching all dp_size requests into a single forward pass on the RolloutWorkerGroup (with a total batch size of dp_size * rollout_batch_size), then sharding and sending the results to the respective TrainWorker groups. This would leverage GPU parallelism much more effectively for the target model inference.
| def __init__(self, sp_degree=None, ulysses_degree=None): | ||
| if sp_degree is not None: | ||
| self.sp_degree = sp_degree | ||
| else: | ||
| self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group()) | ||
| if ulysses_degree is not None: | ||
| self.ulysses_degree = ulysses_degree | ||
| else: | ||
| self.ulysses_degree = torch.distributed.get_world_size( | ||
| get_sp_ulysses_group() | ||
| ) |
There was a problem hiding this comment.
Calling torch.distributed.get_world_size() in the constructor of DataCollatorWithPadding will raise a RuntimeError if the collator is instantiated in a process where torch.distributed is not yet initialized (e.g., the driver process during dataset pre-building or in the orchestrator before workers are launched).
While the current RayOrchestrator passes these values explicitly, other utility functions like prepare_dp_dataloaders use the default constructor, which could lead to crashes if called outside a distributed context. Consider deferring the world size check until the first call to __call__ or providing safe defaults.
| _SP_RING_GROUP = PROCESS_GROUP.RING_PG if sp_size > 1 else my_draft_sp_group | ||
| _TP_DEVICE_MESH = dist.DeviceMesh.from_group(my_tp_group, device_type="cuda") | ||
| _DP_DEVICE_MESH = dist.DeviceMesh.from_group(my_dp_group, device_type="cuda") | ||
| _DEVICE_MESH = None # 2D mesh not available in subgroup mode |
There was a problem hiding this comment.
Setting _DEVICE_MESH = None in init_distributed_from_subgroup might cause failures in other parts of the codebase that rely on get_device_mesh(). While 1D meshes (_TP_DEVICE_MESH, _DP_DEVICE_MESH) are initialized, some FSDP configurations or monitoring tools in the existing codebase might expect the global 2D mesh to be present.
| if self._enable_perf: | ||
| t3 = time.perf_counter() | ||
|
|
||
| metrics = ray.get(train_refs[0]) |
There was a problem hiding this comment.
ray.get(train_refs[0]) only waits for the first worker (rank 0) to complete. If any other worker in the distributed group encounters an error or is significantly slower, the orchestrator may proceed to the next step prematurely or hang in subsequent collective operations, making debugging difficult.
It is safer to wait for all workers to ensure consistency and catch exceptions occurring on non-zero ranks.
| metrics = ray.get(train_refs[0]) | |
| metrics_list = ray.get(train_refs) | |
| metrics = metrics_list[0] |
| position_ids=_to(batch.position_ids), | ||
| ) | ||
| if needs_sync: | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
torch.cuda.synchronize() is a heavy operation that stalls the CPU until all GPU tasks are finished, which can reduce the benefits of using non_blocking=True for overlapping transfers.
Since this is called immediately before the forward pass, you can rely on the default stream's serialization or use CUDA events for more fine-grained synchronization if multiple streams are involved.
|
need to add |
Motivation
Modifications
Related Issues
Accuracy Test
Benchmark & Profiling
Checklist