Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion tests/distributed/test_dcp_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm.config.parallel import ParallelConfig
from vllm.utils.network_utils import get_open_port
from vllm.utils.system_utils import update_environment_variables
Expand Down Expand Up @@ -379,7 +380,13 @@ def _distributed_packed_a2a_worker(env: dict[str, str]) -> None:
update_environment_variables(env)
local_rank = int(env["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
dist.init_process_group(backend="nccl")
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group(backend="nccl")
use_workspace = env.get("USE_WORKSPACE") == "1"
if use_workspace:
from vllm.v1.worker.workspace import init_workspace_manager
Expand Down
33 changes: 23 additions & 10 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.distributed

import vllm.envs as envs
from tests.utils import ensure_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
Expand Down Expand Up @@ -82,11 +83,18 @@ def test_pynccl():
@worker_fn_wrapper
def multiple_allreduce_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
# Eager-init path: parent PG has bound_device_id + a CPU backend,
# so split_group is supported.
group = torch.distributed.split_group(
split_ranks=[[0, 1], [2, 3]], backend="cpu:gloo,cuda:nccl"
)
else:
groups = [
torch.distributed.new_group(ranks=[0, 1], backend="gloo"),
torch.distributed.new_group(ranks=[2, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently
Expand Down Expand Up @@ -339,11 +347,16 @@ def test_pynccl_send_recv():
@worker_fn_wrapper
def multiple_send_recv_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
groups = [
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
group = torch.distributed.split_group(
split_ranks=[[0, 2], [1, 3]], backend="cpu:gloo,cuda:nccl"
)
else:
groups = [
torch.distributed.new_group(ranks=[0, 2], backend="gloo"),
torch.distributed.new_group(ranks=[1, 3], backend="gloo"),
]
group = groups[0] if torch.distributed.get_rank() in [0, 2] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device)
if torch.distributed.get_rank() == 0:
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
Expand Down
29 changes: 22 additions & 7 deletions tests/distributed/test_quick_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce # noqa
from vllm.distributed.device_communicators.quick_all_reduce import (
Expand Down Expand Up @@ -397,13 +398,27 @@ def qr_variable_input(rank, world_size):
ranks = []
for i in range(world_size):
ranks.append(i)
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
)
cpu_group = torch.distributed.new_group(ranks, backend="nccl")
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
device_id=device,
)
else:
dist.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:29500",
rank=rank,
world_size=world_size,
)
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
cpu_group = torch.distributed.split_group(
split_ranks=[ranks], backend="cpu:gloo,cuda:nccl"
)
else:
cpu_group = torch.distributed.new_group(ranks, backend="nccl")

handle = ops.qr_get_handle(_ptr)
world_size = dist.get_world_size(group=cpu_group)
Expand Down
233 changes: 233 additions & 0 deletions tests/distributed/test_split_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for split_group in GroupCoordinator.

These tests verify that:
1. split_group is used for both device and CPU group creation.
2. Multiple subgroups work correctly with split_group.
3. Both GPU and CPU all-reduce work on split groups.
"""

import os
from typing import Any

import multiprocess as mp
import pytest
import torch
import torch.distributed

import vllm.envs as envs
from vllm.distributed.parallel_state import (
GroupCoordinator,
init_distributed_environment,
)
from vllm.utils.system_utils import update_environment_variables

# The whole module exercises the split_group code path, which is opt-in
# behind VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1.
pytestmark = pytest.mark.skipif(
not envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP,
reason=("VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1 not set; split_group path is opt-in."),
)

mp.set_start_method("spawn", force=True)


def distributed_run(fn, world_size):
number_of_processes = world_size
processes: list[mp.Process] = []
for i in range(number_of_processes):
env: dict[str, str] = {}
env["RANK"] = str(i)
env["LOCAL_RANK"] = str(i)
env["WORLD_SIZE"] = str(number_of_processes)
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12346"
# propagate the opt-in flag to the spawned child workers
env["VLLM_DISTRIBUTED_USE_SPLIT_GROUP"] = "1"
p = mp.Process(target=fn, args=(env,))
processes.append(p)
p.start()

for p in processes:
p.join()

for p in processes:
assert p.exitcode == 0


def worker_fn_wrapper(fn):
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ["LOCAL_RANK"]
device = torch.device(f"cuda:{local_rank}")
torch.accelerator.set_device_index(device)
init_distributed_environment()
fn()

return wrapped_fn


def _verify_device_group(coordinator: GroupCoordinator):
"""Verify device group works via all-reduce."""
local_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{local_rank}")
tensor = torch.ones(16, 16, dtype=torch.float32, device=device)
torch.distributed.all_reduce(tensor, group=coordinator.device_group)
torch.accelerator.synchronize()
expected = coordinator.world_size
assert torch.all(tensor == expected).cpu().item(), (
f"Device group all-reduce failed: expected {expected}, "
f"got {tensor.flatten()[0].item()}"
)


def _verify_cpu_group(coordinator: GroupCoordinator):
"""Verify CPU group works via all-reduce."""
tensor = torch.ones(16, dtype=torch.float32)
torch.distributed.all_reduce(tensor, group=coordinator.cpu_group)
expected = coordinator.world_size
assert torch.all(tensor == expected).cpu().item(), (
f"CPU group all-reduce failed: expected {expected}, "
f"got {tensor.flatten()[0].item()}"
)


# ---------------------------------------------------------------------------
# Test 1: Basic split_group path with 2 GPUs
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_basic_worker():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
group_ranks = [list(range(world_size))]

coordinator = GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_basic",
)

_verify_device_group(coordinator)
_verify_cpu_group(coordinator)


@pytest.mark.skipif(
torch.accelerator.device_count() < 2,
reason="Need at least 2 GPUs to run the test.",
)
def test_split_group_basic():
"""Test basic GroupCoordinator creation with split_group."""
distributed_run(split_group_basic_worker, 2)


# ---------------------------------------------------------------------------
# Test 2: Multiple subgroups with split_group (4 GPUs)
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_multiple_subgroups_worker():
rank = torch.distributed.get_rank()
group_ranks = [[0, 1], [2, 3]]

coordinator = GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_multi",
)

assert coordinator.world_size == 2

_verify_device_group(coordinator)
_verify_cpu_group(coordinator)

if rank in [0, 1]:
assert coordinator.ranks == [0, 1]
else:
assert coordinator.ranks == [2, 3]


@pytest.mark.skipif(
torch.accelerator.device_count() < 4,
reason="Need at least 4 GPUs to run the test.",
)
def test_split_group_multiple_subgroups():
"""Test GroupCoordinator with multiple independent subgroups."""
distributed_run(split_group_multiple_subgroups_worker, 4)


# ---------------------------------------------------------------------------
# Test 3: split_group contract — every parent rank must enter with the same
# ``split_ranks``. NCCL happens to produce
# correct subgroups for disjoint partitions because the wrapper hashes
# ``my_group`` to derive the comm-split color, but the contract violation is
# real and would break under non-partition / non-NCCL backends. This test
# captures the actual ``split_ranks`` argument passed on every rank and
# asserts they match.
# ---------------------------------------------------------------------------
@worker_fn_wrapper
def split_group_contract_worker():
rank = torch.distributed.get_rank()
group_ranks = [[0, 1], [2, 3]]

captured: list[list[list[int]]] = []
original_split_group = torch.distributed.split_group

def capturing_split_group(*args, split_ranks=None, **kwargs):
captured.append([list(g) for g in split_ranks])
return original_split_group(*args, split_ranks=split_ranks, **kwargs)

torch.distributed.split_group = capturing_split_group
try:
GroupCoordinator(
group_ranks=group_ranks,
local_rank=rank,
torch_distributed_backend="nccl",
use_device_communicator=False,
group_name="test_split_contract",
)
finally:
torch.distributed.split_group = original_split_group

# GroupCoordinator builds two subgroups (device + cpu) per coordinator,
# so every rank must have made exactly two split_group calls.
if len(captured) != 2:
raise AssertionError(
f"rank {rank} expected 2 split_group calls (device + cpu), "
f"got {len(captured)}: {captured}"
)

world_size = torch.distributed.get_world_size()
for call_idx in range(2):
gathered: list[Any] = [None] * world_size
torch.distributed.all_gather_object(gathered, captured[call_idx])
# Normalize for stable comparison (sort each subgroup and the outer list).
norm = [
sorted([sorted(sg) for sg in per_rank_args]) for per_rank_args in gathered
]
reference = norm[0]
for r, args in enumerate(norm):
if args != reference:
raise AssertionError(
f"split_group contract violation on call #{call_idx}: "
f"rank {r} passed split_ranks={gathered[r]}, but rank 0 "
f"passed split_ranks={gathered[0]}. PyTorch requires every "
"parent rank to enter split_group with the same split_ranks."
)


@pytest.mark.skipif(
torch.accelerator.device_count() < 4,
reason="Need at least 4 GPUs to run the test.",
)
def test_split_group_contract_same_split_ranks_on_all_ranks():
"""All parent ranks must call torch.distributed.split_group with the same
``split_ranks`` argument. This catches the bug where each rank passed
only its own subgroup (``split_ranks=[ranks]``), which NCCL forgives for
disjoint partitions but is a documented contract violation.
"""
distributed_run(split_group_contract_worker, 4)
17 changes: 15 additions & 2 deletions tests/distributed/test_torchrun_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@
import os
import random

import torch
import torch.distributed as dist

import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import get_world_group

# Let PyTorch choose the WORLD backend for the current device type.
dist.init_process_group()
# By default, let PyTorch choose the WORLD backend for the current device
# type (legacy lazy-init path). When VLLM_DISTRIBUTED_USE_SPLIT_GROUP=1,
# use the explicit eager-init pattern required by `split_group` (mixed
# cpu:gloo,cuda:nccl backend + device_id binding).
if envs.VLLM_DISTRIBUTED_USE_SPLIT_GROUP:
local_rank = int(os.environ["LOCAL_RANK"])
torch.accelerator.set_device_index(local_rank)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
device_id=torch.device(f"cuda:{local_rank}"),
)
else:
dist.init_process_group()

# Create prompts
prompts = [
Expand Down
Loading
Loading