diff --git a/magi_compiler/_api.py b/magi_compiler/_api.py index 442b67c..c96386a 100644 --- a/magi_compiler/_api.py +++ b/magi_compiler/_api.py @@ -314,6 +314,18 @@ def _mark_dynamic_shapes(state: MagiCompileState, bound): final_dims = [arg.ndim + d if d < 0 else d for d in dims] + for d in final_dims: + dim_size = arg.shape[d] + if dim_size <= 1: + raise ValueError( + f"Argument '{k}' has size {dim_size} on dynamic dim {d}. " + "PyTorch Dynamo specializes on 0/1 sizes (see " + "https://docs.pytorch.org/docs/stable/user_guide/torch_compiler/compile/" + "dynamic_shapes_zero_one_specialization.html), " + "so this dimension will NOT be treated as dynamic. " + "Use an initial input with size >= 2 on dynamic dims to enable shape generalization." + ) + torch._dynamo.mark_dynamic(arg, final_dims) dynamic_records[id(arg)] = set(final_dims) diff --git a/magi_compiler/magi_backend/piecewise_backend.py b/magi_compiler/magi_backend/piecewise_backend.py index 95f8e40..752db95 100644 --- a/magi_compiler/magi_backend/piecewise_backend.py +++ b/magi_compiler/magi_backend/piecewise_backend.py @@ -22,6 +22,7 @@ import torch.fx as fx from magi_compiler.config import CompileConfig +from magi_compiler.utils import magi_logger if TYPE_CHECKING: from .magi_backend import CompilerManager @@ -85,7 +86,9 @@ def __call__(self, *args) -> Any: self.check_for_ending_compilation() return self.compiled_graph_for_general_shape(*args) - assert len(self.sym_shape_indices) != 0, "No symbolic shape indices found" + if len(self.sym_shape_indices) == 0: + magi_logger.info("No symbolic shape indices found, falling back to general shape compiled graph") + return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape diff --git a/tests/model_tests/test_mlp_infer.py b/tests/model_tests/test_mlp_infer.py index c591a73..e30211e 100644 --- a/tests/model_tests/test_mlp_infer.py +++ b/tests/model_tests/test_mlp_infer.py @@ -45,6 +45,16 @@ def test_mlp_basic_inference(device, mlp_config, mlp_model): assert output.dtype == torch.float32, f"Output data type should be torch.float32, but got {output.dtype}" +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") +def test_mlp_batch1_first_call_raises(device, mlp_config, mlp_model): + """First call with batch=1 should raise ValueError due to zero-one specialization""" + input_tensor = torch.randn(1, mlp_config.hidden_size, device=device, dtype=torch.bfloat16) + + with torch.no_grad(): + with pytest.raises(ValueError, match="specializes on 0/1 sizes"): + mlp_model(input_tensor) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support") def test_mlp_different_input_shapes(device, mlp_config, mlp_model): """Test different input shapes"""