From 7fc0725e2810109c3bcefa6fe892288b2289a393 Mon Sep 17 00:00:00 2001 From: cenzhiyao Date: Mon, 27 Apr 2026 13:21:25 +0800 Subject: [PATCH] [Fix] Scope deferred_runtime_asserts per piecewise sub-graph to prevent NameError MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After piecewise splitting, sub-graphs share the same ShapeEnv whose deferred_runtime_asserts may reference backed SymInts (e.g. s90) that only exist as placeholders in another sub-graph. Inductor blindly emits these as Python runtime assertions, causing NameError at runtime. Add _scope_deferred_runtime_asserts context manager that temporarily narrows deferred_runtime_asserts to only reachable backed symbols before each standalone_compile call, then restores the original dict afterwards. Includes pytest that reproduces the bug (patch fix away → NameError) and verifies the fix (with fix → passes). --- .../magi_backend/piecewise_compiler.py | 81 ++++++++- .../test_piecewise_deferred_assert_scope.py | 172 ++++++++++++++++++ 2 files changed, 252 insertions(+), 1 deletion(-) create mode 100644 tests/feature_tests/test_piecewise_deferred_assert_scope.py diff --git a/magi_compiler/magi_backend/piecewise_compiler.py b/magi_compiler/magi_backend/piecewise_compiler.py index e5e1252..d7136ce 100644 --- a/magi_compiler/magi_backend/piecewise_compiler.py +++ b/magi_compiler/magi_backend/piecewise_compiler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import re from abc import abstractmethod from collections.abc import Callable @@ -39,6 +40,78 @@ def _summarize_compile_input(value: Any) -> str: return type(value).__name__ +@contextlib.contextmanager +def _scope_deferred_runtime_asserts(example_inputs: list[Any]): + """Temporarily narrow ``deferred_runtime_asserts`` to only those whose + backed symbols are reachable from the current piecewise sub-graph's inputs. + + Inductor's ``GraphLowering`` copies the *entire* + ``shape_env.deferred_runtime_asserts`` dict and later emits each assert as + a Python statement referencing the symbol name. When ``standalone_compile`` + is called with ``from_tracing_context`` on a piecewise sub-graph, the + shared ``ShapeEnv`` may contain asserts that reference backed SymInts + absent from this sub-graph's placeholders (e.g. ``s90``), causing a + ``NameError`` at runtime. + + This context manager scopes the dict for the duration of + ``standalone_compile`` and restores the original afterwards so subsequent + sub-graphs see the full set. + """ + import sympy + from torch._subclasses.fake_tensor import FakeTensor + from torch.fx.experimental.symbolic_shapes import SymTypes, free_unbacked_symbols + + context = torch._guards.TracingContext.try_get() + if context is None or context.fake_mode is None: + yield + return + shape_env = context.fake_mode.shape_env + if not shape_env.deferred_runtime_asserts: + yield + return + + # Collect all backed symbols reachable from this sub-graph's inputs. + # e.g. sub-graph inputs: x(s77, 64), u0, u1, u2 → reachable = {s77} + reachable: set[sympy.Symbol] = set() + for inp in example_inputs: + if isinstance(inp, SymTypes): + # SymInt input, e.g. unbacked u0 → {u0} + reachable |= inp.node.expr.free_symbols + elif isinstance(inp, torch.Tensor): + if isinstance(inp, FakeTensor): + for s in inp.shape: + if isinstance(s, torch.SymInt): + # Tensor dim, e.g. x.shape[0] = s77 → {s77} + reachable |= s.node.expr.free_symbols + + # Filter: keep only asserts whose backed symbols all exist in this sub-graph. + # e.g. Eq(u0+u1+u2, s77): backed={s77} ⊆ {s77} → keep + # Eq(u0+u1+u2, s90): backed={s90} ⊄ {s77} → drop (s90 is in another sub-graph) + original: dict[sympy.Symbol, list[sympy.Expr]] = shape_env.deferred_runtime_asserts + filtered: dict[sympy.Symbol, list[sympy.Expr]] = {} + for sym, ras in original.items(): + kept: list[sympy.Expr] = [] + for ra in ras: + # all_symbols - unbacked_symbols = backed symbols only + backed_in_expr = ra.expr.free_symbols - free_unbacked_symbols(ra.expr) + if backed_in_expr <= reachable: + kept.append(ra) + else: + magi_logger.debug( + "Filtering unreachable deferred assert for sub-graph: " "sym=%s expr=%s missing=%s", + sym, + ra.expr, + backed_in_expr - reachable, + ) + if kept: + filtered[sym] = kept + shape_env.deferred_runtime_asserts = filtered + try: + yield + finally: + shape_env.deferred_runtime_asserts = original + + def _read_generated_code_expected_arity(path: str) -> int | None: try: py_files = list(Path(path).rglob("*.py")) @@ -199,8 +272,14 @@ def compile( import torch._functorch.config as functorch_config from torch._inductor import standalone_compile + scope_asserts = ( + _scope_deferred_runtime_asserts(example_inputs) + if dynamic_shapes == "from_tracing_context" + else contextlib.nullcontext() + ) + try: - with functorch_config.patch(autograd_cache_allow_custom_autograd_functions=True): + with scope_asserts, functorch_config.patch(autograd_cache_allow_custom_autograd_functions=True): compiled_graph = standalone_compile( graph, example_inputs, dynamic_shapes=dynamic_shapes, options={"config_patches": current_config} ) diff --git a/tests/feature_tests/test_piecewise_deferred_assert_scope.py b/tests/feature_tests/test_piecewise_deferred_assert_scope.py new file mode 100644 index 0000000..ed15dfd --- /dev/null +++ b/tests/feature_tests/test_piecewise_deferred_assert_scope.py @@ -0,0 +1,172 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for _scope_deferred_runtime_asserts in piecewise compilation. + +Problem: after piecewise splitting, each sub-graph is compiled via +standalone_compile independently, but they share the same ShapeEnv. +ShapeEnv.deferred_runtime_asserts may contain Eq constraints referencing +backed SymInts (e.g. s90) that only belong to *another* sub-graph's +placeholders. Inductor's GraphLowering blindly emits these as Python +runtime assertions, producing ``NameError: name 'sXX' is not defined``. + +Fix: ``_scope_deferred_runtime_asserts`` (in piecewise_compiler.py) narrows +deferred_runtime_asserts to only reachable symbols before each +standalone_compile call, then restores the original dict afterwards. + +test_without_fix: patches the fix away (nullcontext) → NameError. +test_with_fix: uses the real fix → runs correctly. +""" + +from contextlib import nullcontext +from unittest.mock import patch + +import pytest +import torch +import torch.nn as nn + +from magi_compiler import magi_compile, magi_register_custom_op + +HIDDEN = 64 +NUM_MOD = 3 + + +class _Dispatcher: + """Mimics ModalityDispatcher: permute + unbacked group sizes.""" + + def __init__(self, modality_mapping, num_modalities): + self.num_modalities = num_modalities + self.permute_mapping = torch.argsort(modality_mapping) + self.inv_permute_mapping = torch.argsort(self.permute_mapping) + permuted = modality_mapping[self.permute_mapping] + + gs = torch.bincount(permuted, minlength=num_modalities).to(torch.int32) + gs_cpu = [int(v) for v in gs.to("cpu").tolist()] + + self._carrier = torch.empty(*gs_cpu) + if not torch.compiler.is_compiling(): + for i in range(num_modalities): + torch._dynamo.decorators.mark_unbacked(self._carrier, i) + + @property + def group_sizes(self): + return [self._carrier.shape[i] for i in range(self.num_modalities)] + + def dispatch(self, x): + return list(torch.split(x, self.group_sizes, dim=0)) + + def undispatch(self, *groups): + return torch.cat(groups, dim=0) + + +def _identity_meta(x): + return torch.empty_like(x) + + +@magi_register_custom_op("test_scope::identity", infer_output_meta_fn=_identity_meta, is_subgraph_boundary=True) +def identity_op(x: torch.Tensor) -> torch.Tensor: + return x.clone() + + +class _InnerBlock(nn.Module): + def __init__(self): + super().__init__() + self.linears = nn.ModuleList([nn.Linear(HIDDEN, HIDDEN, bias=False) for _ in range(NUM_MOD)]) + + def forward(self, x, permute_mapping, inv_permute_mapping, dispatcher): + # ── Sub-graph 1 (before boundary) ── + # Inputs : x (shape[0] → backed s77), dispatcher (group_sizes → unbacked u0, u1, u2) + # NOT here: permute_mapping (s90), inv_permute_mapping (s92) + chunks = dispatcher.dispatch(x) + outs = [self.linears[i](c) for i, c in enumerate(chunks)] + out = dispatcher.undispatch(*outs) + + # identity_op is registered with is_subgraph_boundary=True, which + # forces magi_compile to split here into two piecewise sub-graphs. + # + # ShapeEnv (shared) holds two deferred_runtime_asserts: + # Eq(u0 + u1 + u2, s77) ← s77 only in sub-graph 1 + # Eq(u0 + u1 + u2, s90) ← s90 only in sub-graph 2 + # + # Without _scope_deferred_runtime_asserts, sub-graph 2 inherits + # Eq(u0+u1+u2, s77) and Inductor emits `if not (... == s77):`, + # but s77 is not a placeholder in sub-graph 2 → NameError. + out = identity_op(out) + + # ── Sub-graph 2 (after boundary) ── + # Inputs : permute_mapping (shape[0] → backed s90), inv_permute_mapping (s92) + # NOT here: x (s77), u0, u1, u2 + out = out[inv_permute_mapping] + out = out[permute_mapping] + return out + + +class _OuterModel(nn.Module): + def __init__(self): + super().__init__() + self.block = _InnerBlock() + + def forward(self, x, modality_mapping): + dispatcher = _Dispatcher(modality_mapping, NUM_MOD) + x = x[dispatcher.permute_mapping] + out = self.block(x, dispatcher.permute_mapping, dispatcher.inv_permute_mapping, dispatcher) + return out[dispatcher.inv_permute_mapping] + + +def _make_inputs(sizes, device="cuda"): + seq = sum(sizes) + x = torch.randn(seq, HIDDEN, device=device) + parts = [] + for i, n in enumerate(sizes): + parts.extend([i] * n) + mm = torch.tensor(parts, dtype=torch.long, device=device) + return x, mm + + +def _build_compiled_model(device="cuda"): + torch._dynamo.reset() + model = _OuterModel().to(device).eval() + model.block = magi_compile(model.block, dynamic_arg_dims={"x": 0, "permute_mapping": 0, "inv_permute_mapping": 0}) + return torch.compile(model, dynamic=True) + + +def _run_two_shapes(compiled, device="cuda"): + """Run two different shapes to exercise the compiled model.""" + x1, mm1 = _make_inputs((32, 16, 16), device) + with torch.no_grad(): + out1 = compiled(x1, mm1) + assert out1.shape == (64, HIDDEN) + + x2, mm2 = _make_inputs((24, 12, 12), device) + with torch.no_grad(): + out2 = compiled(x2, mm2) + assert out2.shape == (48, HIDDEN) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_without_fix_raises_nameerror(): + """Without _scope_deferred_runtime_asserts, Inductor generates code + referencing a backed SymInt not present in the sub-graph → NameError.""" + compiled = _build_compiled_model() + + with patch("magi_compiler.magi_backend.piecewise_compiler._scope_deferred_runtime_asserts", return_value=nullcontext()): + with pytest.raises(NameError, match=r"name 's\d+' is not defined"): + _run_two_shapes(compiled) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_with_fix_passes(): + """With _scope_deferred_runtime_asserts active, all shapes run correctly.""" + compiled = _build_compiled_model() + _run_two_shapes(compiled)