Skip to content

[Fix]: handle zero-one specialization in PiecewiseBackend#24

Merged
jiahy0825 merged 1 commit into
SandAI-org:mainfrom
cennn:fix/issue-23-zero-one-specialization
Apr 23, 2026
Merged

[Fix]: handle zero-one specialization in PiecewiseBackend#24
jiahy0825 merged 1 commit into
SandAI-org:mainfrom
cennn:fix/issue-23-zero-one-specialization

Conversation

@cennn
Copy link
Copy Markdown
Collaborator

@cennn cennn commented Apr 22, 2026

When the first call uses batch size 0 or 1, PyTorch Dynamo specializes that dimension as a static constant, leaving sym_shape_indices empty and causing an AssertionError on subsequent calls with different shapes.

  • Raise ValueError early in _mark_dynamic_shapes when dim_size <= 1 so users get a clear, actionable error instead of a cryptic assert
  • Replace the hard assert in PiecewiseBackend.call with a fallback to compiled_graph_for_general_shape + info log as a safety net
  • Add test_mlp_batch1_first_call_raises regression test

🗂️ PR Category

  • ✨ New Feature
  • 🚀 Optimization (performance, memory, etc.)
  • 💥 Breaking Change
  • 🐛 Bug Fix
  • 🛠️ Development / Refactoring
  • 📚 Documentation
  • 🧹 Chore (Dependencies, CI/CD, Configuration, etc.)
  • 🧪 Testing

📝 Description

Problem

From: ISSUE-23
When the first call to a @magi_compile-decorated module uses batch size 0 or 1, PyTorch Dynamo's zero-one specialization treats that dimension as a static constant rather than a symbolic value. This causes sym_shape_indices to be empty in PiecewiseBackend, resulting in AssertionError: No symbolic shape indices found on subsequent calls with different shapes.

Changes

  • Raise early with a clear message: Upgrade the existing magi_logger.warning in _mark_dynamic_shapes to a ValueError, so users get an actionable error at the point of compilation rather than a cryptic assert deep in the piecewise backend.

  • Add fallback safety net: Replace the hard assert in PiecewiseBackend.call with a graceful fallback to compiled_graph_for_general_shape plus an info log, covering any other scenario where sym_shape_indices might be empty.

  • Add regression test: test_mlp_batch1_first_call_raises verifies that calling with batch=1 as the first invocation raises ValueError with a clear message.

Test

All 5 tests in tests/model_tests/test_mlp_infer.py pass, including the new regression test.

When the first call uses batch size 0 or 1, PyTorch Dynamo specializes
that dimension as a static constant, leaving sym_shape_indices empty and
causing an AssertionError on subsequent calls with different shapes.

- Raise ValueError early in _mark_dynamic_shapes when dim_size <= 1 so
  users get a clear, actionable error instead of a cryptic assert
- Replace the hard assert in PiecewiseBackend.__call__ with a fallback
  to compiled_graph_for_general_shape + info log as a safety net
- Add test_mlp_batch1_first_call_raises regression test
Copy link
Copy Markdown
Collaborator

@jiahy0825 jiahy0825 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiahy0825 jiahy0825 merged commit f1202a8 into SandAI-org:main Apr 23, 2026
3 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants