diff --git a/tests/distributed/test_dcp_a2a.py b/tests/distributed/test_dcp_a2a.py index d80ed36be650..5ab0f3de97b5 100644 --- a/tests/distributed/test_dcp_a2a.py +++ b/tests/distributed/test_dcp_a2a.py @@ -15,6 +15,7 @@ import torch import torch.distributed as dist +import vllm.envs as envs from vllm.config.parallel import ParallelConfig from vllm.utils.network_utils import get_open_port from vllm.utils.system_utils import update_environment_variables @@ -379,7 +380,13 @@ def _distributed_packed_a2a_worker(env: dict[str, str]) -> None: update_environment_variables(env) local_rank = int(env["LOCAL_RANK"]) torch.accelerator.set_device_index(local_rank) - dist.init_process_group(backend="nccl") + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + device_id=torch.device(f"cuda:{local_rank}"), + ) + else: + dist.init_process_group(backend="nccl") use_workspace = env.get("USE_WORKSPACE") == "1" if use_workspace: from vllm.v1.worker.workspace import init_workspace_manager diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index a1d5355d4466..d7b04f68091d 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -9,6 +9,7 @@ import torch import torch.distributed +import vllm.envs as envs from tests.utils import ensure_current_vllm_config from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator @@ -82,11 +83,18 @@ def test_pynccl(): @worker_fn_wrapper def multiple_allreduce_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - groups = [ - torch.distributed.new_group(ranks=[0, 1], backend="gloo"), - torch.distributed.new_group(ranks=[2, 3], backend="gloo"), - ] - group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + # Eager-init path: parent PG has bound_device_id + a CPU backend, + # so split_group is supported. + group = torch.distributed.split_group( + split_ranks=[[0, 1], [2, 3]], backend="cpu:gloo,cuda:nccl" + ) + else: + groups = [ + torch.distributed.new_group(ranks=[0, 1], backend="gloo"), + torch.distributed.new_group(ranks=[2, 3], backend="gloo"), + ] + group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) # two groups can communicate independently @@ -339,11 +347,16 @@ def test_pynccl_send_recv(): @worker_fn_wrapper def multiple_send_recv_worker_fn(): device = torch.device(f"cuda:{torch.distributed.get_rank()}") - groups = [ - torch.distributed.new_group(ranks=[0, 2], backend="gloo"), - torch.distributed.new_group(ranks=[1, 3], backend="gloo"), - ] - group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + group = torch.distributed.split_group( + split_ranks=[[0, 2], [1, 3]], backend="cpu:gloo,cuda:nccl" + ) + else: + groups = [ + torch.distributed.new_group(ranks=[0, 2], backend="gloo"), + torch.distributed.new_group(ranks=[1, 3], backend="gloo"), + ] + group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) if torch.distributed.get_rank() == 0: tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) diff --git a/tests/distributed/test_quick_all_reduce.py b/tests/distributed/test_quick_all_reduce.py index a9591f96a78f..86eb82c962e7 100644 --- a/tests/distributed/test_quick_all_reduce.py +++ b/tests/distributed/test_quick_all_reduce.py @@ -9,6 +9,7 @@ import torch import torch.distributed as dist +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa from vllm.distributed.device_communicators.quick_all_reduce import ( @@ -397,13 +398,27 @@ def qr_variable_input(rank, world_size): ranks = [] for i in range(world_size): ranks.append(i) - dist.init_process_group( - backend="nccl", - init_method="tcp://127.0.0.1:29500", - rank=rank, - world_size=world_size, - ) - cpu_group = torch.distributed.new_group(ranks, backend="nccl") + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method="tcp://127.0.0.1:29500", + rank=rank, + world_size=world_size, + device_id=device, + ) + else: + dist.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:29500", + rank=rank, + world_size=world_size, + ) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + cpu_group = torch.distributed.split_group( + split_ranks=[ranks], backend="cpu:gloo,cuda:nccl" + ) + else: + cpu_group = torch.distributed.new_group(ranks, backend="nccl") handle = ops.qr_get_handle(_ptr) world_size = dist.get_world_size(group=cpu_group) diff --git a/tests/distributed/test_split_group.py b/tests/distributed/test_split_group.py new file mode 100644 index 000000000000..54586c9e370a --- /dev/null +++ b/tests/distributed/test_split_group.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for split_group in GroupCoordinator. + +These tests verify that: +1. split_group is used for both device and CPU group creation. +2. Multiple subgroups work correctly with split_group. +3. Both GPU and CPU all-reduce work on split groups. +""" + +import os +from typing import Any + +import multiprocess as mp +import pytest +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.distributed.parallel_state import ( + GroupCoordinator, + init_distributed_environment, +) +from vllm.utils.system_utils import update_environment_variables + +# The whole module exercises the split_group code path, which is opt-in +# behind VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1. +pytestmark = pytest.mark.skipif( + not envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP, + reason=("VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1 not set; split_group path is opt-in."), +) + +mp.set_start_method("spawn", force=True) + + +def distributed_run(fn, world_size): + number_of_processes = world_size + processes: list[mp.Process] = [] + for i in range(number_of_processes): + env: dict[str, str] = {} + env["RANK"] = str(i) + env["LOCAL_RANK"] = str(i) + env["WORLD_SIZE"] = str(number_of_processes) + env["LOCAL_WORLD_SIZE"] = str(number_of_processes) + env["MASTER_ADDR"] = "localhost" + env["MASTER_PORT"] = "12346" + # propagate the opt-in flag to the spawned child workers + env["VLLM_DISTRIBUTED_USE_SPLIT_GROUP"] = "1" + p = mp.Process(target=fn, args=(env,)) + processes.append(p) + p.start() + + for p in processes: + p.join() + + for p in processes: + assert p.exitcode == 0 + + +def worker_fn_wrapper(fn): + def wrapped_fn(env): + update_environment_variables(env) + local_rank = os.environ["LOCAL_RANK"] + device = torch.device(f"cuda:{local_rank}") + torch.accelerator.set_device_index(device) + init_distributed_environment() + fn() + + return wrapped_fn + + +def _verify_device_group(coordinator: GroupCoordinator): + """Verify device group works via all-reduce.""" + local_rank = torch.distributed.get_rank() + device = torch.device(f"cuda:{local_rank}") + tensor = torch.ones(16, 16, dtype=torch.float32, device=device) + torch.distributed.all_reduce(tensor, group=coordinator.device_group) + torch.accelerator.synchronize() + expected = coordinator.world_size + assert torch.all(tensor == expected).cpu().item(), ( + f"Device group all-reduce failed: expected {expected}, " + f"got {tensor.flatten()[0].item()}" + ) + + +def _verify_cpu_group(coordinator: GroupCoordinator): + """Verify CPU group works via all-reduce.""" + tensor = torch.ones(16, dtype=torch.float32) + torch.distributed.all_reduce(tensor, group=coordinator.cpu_group) + expected = coordinator.world_size + assert torch.all(tensor == expected).cpu().item(), ( + f"CPU group all-reduce failed: expected {expected}, " + f"got {tensor.flatten()[0].item()}" + ) + + +# --------------------------------------------------------------------------- +# Test 1: Basic split_group path with 2 GPUs +# --------------------------------------------------------------------------- +@worker_fn_wrapper +def split_group_basic_worker(): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + group_ranks = [list(range(world_size))] + + coordinator = GroupCoordinator( + group_ranks=group_ranks, + local_rank=rank, + torch_distributed_backend="nccl", + use_device_communicator=False, + group_name="test_split_basic", + ) + + _verify_device_group(coordinator) + _verify_cpu_group(coordinator) + + +@pytest.mark.skipif( + torch.accelerator.device_count() < 2, + reason="Need at least 2 GPUs to run the test.", +) +def test_split_group_basic(): + """Test basic GroupCoordinator creation with split_group.""" + distributed_run(split_group_basic_worker, 2) + + +# --------------------------------------------------------------------------- +# Test 2: Multiple subgroups with split_group (4 GPUs) +# --------------------------------------------------------------------------- +@worker_fn_wrapper +def split_group_multiple_subgroups_worker(): + rank = torch.distributed.get_rank() + group_ranks = [[0, 1], [2, 3]] + + coordinator = GroupCoordinator( + group_ranks=group_ranks, + local_rank=rank, + torch_distributed_backend="nccl", + use_device_communicator=False, + group_name="test_split_multi", + ) + + assert coordinator.world_size == 2 + + _verify_device_group(coordinator) + _verify_cpu_group(coordinator) + + if rank in [0, 1]: + assert coordinator.ranks == [0, 1] + else: + assert coordinator.ranks == [2, 3] + + +@pytest.mark.skipif( + torch.accelerator.device_count() < 4, + reason="Need at least 4 GPUs to run the test.", +) +def test_split_group_multiple_subgroups(): + """Test GroupCoordinator with multiple independent subgroups.""" + distributed_run(split_group_multiple_subgroups_worker, 4) + + +# --------------------------------------------------------------------------- +# Test 3: split_group contract — every parent rank must enter with the same +# ``split_ranks``. NCCL happens to produce +# correct subgroups for disjoint partitions because the wrapper hashes +# ``my_group`` to derive the comm-split color, but the contract violation is +# real and would break under non-partition / non-NCCL backends. This test +# captures the actual ``split_ranks`` argument passed on every rank and +# asserts they match. +# --------------------------------------------------------------------------- +@worker_fn_wrapper +def split_group_contract_worker(): + rank = torch.distributed.get_rank() + group_ranks = [[0, 1], [2, 3]] + + captured: list[list[list[int]]] = [] + original_split_group = torch.distributed.split_group + + def capturing_split_group(*args, split_ranks=None, **kwargs): + captured.append([list(g) for g in split_ranks]) + return original_split_group(*args, split_ranks=split_ranks, **kwargs) + + torch.distributed.split_group = capturing_split_group + try: + GroupCoordinator( + group_ranks=group_ranks, + local_rank=rank, + torch_distributed_backend="nccl", + use_device_communicator=False, + group_name="test_split_contract", + ) + finally: + torch.distributed.split_group = original_split_group + + # GroupCoordinator builds two subgroups (device + cpu) per coordinator, + # so every rank must have made exactly two split_group calls. + if len(captured) != 2: + raise AssertionError( + f"rank {rank} expected 2 split_group calls (device + cpu), " + f"got {len(captured)}: {captured}" + ) + + world_size = torch.distributed.get_world_size() + for call_idx in range(2): + gathered: list[Any] = [None] * world_size + torch.distributed.all_gather_object(gathered, captured[call_idx]) + # Normalize for stable comparison (sort each subgroup and the outer list). + norm = [ + sorted([sorted(sg) for sg in per_rank_args]) for per_rank_args in gathered + ] + reference = norm[0] + for r, args in enumerate(norm): + if args != reference: + raise AssertionError( + f"split_group contract violation on call #{call_idx}: " + f"rank {r} passed split_ranks={gathered[r]}, but rank 0 " + f"passed split_ranks={gathered[0]}. PyTorch requires every " + "parent rank to enter split_group with the same split_ranks." + ) + + +@pytest.mark.skipif( + torch.accelerator.device_count() < 4, + reason="Need at least 4 GPUs to run the test.", +) +def test_split_group_contract_same_split_ranks_on_all_ranks(): + """All parent ranks must call torch.distributed.split_group with the same + ``split_ranks`` argument. This catches the bug where each rank passed + only its own subgroup (``split_ranks=[ranks]``), which NCCL forgives for + disjoint partitions but is a documented contract violation. + """ + distributed_run(split_group_contract_worker, 4) diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index e72f00bc91e0..670df2759b0a 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -5,13 +5,26 @@ import os import random +import torch import torch.distributed as dist +import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.distributed.parallel_state import get_world_group -# Let PyTorch choose the WORLD backend for the current device type. -dist.init_process_group() +# By default, let PyTorch choose the WORLD backend for the current device +# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1, +# use the explicit eager-init pattern required by `split_group` (mixed +# cpu:gloo,cuda:nccl backend + device_id binding). +if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.accelerator.set_device_index(local_rank) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + device_id=torch.device(f"cuda:{local_rank}"), + ) +else: + dist.init_process_group() # Create prompts prompts = [ diff --git a/tests/distributed/test_torchrun_example_moe.py b/tests/distributed/test_torchrun_example_moe.py index 969b5e92e3fc..6f0957ed0263 100644 --- a/tests/distributed/test_torchrun_example_moe.py +++ b/tests/distributed/test_torchrun_example_moe.py @@ -5,13 +5,26 @@ import os import random +import torch import torch.distributed as dist +import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.distributed.parallel_state import get_tp_group, get_world_group -# Let PyTorch choose the WORLD backend for the current device type. -dist.init_process_group() +# By default, let PyTorch choose the WORLD backend for the current device +# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1, +# use the explicit eager-init pattern required by `split_group` (mixed +# cpu:gloo,cuda:nccl backend + device_id binding). +if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + local_rank = int(os.environ["LOCAL_RANK"]) + torch.accelerator.set_device_index(local_rank) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + device_id=torch.device(f"cuda:{local_rank}"), + ) +else: + dist.init_process_group() # Create prompts prompts = [ diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py index 07f244451b45..44e2ee836b95 100644 --- a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -10,6 +10,7 @@ from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage] from typing_extensions import ParamSpec +import vllm.envs as envs from vllm.config import VllmConfig, set_current_vllm_config from vllm.distributed import ( cleanup_dist_env_and_memory, @@ -60,7 +61,15 @@ def _set_vllm_config( tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size, pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size, ) - cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo") + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + cpu_group = torch.distributed.split_group( + split_ranks=[list(range(world_size))], + group_desc="moe_test_cpu", + ) + else: + cpu_group = torch.distributed.new_group( + list(range(world_size)), backend="gloo" + ) return cpu_group diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 452bf64ed989..efb1e2f2969d 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -14,6 +14,7 @@ from torch.distributed import ProcessGroup from typing_extensions import ParamSpec +import vllm.envs as envs from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.activation import MoEActivation @@ -375,7 +376,13 @@ def _test_deepep_deepgemm_moe( w1_scale = w1_scale.to(device=device) w2_scale = w2_scale.to(device=device) - pg = torch.distributed.new_group(list(range(pgi.world_size))) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + pg = torch.distributed.split_group( + split_ranks=[list(range(pgi.world_size))], + group_desc="deepep_deepgemm_test", + ) + else: + pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, pgi.rank) block_shape = [w1.size(1) // w1_scale.size(1), w1.size(2) // w1_scale.size(2)] diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py index 5e0303c3df72..83cd2f09d1e8 100644 --- a/tests/kernels/moe/test_deepep_moe.py +++ b/tests/kernels/moe/test_deepep_moe.py @@ -10,6 +10,7 @@ import torch.distributed from torch.distributed import ProcessGroup +import vllm.envs as envs from tests.kernels.moe.utils import make_dummy_moe_config from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config @@ -375,7 +376,13 @@ def _deep_ep_moe( w1_scale = w1_scale.to(device=device_idx) w2_scale = w2_scale.to(device=device_idx) - pg = torch.distributed.new_group(list(range(pgi.world_size))) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + pg = torch.distributed.split_group( + split_ranks=[list(range(pgi.world_size))], + group_desc="deepep_test", + ) + else: + pg = torch.distributed.new_group(list(range(pgi.world_size))) test_tensors = TestTensors.make(config, low_latency_mode) with set_current_vllm_config(VllmConfig()): diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 712167c601cf..b642435c7491 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -227,6 +227,67 @@ def patched_fused_scaled_matmul_reduce_scatter_fake( return res +def _platform_device_type() -> str: + """Return the device-type string (e.g. ``"cuda"``, ``"xpu"``, ``"cpu"``) + for the current platform, in the form expected by + ``torch.distributed.init_process_group(backend=...)``. + """ + from vllm.platforms import current_platform + + if current_platform.is_cuda_alike(): + return "cuda" + elif current_platform.is_xpu(): + return "xpu" + elif current_platform.is_out_of_tree(): + return current_platform.device_name + else: + return "cpu" + + +def _device_backend_str(torch_distributed_backend: str | Backend) -> str: + """Normalize ``torch_distributed_backend`` to the ``":"`` + format required by ``split_group``'s ``backend`` argument. + + Accepts either a bare backend name (e.g. ``"nccl"``) or an already-prefixed + string (e.g. ``"cuda:nccl"``). + """ + backend_str = str(torch_distributed_backend) + if ":" in backend_str: + return backend_str + return f"{_platform_device_type()}:{backend_str}" + + +def _create_subgroups_split_group( + group_ranks: list[list[int]], + group_name: str, + torch_distributed_backend: str | Backend, +) -> tuple[ProcessGroup, ProcessGroup]: + """Create the device + CPU subgroups for ``GroupCoordinator`` via + ``torch.distributed.split_group``. + + ``split_group`` is collective on the parent group, so every parent rank + must enter with the same ``split_ranks`` definition. Each rank receives + the subgroup it belongs to. + """ + device_backend_str = _device_backend_str(torch_distributed_backend) + self_device_group = torch.distributed.split_group( + split_ranks=group_ranks, + group_desc=f"{group_name}:device", + backend=device_backend_str, + ) + # CPU subgroup: split_group requires the requested backend filter to + # include the parent's default device type (= the device the parent PG + # was bound to via ``device_id``), so a cpu-only filter is rejected. + # Include the device backend in the filter; only the gloo backend is + # actually used for CPU collectives on this group. + self_cpu_group = torch.distributed.split_group( + split_ranks=group_ranks, + group_desc=f"{group_name}:cpu", + backend=f"cpu:gloo,{device_backend_str}", + ) + return self_device_group, self_cpu_group + + def patched_fused_scaled_matmul_reduce_scatter( A: torch.Tensor, B: torch.Tensor, @@ -335,26 +396,39 @@ def __init__( self_device_group = None self_cpu_group = None - from vllm.distributed.utils import get_cpu_distributed_timeout_or_none + # VLLM_DISTRIBUTED_USE_SPLIT_GROUP gates the new ``split_group`` + # codepath. Default (False) preserves the legacy ``new_group`` path. + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + self_device_group, self_cpu_group = _create_subgroups_split_group( + group_ranks, group_name, torch_distributed_backend + ) + for ranks in group_ranks: + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + break + else: + from vllm.distributed.utils import get_cpu_distributed_timeout_or_none - timeout = get_cpu_distributed_timeout_or_none() + timeout = get_cpu_distributed_timeout_or_none() - for ranks in group_ranks: - device_group = torch.distributed.new_group( - ranks, backend=torch_distributed_backend - ) - # a group with `gloo` backend, to allow direct coordination between - # processes through the CPU. - with suppress_stdout(): - cpu_group = torch.distributed.new_group( - ranks, backend="gloo", timeout=timeout + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend ) - if self.rank in ranks: - self.ranks = ranks - self.world_size = len(ranks) - self.rank_in_group = ranks.index(self.rank) - self_device_group = device_group - self_cpu_group = cpu_group + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + with suppress_stdout(): + cpu_group = torch.distributed.new_group( + ranks, backend="gloo", timeout=timeout + ) + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self_device_group = device_group + self_cpu_group = cpu_group assert self_cpu_group is not None assert self_device_group is not None @@ -1332,6 +1406,62 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def _init_process_group_for_split_group( + *, + backend: str, + distributed_init_method: str, + world_size: int, + rank: int, + local_rank: int, + timeout: timedelta | None, +) -> None: + """Initialize the default PG with both CPU (gloo) and device (e.g. nccl) + backends and an eager ``device_id`` binding so that subgroups can be + created via ``split_group`` (which requires the parent communicator to + be eagerly initialized). Falls back to ``gloo`` on CPU-only systems. + """ + if torch.accelerator.is_available() and backend != "gloo": + init_backend = "cpu:gloo,cuda:nccl" + device_id: torch.device | None = torch.device(f"cuda:{local_rank}") + else: + init_backend = "gloo" + device_id = None + torch.distributed.init_process_group( + backend=init_backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + timeout=timeout, + device_id=device_id, + ) + + +def _validate_default_pg_for_split_group() -> None: + """When an external launcher (e.g. ``torchrun``) initialized the default + PG, ``GroupCoordinator`` cannot patch in additional backends or change + the eager-init behavior — ``split_group`` only selects subsets of an + existing parent. Validate that the parent has both ``device_id`` and a + CPU (gloo) backend, and emit a descriptive error pointing at the exact + init call to update otherwise. + """ + default_pg = torch.distributed.distributed_c10d._get_default_group() + assert default_pg.bound_device_id is not None, ( + "External launcher initialized the default process group " + "without device_id. vLLM requires the default PG to be device-" + "bound for split_group. Pass device_id=torch.device(f'cuda:" + "{local_rank}') to torch.distributed.init_process_group()." + ) + try: + default_pg._get_backend(torch.device("cpu")) + except RuntimeError as e: + raise RuntimeError( + "External launcher initialized the default process group " + "without a CPU (gloo) backend. vLLM requires both CPU and " + "device backends. Pass backend='cpu:gloo,cuda:nccl' to " + "torch.distributed.init_process_group()." + ) from e + + def _init_elastic_ep_world( config, local_rank: int, backend: str, rank: int, world_size: int ) -> None: @@ -1440,14 +1570,33 @@ def init_distributed_environment( "Fallback Gloo backend is not available." ) backend = "gloo" - # this backend is used for WORLD - torch.distributed.init_process_group( - backend=backend, - init_method=distributed_init_method, - world_size=world_size, - rank=rank, - timeout=timeout, - ) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP: + # split_group needs local_rank early to compute device_id for + # the eager init. local_rank is not available in torch + # ProcessGroup, see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + local_rank = ( + int(envs.LOCAL_RANK) + if distributed_init_method == "env://" + else rank + ) + _init_process_group_for_split_group( + backend=backend, + distributed_init_method=distributed_init_method, + world_size=world_size, + rank=rank, + local_rank=local_rank, + timeout=timeout, + ) + else: + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + timeout=timeout, + ) if enable_elastic_ep: tp_pp_cpu_group = torch.distributed.new_group( backend="gloo", timeout=timeout @@ -1460,6 +1609,9 @@ def init_distributed_environment( "Elastic EP is not yet supported with multi-node TP/PP" ) + if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP and torch.accelerator.is_available(): + _validate_default_pg_for_split_group() + # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 diff --git a/vllm/envs.py b/vllm/envs.py index c12e3cae247f..14f1bb3b4449 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -63,6 +63,7 @@ VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True VLLM_USE_RAY_V2_EXECUTOR_BACKEND: bool = False + VLLM_DISTRIBUTED_USE_SPLIT_GROUP: bool = True VLLM_XLA_USE_SPMD: bool = False VLLM_WORKER_MULTIPROC_METHOD: Literal["fork", "spawn"] = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") @@ -877,6 +878,13 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_USE_RAY_V2_EXECUTOR_BACKEND": lambda: bool( int(os.getenv("VLLM_USE_RAY_V2_EXECUTOR_BACKEND", "1")) ), + # When True, GroupCoordinator constructs its CPU/device subgroups via + # ``torch.distributed.split_group(backend=...)`` + # and ``init_distributed_environment`` initializes the default PG with + # mixed ``cpu:gloo,cuda:nccl`` backend + eager ``device_id`` binding. + "VLLM_DISTRIBUTED_USE_SPLIT_GROUP": lambda: bool( + int(os.getenv("VLLM_DISTRIBUTED_USE_SPLIT_GROUP", "1")) + ), # Use dedicated multiprocess context for workers. # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": env_with_choices(