From 606630c7da721ab2422e7e91a5579917f2a85acd Mon Sep 17 00:00:00 2001 From: cenzhiyao Date: Mon, 27 Apr 2026 23:22:24 +0800 Subject: [PATCH] [Fix] Capture and align Inductor output strides across piecewise sub-graphs Inductor may change output memory layout (e.g. mm padding, kernel fusion) during standalone_compile. When FakeTensor strides from sub-graph N flow into sub-graph N+1's compilation, mismatched strides cause assert_size_stride failures at runtime. - Add _intercept_inductor_output_strides to capture strides Inductor reports via set_tracing_context_output_strides before the TracingContext is destroyed. - Add _restride_outputs to update FakeTensor strides using as_strided (zero-copy view) so downstream sub-graphs compile with correct layouts. - Add test_stride_mismatch.py for non-contiguous view across piecewise boundary regression. - Add test_unbacked_symbol_guard.py for GuardOnDataDependentSymNode regression with mark_unbacked + view(-1). --- magi_compiler/magi_backend/magi_backend.py | 51 ++++++ .../magi_backend/piecewise_compiler.py | 46 +++++- tests/feature_tests/test_stride_mismatch.py | 96 ++++++++++++ .../test_unbacked_symbol_guard.py | 146 ++++++++++++++++++ 4 files changed, 338 insertions(+), 1 deletion(-) create mode 100644 tests/feature_tests/test_stride_mismatch.py create mode 100644 tests/feature_tests/test_unbacked_symbol_guard.py diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index c00bad2..0d010e3 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -368,6 +368,54 @@ def _build_fake_args(self, args: tuple) -> list: fake_args.append(t) return fake_args + @staticmethod + def _restride_outputs(target: str, output: Any, output_strides: list | None) -> Any: + """Update FakeTensor output strides to match what Inductor will produce. + + ``standalone_compile`` may change the memory layout of a subgraph's + outputs (e.g. mm output padding, kernel fusion). The downstream + subgraph will be compiled with the FakeTensor strides that flow out of + this method, so they **must** reflect Inductor's actual output layout. + + ``output_strides`` comes from Inductor's ``set_tracing_context_output_strides`` + which evaluates symbolic stride expressions to concrete ints. When the + FakeTensor already has a symbolic stride (e.g. ``5120*s93``), replacing it + with a concrete value (e.g. ``20244480``) would specialize that dimension + and break dynamic-shape compilation for other sequence lengths. We only + apply restride for dimensions where *both* sides are statically known + (concrete ints) and differ. + """ + if not output_strides: + return output + + outputs: list[Any] + is_tuple = isinstance(output, (tuple, list)) + outputs = list(output) if is_tuple else [output] + + for i, strides in enumerate(output_strides): + if strides is None or i >= len(outputs): + continue + t = outputs[i] + if not isinstance(t, torch.Tensor) or t.dim() == 0: + continue + + old_strides = t.stride() + new_strides = list(old_strides) + changed = False + for d, (old_s, new_s) in enumerate(zip(old_strides, strides)): + if isinstance(old_s, torch.SymInt) or isinstance(new_s, torch.SymInt): + continue + if old_s != new_s: + new_strides[d] = new_s + changed = True + + if not changed: + continue + magi_logger.info("Restriding output %d of '%s': %s -> %s", i, target, tuple(old_strides), tuple(new_strides)) + outputs[i] = t.as_strided(t.shape, new_strides) + + return type(output)(outputs) if is_tuple else outputs[0] + def call_module( self, target: torch.fx.node.Target, args: tuple[torch.fx.node.Argument, ...], kwargs: dict[str, Any] ) -> Any: @@ -390,6 +438,9 @@ def call_module( runtime_shape=None, ) + output_strides = getattr(self.compiler_manager.compiler, "_last_output_strides", None) + output = self._restride_outputs(target, output, output_strides) + piecewise_backend = PiecewiseBackend( submod, compiled_graph_for_dynamic_shape, diff --git a/magi_compiler/magi_backend/piecewise_compiler.py b/magi_compiler/magi_backend/piecewise_compiler.py index d7136ce..3e4aa31 100644 --- a/magi_compiler/magi_backend/piecewise_compiler.py +++ b/magi_compiler/magi_backend/piecewise_compiler.py @@ -127,6 +127,46 @@ def _read_generated_code_expected_arity(path: str) -> int | None: return None +def _intercept_inductor_output_strides() -> tuple[contextlib.AbstractContextManager, list[tuple[int, ...] | None]]: + """Create a context manager that captures the output strides Inductor + reports during ``standalone_compile``. + + ``standalone_compile`` creates its own ``TracingContext`` that is destroyed + when its ``with`` block exits, so the strides it writes are lost. We + intercept ``set_tracing_context_output_strides`` to capture them. + + Returns ``(context_manager, captured_list)``; the caller should enter the + context manager around the ``standalone_compile`` call, then read + ``captured_list`` afterwards. + """ + import torch._inductor.output_code as output_code_mod + import torch._inductor.utils as inductor_utils + + original_fn = inductor_utils.set_tracing_context_output_strides + captured: list[tuple[int, ...] | None] = [] + + def _hook(example_inputs, compiled_graph): + original_fn(example_inputs, compiled_graph) + ctx = torch._guards.TracingContext.try_get() + if ctx is not None and ctx.output_strides is not None: + captured.extend(ctx.output_strides) + + @contextlib.contextmanager + def _ctx(): + inductor_utils.set_tracing_context_output_strides = _hook + original_ref = getattr(output_code_mod, "set_tracing_context_output_strides", None) + if original_ref is not None: + output_code_mod.set_tracing_context_output_strides = _hook + try: + yield + finally: + inductor_utils.set_tracing_context_output_strides = original_fn + if original_ref is not None: + output_code_mod.set_tracing_context_output_strides = original_fn + + return _ctx(), captured + + class CompilerInterface: """ The interface for a compiler that can be used by MagiCompiler. @@ -223,6 +263,7 @@ def __init__(self, compile_config): self.compile_config: CompileConfig = compile_config self._restart_analysis_counts: dict[str, int] = {} + self._last_output_strides: list[tuple[int, ...] | None] | None = None @property def hash(self) -> str: @@ -277,9 +318,10 @@ def compile( if dynamic_shapes == "from_tracing_context" else contextlib.nullcontext() ) + stride_ctx, captured_strides = _intercept_inductor_output_strides() try: - with scope_asserts, functorch_config.patch(autograd_cache_allow_custom_autograd_functions=True): + with scope_asserts, functorch_config.patch(autograd_cache_allow_custom_autograd_functions=True), stride_ctx: compiled_graph = standalone_compile( graph, example_inputs, dynamic_shapes=dynamic_shapes, options={"config_patches": current_config} ) @@ -298,6 +340,8 @@ def compile( ) raise + self._last_output_strides = captured_strides if captured_strides else None + # Step3: Save the compiled artifact # autograd_cache_allow_custom_autograd_functions=True is required above so that # autograd_function_apply (a HigherOrderOperator) does not bypass AOTAutograd cache diff --git a/tests/feature_tests/test_stride_mismatch.py b/tests/feature_tests/test_stride_mismatch.py new file mode 100644 index 0000000..2a26b16 --- /dev/null +++ b/tests/feature_tests/test_stride_mismatch.py @@ -0,0 +1,96 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test: non-contiguous view crossing a piecewise boundary. + +When ``split`` + ``unsqueeze`` produces a non-contiguous view and that tensor +crosses a piecewise boundary (custom op registered as subgraph boundary), +Inductor may change its stride (e.g. mm padding, kernel fusion). + +The framework fix in ``PiecewiseCompileInterpreter._restride_outputs`` +captures Inductor's actual output strides via +``TracingContext.report_output_strides`` and updates the FakeTensor +metadata before it flows into the next subgraph's compilation. +This ensures ``assert_size_stride`` in the downstream subgraph matches +the runtime stride. +""" + +import pytest +import torch +import torch.nn as nn + +from magi_compiler import magi_compile, magi_register_custom_op +from magi_compiler.config import get_compile_config + + +@magi_register_custom_op(name="test_stride::boundary_op", is_subgraph_boundary=True) +def boundary_op(x: torch.Tensor) -> torch.Tensor: + return x.clone() + + +class SplitGateModel(nn.Module): + """Linear -> split -> unsqueeze(non-contiguous gate) -> boundary_op + -> use gate after boundary.""" + + def __init__(self, hidden: int, main_dim: int, gate_dim: int): + super().__init__() + self.main_dim = main_dim + self.gate_dim = gate_dim + self.proj = nn.Linear(hidden, main_dim + gate_dim, bias=False) + self.out_proj = nn.Linear(main_dim, hidden, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + projected = self.proj(x) + main, gate = projected.split([self.main_dim, self.gate_dim], dim=1) + gate = gate.unsqueeze(-1) # non-contiguous view, stride=(total, 1, 1) + + main = main.reshape(x.shape[0], self.gate_dim, self.main_dim // self.gate_dim) + main = torch.ops.test_stride.boundary_op(main) + + out = main * torch.sigmoid(gate) + return self.out_proj(out.reshape(x.shape[0], self.main_dim)) + + +def _run(): + torch._dynamo.reset() + get_compile_config().splitting_ops.clear() + get_compile_config().splitting_ops.append("test_stride::boundary_op") + + device = "cuda" + dtype = torch.bfloat16 + hidden, main_dim, gate_dim = 5120, 5120, 40 + + model = SplitGateModel(hidden, main_dim, gate_dim).to(device, dtype).eval() + compiled = magi_compile(model, dynamic_arg_dims={"x": 0}) + + for seq_len in [32, 64, 17]: + x = torch.randn(seq_len, hidden, device=device, dtype=dtype) + with torch.no_grad(): + ref = model(x) + out = compiled(x) + torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2) + print(f" seq_len={seq_len}: PASS") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_non_contiguous_view_across_piecewise_boundary(): + """Non-contiguous gate view should work without .contiguous() thanks to + _restride_outputs aligning FakeTensor strides with Inductor output.""" + _run() + + +if __name__ == "__main__": + _run() + print("PASS: non-contiguous view across piecewise boundary handled correctly") diff --git a/tests/feature_tests/test_unbacked_symbol_guard.py b/tests/feature_tests/test_unbacked_symbol_guard.py new file mode 100644 index 0000000..48ae1a3 --- /dev/null +++ b/tests/feature_tests/test_unbacked_symbol_guard.py @@ -0,0 +1,146 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test for GuardOnDataDependentSymNode triggered by mark_unbacked + view with -1. + +When a dispatcher marks per-group token counts as unbacked symbols +(u0, u1, u2) via mark_unbacked, downstream reshape operations that infer +a dimension via -1 force Inductor to guard on expressions like +`u0+u1+u2 > 0`, which is forbidden for unbacked symbolic integers. +""" + +pass + +import pytest +import torch +import torch._dynamo.decorators + +HIDDEN_SIZE = 64 +NUM_HEADS = 4 +HEAD_DIM = 16 + + +class ModalityDispatcherMinimal: + """Stripped-down ModalityDispatcher that reproduces the mark_unbacked pattern.""" + + def __init__(self, modality_mapping: torch.Tensor, num_modalities: int): + self.num_modalities = num_modalities + permute_mapping = torch.argsort(modality_mapping) + permuted = modality_mapping[permute_mapping] + group_size = torch.bincount(permuted, minlength=num_modalities).to(torch.int32) + group_size_cpu = [int(x) for x in group_size.to("cpu").tolist()] + + self._size_carrier = torch.empty(*group_size_cpu) + if not torch.compiler.is_compiling(): + for i in range(num_modalities): + torch._dynamo.decorators.mark_unbacked(self._size_carrier, i) + + @property + def group_size_cpu(self) -> list[int]: + return [self._size_carrier.shape[i] for i in range(self.num_modalities)] + + def dispatch(self, x: torch.Tensor) -> list[torch.Tensor]: + return list(torch.split(x, self.group_size_cpu, dim=0)) + + def undispatch(self, *parts: torch.Tensor) -> torch.Tensor: + return torch.cat(parts, dim=0) + + +class BuggyModel(torch.nn.Module): + """Reproduces the original bug: view(k.shape[0], num_heads, -1).""" + + def __init__(self): + super().__init__() + qkv_out = NUM_HEADS * HEAD_DIM * 3 + NUM_HEADS + self.linear = torch.nn.Linear(HIDDEN_SIZE, qkv_out, bias=False, dtype=torch.bfloat16) + + def forward(self, x: torch.Tensor, modality_mapping: torch.Tensor): + md = ModalityDispatcherMinimal(modality_mapping, 3) + parts = md.dispatch(x) + x = md.undispatch(*parts) + + qkv = self.linear(x) + q_size = NUM_HEADS * HEAD_DIM + kv_size = NUM_HEADS * HEAD_DIM + _, k, _, g = torch.split(qkv, [q_size, kv_size, kv_size, NUM_HEADS], dim=1) + k = k.view(-1, NUM_HEADS, HEAD_DIM) + # BUG: view with -1 on unbacked-symbol seq_len triggers guard error + g = g.view(k.shape[0], NUM_HEADS, -1) + return g.sum() + + +class FixedModel(torch.nn.Module): + """Fixed version: unsqueeze(-1) avoids the problematic -1 inference.""" + + def __init__(self): + super().__init__() + qkv_out = NUM_HEADS * HEAD_DIM * 3 + NUM_HEADS + self.linear = torch.nn.Linear(HIDDEN_SIZE, qkv_out, bias=False, dtype=torch.bfloat16) + + def forward(self, x: torch.Tensor, modality_mapping: torch.Tensor): + md = ModalityDispatcherMinimal(modality_mapping, 3) + parts = md.dispatch(x) + x = md.undispatch(*parts) + + qkv = self.linear(x) + q_size = NUM_HEADS * HEAD_DIM + kv_size = NUM_HEADS * HEAD_DIM + _, k, _, g = torch.split(qkv, [q_size, kv_size, kv_size, NUM_HEADS], dim=1) + k = k.view(-1, NUM_HEADS, HEAD_DIM) + # FIX: unsqueeze avoids -1 dimension inference on unbacked symbols + g = g.unsqueeze(-1) + return g.sum() + + +def _make_inputs(seq_len: int, modality_sizes: list[int], device: str): + assert sum(modality_sizes) == seq_len + parts = [] + for mod_id, size in enumerate(modality_sizes): + parts.append(torch.full((size,), mod_id, dtype=torch.long, device=device)) + modality_mapping = torch.cat(parts) + x = torch.randn(seq_len, HIDDEN_SIZE, dtype=torch.bfloat16, device=device) + return x, modality_mapping + + +def test_unbacked_symbol_guard_error(): + """The original view(k.shape[0], NUM_HEADS, -1) MUST raise GuardOnDataDependentSymNode.""" + torch._dynamo.reset() + device = "cuda" if torch.cuda.is_available() else "cpu" + model = BuggyModel().to(device) + compiled = torch.compile(model, dynamic=True, fullgraph=False) + + x, mm = _make_inputs(150, [100, 30, 20], device) + with pytest.raises(torch._inductor.exc.InductorError, match="GuardOnDataDependentSymNode"): + compiled(x, mm) + + +def test_unbacked_symbol_guard_fixed(): + """The fixed unsqueeze(-1) version must succeed for multiple dynamic shapes.""" + torch._dynamo.reset() + device = "cuda" if torch.cuda.is_available() else "cpu" + model = FixedModel().to(device) + compiled = torch.compile(model, dynamic=True, fullgraph=False) + + x1, mm1 = _make_inputs(150, [100, 30, 20], device) + out1 = compiled(x1, mm1) + assert out1.shape == () + + x2, mm2 = _make_inputs(200, [120, 50, 30], device) + out2 = compiled(x2, mm2) + assert out2.shape == () + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])