Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions magi_compiler/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,18 @@ def _mark_dynamic_shapes(state: MagiCompileState, bound):

final_dims = [arg.ndim + d if d < 0 else d for d in dims]

for d in final_dims:
dim_size = arg.shape[d]
if dim_size <= 1:
raise ValueError(
f"Argument '{k}' has size {dim_size} on dynamic dim {d}. "
"PyTorch Dynamo specializes on 0/1 sizes (see "
"https://docs.pytorch.org/docs/stable/user_guide/torch_compiler/compile/"
"dynamic_shapes_zero_one_specialization.html), "
"so this dimension will NOT be treated as dynamic. "
"Use an initial input with size >= 2 on dynamic dims to enable shape generalization."
)

torch._dynamo.mark_dynamic(arg, final_dims)

dynamic_records[id(arg)] = set(final_dims)
Expand Down
5 changes: 4 additions & 1 deletion magi_compiler/magi_backend/piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.fx as fx

from magi_compiler.config import CompileConfig
from magi_compiler.utils import magi_logger

if TYPE_CHECKING:
from .magi_backend import CompilerManager
Expand Down Expand Up @@ -85,7 +86,9 @@ def __call__(self, *args) -> Any:
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)

assert len(self.sym_shape_indices) != 0, "No symbolic shape indices found"
if len(self.sym_shape_indices) == 0:
magi_logger.info("No symbolic shape indices found, falling back to general shape compiled graph")
return self.compiled_graph_for_general_shape(*args)
runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
Expand Down
10 changes: 10 additions & 0 deletions tests/model_tests/test_mlp_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ def test_mlp_basic_inference(device, mlp_config, mlp_model):
assert output.dtype == torch.float32, f"Output data type should be torch.float32, but got {output.dtype}"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_mlp_batch1_first_call_raises(device, mlp_config, mlp_model):
"""First call with batch=1 should raise ValueError due to zero-one specialization"""
input_tensor = torch.randn(1, mlp_config.hidden_size, device=device, dtype=torch.bfloat16)

with torch.no_grad():
with pytest.raises(ValueError, match="specializes on 0/1 sizes"):
mlp_model(input_tensor)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA support")
def test_mlp_different_input_shapes(device, mlp_config, mlp_model):
"""Test different input shapes"""
Expand Down