Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion magi_compiler/magi_backend/piecewise_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import re
from abc import abstractmethod
from collections.abc import Callable
Expand Down Expand Up @@ -39,6 +40,78 @@ def _summarize_compile_input(value: Any) -> str:
return type(value).__name__


@contextlib.contextmanager
def _scope_deferred_runtime_asserts(example_inputs: list[Any]):
"""Temporarily narrow ``deferred_runtime_asserts`` to only those whose
backed symbols are reachable from the current piecewise sub-graph's inputs.

Inductor's ``GraphLowering`` copies the *entire*
``shape_env.deferred_runtime_asserts`` dict and later emits each assert as
a Python statement referencing the symbol name. When ``standalone_compile``
is called with ``from_tracing_context`` on a piecewise sub-graph, the
shared ``ShapeEnv`` may contain asserts that reference backed SymInts
absent from this sub-graph's placeholders (e.g. ``s90``), causing a
``NameError`` at runtime.

This context manager scopes the dict for the duration of
``standalone_compile`` and restores the original afterwards so subsequent
sub-graphs see the full set.
"""
import sympy
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.symbolic_shapes import SymTypes, free_unbacked_symbols

context = torch._guards.TracingContext.try_get()
if context is None or context.fake_mode is None:
yield
return
shape_env = context.fake_mode.shape_env
if not shape_env.deferred_runtime_asserts:
yield
return

# Collect all backed symbols reachable from this sub-graph's inputs.
# e.g. sub-graph inputs: x(s77, 64), u0, u1, u2 β†’ reachable = {s77}
reachable: set[sympy.Symbol] = set()
for inp in example_inputs:
if isinstance(inp, SymTypes):
# SymInt input, e.g. unbacked u0 β†’ {u0}
reachable |= inp.node.expr.free_symbols
elif isinstance(inp, torch.Tensor):
if isinstance(inp, FakeTensor):
for s in inp.shape:
if isinstance(s, torch.SymInt):
# Tensor dim, e.g. x.shape[0] = s77 β†’ {s77}
reachable |= s.node.expr.free_symbols

# Filter: keep only asserts whose backed symbols all exist in this sub-graph.
# e.g. Eq(u0+u1+u2, s77): backed={s77} βŠ† {s77} β†’ keep
# Eq(u0+u1+u2, s90): backed={s90} βŠ„ {s77} β†’ drop (s90 is in another sub-graph)
original: dict[sympy.Symbol, list[sympy.Expr]] = shape_env.deferred_runtime_asserts
filtered: dict[sympy.Symbol, list[sympy.Expr]] = {}
for sym, ras in original.items():
kept: list[sympy.Expr] = []
for ra in ras:
# all_symbols - unbacked_symbols = backed symbols only
backed_in_expr = ra.expr.free_symbols - free_unbacked_symbols(ra.expr)
if backed_in_expr <= reachable:
kept.append(ra)
else:
magi_logger.debug(
"Filtering unreachable deferred assert for sub-graph: " "sym=%s expr=%s missing=%s",
sym,
ra.expr,
backed_in_expr - reachable,
)
if kept:
filtered[sym] = kept
shape_env.deferred_runtime_asserts = filtered
try:
yield
finally:
shape_env.deferred_runtime_asserts = original


def _read_generated_code_expected_arity(path: str) -> int | None:
try:
py_files = list(Path(path).rglob("*.py"))
Expand Down Expand Up @@ -199,8 +272,14 @@ def compile(
import torch._functorch.config as functorch_config
from torch._inductor import standalone_compile

scope_asserts = (
_scope_deferred_runtime_asserts(example_inputs)
if dynamic_shapes == "from_tracing_context"
else contextlib.nullcontext()
)

try:
with functorch_config.patch(autograd_cache_allow_custom_autograd_functions=True):
with scope_asserts, functorch_config.patch(autograd_cache_allow_custom_autograd_functions=True):
compiled_graph = standalone_compile(
graph, example_inputs, dynamic_shapes=dynamic_shapes, options={"config_patches": current_config}
)
Expand Down
172 changes: 172 additions & 0 deletions tests/feature_tests/test_piecewise_deferred_assert_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for _scope_deferred_runtime_asserts in piecewise compilation.

Problem: after piecewise splitting, each sub-graph is compiled via
standalone_compile independently, but they share the same ShapeEnv.
ShapeEnv.deferred_runtime_asserts may contain Eq constraints referencing
backed SymInts (e.g. s90) that only belong to *another* sub-graph's
placeholders. Inductor's GraphLowering blindly emits these as Python
runtime assertions, producing ``NameError: name 'sXX' is not defined``.

Fix: ``_scope_deferred_runtime_asserts`` (in piecewise_compiler.py) narrows
deferred_runtime_asserts to only reachable symbols before each
standalone_compile call, then restores the original dict afterwards.

test_without_fix: patches the fix away (nullcontext) β†’ NameError.
test_with_fix: uses the real fix β†’ runs correctly.
"""

from contextlib import nullcontext
from unittest.mock import patch

import pytest
import torch
import torch.nn as nn

from magi_compiler import magi_compile, magi_register_custom_op

HIDDEN = 64
NUM_MOD = 3


class _Dispatcher:
"""Mimics ModalityDispatcher: permute + unbacked group sizes."""

def __init__(self, modality_mapping, num_modalities):
self.num_modalities = num_modalities
self.permute_mapping = torch.argsort(modality_mapping)
self.inv_permute_mapping = torch.argsort(self.permute_mapping)
permuted = modality_mapping[self.permute_mapping]

gs = torch.bincount(permuted, minlength=num_modalities).to(torch.int32)
gs_cpu = [int(v) for v in gs.to("cpu").tolist()]

self._carrier = torch.empty(*gs_cpu)
if not torch.compiler.is_compiling():
for i in range(num_modalities):
torch._dynamo.decorators.mark_unbacked(self._carrier, i)

@property
def group_sizes(self):
return [self._carrier.shape[i] for i in range(self.num_modalities)]

def dispatch(self, x):
return list(torch.split(x, self.group_sizes, dim=0))

def undispatch(self, *groups):
return torch.cat(groups, dim=0)


def _identity_meta(x):
return torch.empty_like(x)


@magi_register_custom_op("test_scope::identity", infer_output_meta_fn=_identity_meta, is_subgraph_boundary=True)
def identity_op(x: torch.Tensor) -> torch.Tensor:
return x.clone()


class _InnerBlock(nn.Module):
def __init__(self):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(HIDDEN, HIDDEN, bias=False) for _ in range(NUM_MOD)])

def forward(self, x, permute_mapping, inv_permute_mapping, dispatcher):
# ── Sub-graph 1 (before boundary) ──
# Inputs : x (shape[0] β†’ backed s77), dispatcher (group_sizes β†’ unbacked u0, u1, u2)
# NOT here: permute_mapping (s90), inv_permute_mapping (s92)
chunks = dispatcher.dispatch(x)
outs = [self.linears[i](c) for i, c in enumerate(chunks)]
out = dispatcher.undispatch(*outs)

# identity_op is registered with is_subgraph_boundary=True, which
# forces magi_compile to split here into two piecewise sub-graphs.
#
# ShapeEnv (shared) holds two deferred_runtime_asserts:
# Eq(u0 + u1 + u2, s77) ← s77 only in sub-graph 1
# Eq(u0 + u1 + u2, s90) ← s90 only in sub-graph 2
#
# Without _scope_deferred_runtime_asserts, sub-graph 2 inherits
# Eq(u0+u1+u2, s77) and Inductor emits `if not (... == s77):`,
# but s77 is not a placeholder in sub-graph 2 β†’ NameError.
out = identity_op(out)

# ── Sub-graph 2 (after boundary) ──
# Inputs : permute_mapping (shape[0] β†’ backed s90), inv_permute_mapping (s92)
# NOT here: x (s77), u0, u1, u2
out = out[inv_permute_mapping]
out = out[permute_mapping]
return out


class _OuterModel(nn.Module):
def __init__(self):
super().__init__()
self.block = _InnerBlock()

def forward(self, x, modality_mapping):
dispatcher = _Dispatcher(modality_mapping, NUM_MOD)
x = x[dispatcher.permute_mapping]
out = self.block(x, dispatcher.permute_mapping, dispatcher.inv_permute_mapping, dispatcher)
return out[dispatcher.inv_permute_mapping]


def _make_inputs(sizes, device="cuda"):
seq = sum(sizes)
x = torch.randn(seq, HIDDEN, device=device)
parts = []
for i, n in enumerate(sizes):
parts.extend([i] * n)
mm = torch.tensor(parts, dtype=torch.long, device=device)
return x, mm


def _build_compiled_model(device="cuda"):
torch._dynamo.reset()
model = _OuterModel().to(device).eval()
model.block = magi_compile(model.block, dynamic_arg_dims={"x": 0, "permute_mapping": 0, "inv_permute_mapping": 0})
return torch.compile(model, dynamic=True)


def _run_two_shapes(compiled, device="cuda"):
"""Run two different shapes to exercise the compiled model."""
x1, mm1 = _make_inputs((32, 16, 16), device)
with torch.no_grad():
out1 = compiled(x1, mm1)
assert out1.shape == (64, HIDDEN)

x2, mm2 = _make_inputs((24, 12, 12), device)
with torch.no_grad():
out2 = compiled(x2, mm2)
assert out2.shape == (48, HIDDEN)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_without_fix_raises_nameerror():
"""Without _scope_deferred_runtime_asserts, Inductor generates code
referencing a backed SymInt not present in the sub-graph β†’ NameError."""
compiled = _build_compiled_model()

with patch("magi_compiler.magi_backend.piecewise_compiler._scope_deferred_runtime_asserts", return_value=nullcontext()):
with pytest.raises(NameError, match=r"name 's\d+' is not defined"):
_run_two_shapes(compiled)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_with_fix_passes():
"""With _scope_deferred_runtime_asserts active, all shapes run correctly."""
compiled = _build_compiled_model()
_run_two_shapes(compiled)