From 19c804895a85e96f2ef99675fbbc256a7fd37aa8 Mon Sep 17 00:00:00 2001 From: David Oy Date: Mon, 20 Apr 2026 12:15:20 -0700 Subject: [PATCH] fix: fail fast on NIXL region layout mismatch instead of building bad pairs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When MX_CONTIGUOUS_REG=1 is set on both source and target, each side groups adjacent tensors into contiguous memory regions and registers those regions with NIXL instead of individual tensors. The region layout is computed from each process's PyTorch CUDA allocator state, which is NOT deterministic across processes: fragmentation, temporary buffers, and pinned-memory state vary between runs. Two workers loading the same model weights can therefore produce different region groupings. Previously the recv path paired source regions with local regions by index, logged per-region WARNINGs when sizes disagreed, and then still built `(src, local)` pairs of mismatched length. NIXL rejected these: makeXferReq: length mismatch at index pair 52 with local index 52 and remote index 52 NIXL_ERR_INVALID_PARAM aborting the whole transfer and falling back to an HF download. Fix: validate that source and local region layouts have the same count and the same per-region sizes BEFORE building transfer descriptors. If they disagree, raise a new typed exception `RegionLayoutMismatchError`. The RDMA strategy maps this to `ManifestMismatchError` so the caller simply tries the next source candidate (without marking this source STALE, since the source itself is healthy — the mismatch is a property of the local allocator state). If no candidate's layout matches, the loader falls through to a non-RDMA strategy as before. The `[Contiguous Registration]` optimization is preserved unchanged when layouts do match; only the failure mode is fixed. Added unit tests for the new validation helper that run without NIXL, CUDA, or a GPU (pure Python inputs), including a regression test for the observed 223 vs 219 Llama-3.1-8B mismatch. --- .../load_strategy/rdma_strategy.py | 10 +- .../python/modelexpress/nixl_transfer.py | 111 ++++++++++++++-- .../python/tests/test_nixl_transfer.py | 121 ++++++++++++++++++ 3 files changed, 229 insertions(+), 13 deletions(-) create mode 100644 modelexpress_client/python/tests/test_nixl_transfer.py diff --git a/modelexpress_client/python/modelexpress/load_strategy/rdma_strategy.py b/modelexpress_client/python/modelexpress/load_strategy/rdma_strategy.py index 8d475155..1ff4ef6c 100644 --- a/modelexpress_client/python/modelexpress/load_strategy/rdma_strategy.py +++ b/modelexpress_client/python/modelexpress/load_strategy/rdma_strategy.py @@ -15,7 +15,7 @@ import torch.nn as nn from .base import LoadContext, LoadStrategy, SourceTransferError, register_tensors, publish_metadata -from ..nixl_transfer import is_nixl_available +from ..nixl_transfer import RegionLayoutMismatchError, is_nixl_available from ..tensor_utils import capture_tensor_attrs from ..transfer_safety import check_transfer_allowed from ..types import ManifestMismatchError, TensorDescriptor @@ -270,6 +270,14 @@ def _receive_from_peer( coalesce_transfers=coalesce, remote_agent_name=remote_agent_name_override, ) + except RegionLayoutMismatchError as e: + # Contiguous-region layouts differ between source and this worker + # (PyTorch CUDA allocator non-determinism). The source itself is + # healthy; raise ManifestMismatchError so the caller just tries + # the next candidate without marking the source STALE. + raise ManifestMismatchError( + f"region layout mismatch with source: {e}" + ) from e except Exception as e: raise SourceTransferError(f"RDMA receive failed: {e}") from e transfer_time = time.perf_counter() - transfer_start diff --git a/modelexpress_client/python/modelexpress/nixl_transfer.py b/modelexpress_client/python/modelexpress/nixl_transfer.py index f1245fc1..00f1a838 100644 --- a/modelexpress_client/python/modelexpress/nixl_transfer.py +++ b/modelexpress_client/python/modelexpress/nixl_transfer.py @@ -39,6 +39,27 @@ def is_nixl_available() -> bool: return NIXL_AVAILABLE +class RegionLayoutMismatchError(Exception): + """Source and local contiguous-region layouts disagree. + + The recv side registered a different set of contiguous memory regions + than the source (different region count, or same count but different + region sizes). This happens because PyTorch's CUDA caching allocator is + not deterministic across processes, so two workers loading the same + model weights can produce different region groupings. + + The transfer cannot proceed in region mode because NIXL requires that + the N-th local region exactly matches the N-th remote region in size. + Callers should try the next source candidate (which may share our + layout) or fall through to a non-RDMA loading strategy. + """ + + +# Internal alias kept short for the raise site; re-exported via the public +# name above for callers that want to catch it. +_RegionLayoutMismatchError = RegionLayoutMismatchError + + class NixlTransferManager: """ Manages a single NIXL agent and RDMA transfers for GPU tensors. @@ -271,6 +292,61 @@ def _find_contiguous_regions( return regions + @staticmethod + def _validate_region_layout_match( + source_regions: list[TensorDescriptor], + local_regions: list[tuple[int, int]], + ) -> tuple[bool, str]: + """Check that two contiguous-region layouts agree on count and sizes. + + Two layouts match iff they have the same number of regions and each + pair has the same size. (Addresses legitimately differ — they're + virtual addresses in different processes.) + + Args: + source_regions: Region descriptors received from the source. + local_regions: ``(addr, size)`` tuples this worker registered. + + Returns: + ``(True, "")`` when the layouts match; otherwise + ``(False, human_readable_summary)`` describing the first few + mismatches so callers can include it in an error message. + """ + if len(source_regions) != len(local_regions): + return False, ( + f"region count mismatch: source has {len(source_regions)}, " + f"local has {len(local_regions)} " + "(PyTorch CUDA allocator non-determinism produced different " + "contiguous-region groupings across processes)" + ) + + size_mismatches: list[tuple[int, int, int]] = [] + for i, (src_region, (_local_addr, local_size)) in enumerate( + zip(source_regions, local_regions, strict=True) + ): + if src_region.size != local_size: + size_mismatches.append((i, src_region.size, local_size)) + + if size_mismatches: + # Log just the first few so errors stay readable. + head = size_mismatches[:5] + sample = ", ".join( + f"region {i}: source={s} local={l}" for i, s, l in head + ) + suffix = ( + f" (+{len(size_mismatches) - len(head)} more)" + if len(size_mismatches) > len(head) + else "" + ) + return False, ( + f"{len(size_mismatches)} region size mismatch(es) " + f"out of {len(source_regions)}: {sample}{suffix} " + "(PyTorch CUDA allocator non-determinism produced different " + "contiguous-region groupings across processes)" + ) + + return True, "" + def fetch_remote_and_wait( self, remote_agent_name: str, @@ -361,27 +437,38 @@ def receive_from_source( logger.info(f"Region-based transfer: {len(source_tensors)} source regions -> {len(self._registered_regions)} local regions") - # Validate region counts match - if len(source_tensors) != len(self._registered_regions): - logger.warning( - f"Region count mismatch: source has {len(source_tensors)}, " - f"local has {len(self._registered_regions)}. Proceeding with min." - ) + # Validate the two region layouts are identical before matching. + # + # The contiguous-region layout is derived from each process's + # PyTorch CUDA allocator state, which is NOT deterministic across + # processes: fragmentation pattern, temporary buffers, and pinned- + # memory state vary between runs. Two workers that load the same + # model weights can therefore produce different region groupings. + # + # If source and local layouts disagree (region count differs, or + # any region size differs), naively pairing by index produces + # mismatched (src, local) pairs. NIXL will reject these with + # `makeXferReq: length mismatch ... NIXL_ERR_INVALID_PARAM` and + # abort the entire transfer. Previously we logged per-region + # WARNINGs but proceeded anyway; now we fail fast with a typed + # exception so the caller can try the next source candidate or + # fall through to a non-RDMA strategy. + layout_matches, mismatch_summary = self._validate_region_layout_match( + source_tensors, self._registered_regions, + ) + if not layout_matches: + raise _RegionLayoutMismatchError(mismatch_summary) - # Build transfer lists by region index + # Build transfer lists by region index (layouts are validated equal) remote_descs = [] local_descs = [] # Will be (addr, size, device_id) tuples total_bytes = 0 - matched_count = min(len(source_tensors), len(self._registered_regions)) + matched_count = len(source_tensors) for i in range(matched_count): src_region = source_tensors[i] local_addr, local_size = self._registered_regions[i] - # Verify sizes match (regions should be same size) - if src_region.size != local_size: - logger.warning(f"Region {i} size mismatch: source={src_region.size}, local={local_size}") - remote_descs.append((src_region.addr, src_region.size, src_region.device_id)) local_descs.append((local_addr, local_size, self._device_id)) total_bytes += src_region.size diff --git a/modelexpress_client/python/tests/test_nixl_transfer.py b/modelexpress_client/python/tests/test_nixl_transfer.py new file mode 100644 index 00000000..2eb658fc --- /dev/null +++ b/modelexpress_client/python/tests/test_nixl_transfer.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for nixl_transfer region-layout validation. + +These tests exercise only pure-Python logic and do NOT require NIXL, CUDA, +or a GPU, so they can run in CI on any worker. +""" + +from modelexpress.nixl_transfer import ( + NixlTransferManager, + RegionLayoutMismatchError, +) +from modelexpress.types import TensorDescriptor + + +def _region_desc(i: int, addr: int, size: int) -> TensorDescriptor: + """Build a region-style TensorDescriptor as emitted by the source side.""" + return TensorDescriptor( + name=f"__region_{i}__", + addr=addr, + size=size, + device_id=0, + dtype="contiguous_region", + ) + + +class TestValidateRegionLayoutMatch: + """Tests for NixlTransferManager._validate_region_layout_match. + + This guards against the bug where mismatched region layouts between + source and recv produced `(src, local)` pairs of unequal length that + NIXL rejected with NIXL_ERR_INVALID_PARAM after having silently logged + per-region WARNINGs. + """ + + def test_identical_layouts_match(self): + """Equal count and equal sizes (addresses may differ) -> match.""" + source = [ + _region_desc(0, 0x10000, 1024), + _region_desc(1, 0x20000, 2048), + _region_desc(2, 0x30000, 512), + ] + # Local addresses intentionally different — they're VAs in another process. + local = [(0xAA000, 1024), (0xBB000, 2048), (0xCC000, 512)] + + ok, msg = NixlTransferManager._validate_region_layout_match(source, local) + assert ok is True + assert msg == "" + + def test_count_mismatch_does_not_match(self): + """Different region counts must be rejected.""" + source = [_region_desc(i, 0x10000 * (i + 1), 1024) for i in range(3)] + local = [(0xA000, 1024), (0xB000, 1024)] + + ok, msg = NixlTransferManager._validate_region_layout_match(source, local) + assert ok is False + assert "region count mismatch" in msg + assert "3" in msg and "2" in msg + + def test_size_mismatch_does_not_match(self): + """Same count but one region differs in size -> rejected.""" + source = [ + _region_desc(0, 0x10000, 1024), + _region_desc(1, 0x20000, 2048), + _region_desc(2, 0x30000, 512), + ] + local = [(0xA000, 1024), (0xB000, 9999), (0xC000, 512)] # index 1 diverges + + ok, msg = NixlTransferManager._validate_region_layout_match(source, local) + assert ok is False + assert "size mismatch" in msg + assert "region 1" in msg + assert "2048" in msg + assert "9999" in msg + + def test_size_mismatch_summary_caps_output(self): + """Many mismatches produce a bounded summary, not spam per region.""" + # 20 regions, all size-mismatched + source = [_region_desc(i, 0x1000 * (i + 1), 1024) for i in range(20)] + local = [(0x10000 + 0x1000 * i, 4096) for i in range(20)] + + ok, msg = NixlTransferManager._validate_region_layout_match(source, local) + assert ok is False + # Summary should name the total but not list every one of the 20. + assert "20 region size mismatch" in msg + # Expect a "+N more" suffix indicating truncation. + assert "more" in msg + + def test_empty_layouts_match_trivially(self): + """Two empty layouts are (vacuously) equal.""" + ok, msg = NixlTransferManager._validate_region_layout_match([], []) + assert ok is True + assert msg == "" + + def test_regression_logged_failure_from_llama31_8b(self): + """Regression test matching the exact scenario from the 2026-04-20 log. + + Source produced 223 regions, recv produced 219 regions, with specific + size mismatches at indices 84, 85, 88-92, 216-218. Under the old + code this proceeded anyway and NIXL rejected the transfer. Under + the fix it must raise RegionLayoutMismatchError. + """ + # Build two layouts with different counts (223 vs 219) + source = [_region_desc(i, 0x10000 * (i + 1), 33554432) for i in range(223)] + local = [(0xA0000 + i * 0x1000, 33554432) for i in range(219)] + + ok, msg = NixlTransferManager._validate_region_layout_match(source, local) + assert ok is False + assert "region count mismatch" in msg + assert "223" in msg + assert "219" in msg + + +class TestRegionLayoutMismatchError: + """The exception used to signal layout disagreement to callers.""" + + def test_is_exception(self): + err = RegionLayoutMismatchError("layouts differ") + assert isinstance(err, Exception) + assert "layouts differ" in str(err)