Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions magi_compiler/magi_backend/magi_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,54 @@ def _build_fake_args(self, args: tuple) -> list:
fake_args.append(t)
return fake_args

@staticmethod
def _restride_outputs(target: str, output: Any, output_strides: list | None) -> Any:
"""Update FakeTensor output strides to match what Inductor will produce.

``standalone_compile`` may change the memory layout of a subgraph's
outputs (e.g. mm output padding, kernel fusion). The downstream
subgraph will be compiled with the FakeTensor strides that flow out of
this method, so they **must** reflect Inductor's actual output layout.

``output_strides`` comes from Inductor's ``set_tracing_context_output_strides``
which evaluates symbolic stride expressions to concrete ints. When the
FakeTensor already has a symbolic stride (e.g. ``5120*s93``), replacing it
with a concrete value (e.g. ``20244480``) would specialize that dimension
and break dynamic-shape compilation for other sequence lengths. We only
apply restride for dimensions where *both* sides are statically known
(concrete ints) and differ.
"""
if not output_strides:
return output

outputs: list[Any]
is_tuple = isinstance(output, (tuple, list))
outputs = list(output) if is_tuple else [output]

for i, strides in enumerate(output_strides):
if strides is None or i >= len(outputs):
continue
t = outputs[i]
if not isinstance(t, torch.Tensor) or t.dim() == 0:
continue

old_strides = t.stride()
new_strides = list(old_strides)
changed = False
for d, (old_s, new_s) in enumerate(zip(old_strides, strides)):
if isinstance(old_s, torch.SymInt) or isinstance(new_s, torch.SymInt):
continue
if old_s != new_s:
new_strides[d] = new_s
changed = True

if not changed:
continue
magi_logger.info("Restriding output %d of '%s': %s -> %s", i, target, tuple(old_strides), tuple(new_strides))
outputs[i] = t.as_strided(t.shape, new_strides)

return type(output)(outputs) if is_tuple else outputs[0]

def call_module(
self, target: torch.fx.node.Target, args: tuple[torch.fx.node.Argument, ...], kwargs: dict[str, Any]
) -> Any:
Expand All @@ -390,6 +438,9 @@ def call_module(
runtime_shape=None,
)

output_strides = getattr(self.compiler_manager.compiler, "_last_output_strides", None)
output = self._restride_outputs(target, output, output_strides)

piecewise_backend = PiecewiseBackend(
submod,
compiled_graph_for_dynamic_shape,
Expand Down
46 changes: 45 additions & 1 deletion magi_compiler/magi_backend/piecewise_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,46 @@ def _read_generated_code_expected_arity(path: str) -> int | None:
return None


def _intercept_inductor_output_strides() -> tuple[contextlib.AbstractContextManager, list[tuple[int, ...] | None]]:
"""Create a context manager that captures the output strides Inductor
reports during ``standalone_compile``.

``standalone_compile`` creates its own ``TracingContext`` that is destroyed
when its ``with`` block exits, so the strides it writes are lost. We
intercept ``set_tracing_context_output_strides`` to capture them.

Returns ``(context_manager, captured_list)``; the caller should enter the
context manager around the ``standalone_compile`` call, then read
``captured_list`` afterwards.
"""
import torch._inductor.output_code as output_code_mod
import torch._inductor.utils as inductor_utils

original_fn = inductor_utils.set_tracing_context_output_strides
captured: list[tuple[int, ...] | None] = []

def _hook(example_inputs, compiled_graph):
original_fn(example_inputs, compiled_graph)
ctx = torch._guards.TracingContext.try_get()
if ctx is not None and ctx.output_strides is not None:
captured.extend(ctx.output_strides)

@contextlib.contextmanager
def _ctx():
inductor_utils.set_tracing_context_output_strides = _hook
original_ref = getattr(output_code_mod, "set_tracing_context_output_strides", None)
if original_ref is not None:
output_code_mod.set_tracing_context_output_strides = _hook
try:
yield
finally:
inductor_utils.set_tracing_context_output_strides = original_fn
if original_ref is not None:
output_code_mod.set_tracing_context_output_strides = original_fn

return _ctx(), captured


class CompilerInterface:
"""
The interface for a compiler that can be used by MagiCompiler.
Expand Down Expand Up @@ -223,6 +263,7 @@ def __init__(self, compile_config):

self.compile_config: CompileConfig = compile_config
self._restart_analysis_counts: dict[str, int] = {}
self._last_output_strides: list[tuple[int, ...] | None] | None = None

@property
def hash(self) -> str:
Expand Down Expand Up @@ -277,9 +318,10 @@ def compile(
if dynamic_shapes == "from_tracing_context"
else contextlib.nullcontext()
)
stride_ctx, captured_strides = _intercept_inductor_output_strides()

try:
with scope_asserts, functorch_config.patch(autograd_cache_allow_custom_autograd_functions=True):
with scope_asserts, functorch_config.patch(autograd_cache_allow_custom_autograd_functions=True), stride_ctx:
compiled_graph = standalone_compile(
graph, example_inputs, dynamic_shapes=dynamic_shapes, options={"config_patches": current_config}
)
Expand All @@ -298,6 +340,8 @@ def compile(
)
raise

self._last_output_strides = captured_strides if captured_strides else None

# Step3: Save the compiled artifact
# autograd_cache_allow_custom_autograd_functions=True is required above so that
# autograd_function_apply (a HigherOrderOperator) does not bypass AOTAutograd cache
Expand Down
96 changes: 96 additions & 0 deletions tests/feature_tests/test_stride_mismatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Test: non-contiguous view crossing a piecewise boundary.

When ``split`` + ``unsqueeze`` produces a non-contiguous view and that tensor
crosses a piecewise boundary (custom op registered as subgraph boundary),
Inductor may change its stride (e.g. mm padding, kernel fusion).

The framework fix in ``PiecewiseCompileInterpreter._restride_outputs``
captures Inductor's actual output strides via
``TracingContext.report_output_strides`` and updates the FakeTensor
metadata before it flows into the next subgraph's compilation.
This ensures ``assert_size_stride`` in the downstream subgraph matches
the runtime stride.
"""

import pytest
import torch
import torch.nn as nn

from magi_compiler import magi_compile, magi_register_custom_op
from magi_compiler.config import get_compile_config


@magi_register_custom_op(name="test_stride::boundary_op", is_subgraph_boundary=True)
def boundary_op(x: torch.Tensor) -> torch.Tensor:
return x.clone()


class SplitGateModel(nn.Module):
"""Linear -> split -> unsqueeze(non-contiguous gate) -> boundary_op
-> use gate after boundary."""

def __init__(self, hidden: int, main_dim: int, gate_dim: int):
super().__init__()
self.main_dim = main_dim
self.gate_dim = gate_dim
self.proj = nn.Linear(hidden, main_dim + gate_dim, bias=False)
self.out_proj = nn.Linear(main_dim, hidden, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
projected = self.proj(x)
main, gate = projected.split([self.main_dim, self.gate_dim], dim=1)
gate = gate.unsqueeze(-1) # non-contiguous view, stride=(total, 1, 1)

main = main.reshape(x.shape[0], self.gate_dim, self.main_dim // self.gate_dim)
main = torch.ops.test_stride.boundary_op(main)

out = main * torch.sigmoid(gate)
return self.out_proj(out.reshape(x.shape[0], self.main_dim))


def _run():
torch._dynamo.reset()
get_compile_config().splitting_ops.clear()
get_compile_config().splitting_ops.append("test_stride::boundary_op")

device = "cuda"
dtype = torch.bfloat16
hidden, main_dim, gate_dim = 5120, 5120, 40

model = SplitGateModel(hidden, main_dim, gate_dim).to(device, dtype).eval()
compiled = magi_compile(model, dynamic_arg_dims={"x": 0})

for seq_len in [32, 64, 17]:
x = torch.randn(seq_len, hidden, device=device, dtype=dtype)
with torch.no_grad():
ref = model(x)
out = compiled(x)
torch.testing.assert_close(out, ref, atol=1e-2, rtol=1e-2)
print(f" seq_len={seq_len}: PASS")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_non_contiguous_view_across_piecewise_boundary():
"""Non-contiguous gate view should work without .contiguous() thanks to
_restride_outputs aligning FakeTensor strides with Inductor output."""
_run()


if __name__ == "__main__":
_run()
print("PASS: non-contiguous view across piecewise boundary handled correctly")
146 changes: 146 additions & 0 deletions tests/feature_tests/test_unbacked_symbol_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Test for GuardOnDataDependentSymNode triggered by mark_unbacked + view with -1.

When a dispatcher marks per-group token counts as unbacked symbols
(u0, u1, u2) via mark_unbacked, downstream reshape operations that infer
a dimension via -1 force Inductor to guard on expressions like
`u0+u1+u2 > 0`, which is forbidden for unbacked symbolic integers.
"""

pass

import pytest
import torch
import torch._dynamo.decorators

HIDDEN_SIZE = 64
NUM_HEADS = 4
HEAD_DIM = 16


class ModalityDispatcherMinimal:
"""Stripped-down ModalityDispatcher that reproduces the mark_unbacked pattern."""

def __init__(self, modality_mapping: torch.Tensor, num_modalities: int):
self.num_modalities = num_modalities
permute_mapping = torch.argsort(modality_mapping)
permuted = modality_mapping[permute_mapping]
group_size = torch.bincount(permuted, minlength=num_modalities).to(torch.int32)
group_size_cpu = [int(x) for x in group_size.to("cpu").tolist()]

self._size_carrier = torch.empty(*group_size_cpu)
if not torch.compiler.is_compiling():
for i in range(num_modalities):
torch._dynamo.decorators.mark_unbacked(self._size_carrier, i)

@property
def group_size_cpu(self) -> list[int]:
return [self._size_carrier.shape[i] for i in range(self.num_modalities)]

def dispatch(self, x: torch.Tensor) -> list[torch.Tensor]:
return list(torch.split(x, self.group_size_cpu, dim=0))

def undispatch(self, *parts: torch.Tensor) -> torch.Tensor:
return torch.cat(parts, dim=0)


class BuggyModel(torch.nn.Module):
"""Reproduces the original bug: view(k.shape[0], num_heads, -1)."""

def __init__(self):
super().__init__()
qkv_out = NUM_HEADS * HEAD_DIM * 3 + NUM_HEADS
self.linear = torch.nn.Linear(HIDDEN_SIZE, qkv_out, bias=False, dtype=torch.bfloat16)

def forward(self, x: torch.Tensor, modality_mapping: torch.Tensor):
md = ModalityDispatcherMinimal(modality_mapping, 3)
parts = md.dispatch(x)
x = md.undispatch(*parts)

qkv = self.linear(x)
q_size = NUM_HEADS * HEAD_DIM
kv_size = NUM_HEADS * HEAD_DIM
_, k, _, g = torch.split(qkv, [q_size, kv_size, kv_size, NUM_HEADS], dim=1)
k = k.view(-1, NUM_HEADS, HEAD_DIM)
# BUG: view with -1 on unbacked-symbol seq_len triggers guard error
g = g.view(k.shape[0], NUM_HEADS, -1)
return g.sum()


class FixedModel(torch.nn.Module):
"""Fixed version: unsqueeze(-1) avoids the problematic -1 inference."""

def __init__(self):
super().__init__()
qkv_out = NUM_HEADS * HEAD_DIM * 3 + NUM_HEADS
self.linear = torch.nn.Linear(HIDDEN_SIZE, qkv_out, bias=False, dtype=torch.bfloat16)

def forward(self, x: torch.Tensor, modality_mapping: torch.Tensor):
md = ModalityDispatcherMinimal(modality_mapping, 3)
parts = md.dispatch(x)
x = md.undispatch(*parts)

qkv = self.linear(x)
q_size = NUM_HEADS * HEAD_DIM
kv_size = NUM_HEADS * HEAD_DIM
_, k, _, g = torch.split(qkv, [q_size, kv_size, kv_size, NUM_HEADS], dim=1)
k = k.view(-1, NUM_HEADS, HEAD_DIM)
# FIX: unsqueeze avoids -1 dimension inference on unbacked symbols
g = g.unsqueeze(-1)
return g.sum()


def _make_inputs(seq_len: int, modality_sizes: list[int], device: str):
assert sum(modality_sizes) == seq_len
parts = []
for mod_id, size in enumerate(modality_sizes):
parts.append(torch.full((size,), mod_id, dtype=torch.long, device=device))
modality_mapping = torch.cat(parts)
x = torch.randn(seq_len, HIDDEN_SIZE, dtype=torch.bfloat16, device=device)
return x, modality_mapping


def test_unbacked_symbol_guard_error():
"""The original view(k.shape[0], NUM_HEADS, -1) MUST raise GuardOnDataDependentSymNode."""
torch._dynamo.reset()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BuggyModel().to(device)
compiled = torch.compile(model, dynamic=True, fullgraph=False)

x, mm = _make_inputs(150, [100, 30, 20], device)
with pytest.raises(torch._inductor.exc.InductorError, match="GuardOnDataDependentSymNode"):
compiled(x, mm)


def test_unbacked_symbol_guard_fixed():
"""The fixed unsqueeze(-1) version must succeed for multiple dynamic shapes."""
torch._dynamo.reset()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = FixedModel().to(device)
compiled = torch.compile(model, dynamic=True, fullgraph=False)

x1, mm1 = _make_inputs(150, [100, 30, 20], device)
out1 = compiled(x1, mm1)
assert out1.shape == ()

x2, mm2 = _make_inputs(200, [120, 50, 30], device)
out2 = compiled(x2, mm2)
assert out2.shape == ()


if __name__ == "__main__":
pytest.main([__file__, "-v"])