Optimize aten::min/max.dim with TopK op#2780
Optimize aten::min/max.dim with TopK op#2780danielhumanmod wants to merge 6 commits intomicrosoft:mainfrom
Conversation
@microsoft-github-policy-service agree |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2780 +/- ##
==========================================
+ Coverage 70.46% 70.72% +0.25%
==========================================
Files 228 230 +2
Lines 27258 27443 +185
Branches 2761 2757 -4
==========================================
+ Hits 19208 19409 +201
+ Misses 7100 7096 -4
+ Partials 950 938 -12 ☔ View full report in Codecov by Sentry. |
Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.
So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation. |
|
I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass) |
Yeah, that's a good point. It makes more sense to handle this in the rewriter/optimizer. I will take a look at the rules and follow up. Thanks for the feedback! |
|
Hey @justinchuby ,I’ve added a new rewrite rule to optimize this case based on our previous discussion. Whenever you have a moment, I’d appreciate your thoughts on it. Thanks! |
There was a problem hiding this comment.
Pull request overview
Adds a new ONNXScript rewriter rule to fuse Reduce{Max,Min} + Arg{Max,Min} patterns into a single TopK (plus optional Squeeze), aiming to improve performance for torch.min/max(dim=...)-style graphs.
Changes:
- Introduces
FuseReduce{Max,Min}Arg{Max,Min}ToTopKrewrite rules and aRewriteRuleSet. - Adds extensive unit tests covering success and failure conditions across opset 13 and 18.
- Validates numerical equivalence and serialized-model correctness for rewritten graphs.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py |
Implements the Reduce+Arg → TopK fusion rules for both max and min cases. |
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py |
Adds unit tests for the new fusion rules, including opset and attribute/input variants. |
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py
Outdated
Show resolved
Hide resolved
| from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet | ||
|
|
||
|
|
||
| class FuseReduceArgToTopKBase(RewriteRuleClassBase): |
There was a problem hiding this comment.
| class FuseReduceArgToTopKBase(RewriteRuleClassBase): | |
| class _FuseReduceArgToTopKBase(RewriteRuleClassBase): |
| reduce_keepdims = ( | ||
| reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1 | ||
| ) | ||
| arg_keepdims = arg_keepdims_attr.as_int() if arg_keepdims_attr is not None else 1 |
|
|
||
| # ONNX default: keepdims = 1 for both Reduce and Arg operations | ||
| reduce_keepdims = ( | ||
| reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1 |
There was a problem hiding this comment.
@justinchuby is there a more clean way to get the default? (just curious)
There was a problem hiding this comment.
I think there should be a get_int method
| ) | ||
|
|
||
| # Step 3: Get axes from Reduce operation | ||
| # In opset 18+, axes is an input; in opset 13-17, it's an attribute |
There was a problem hiding this comment.
I wonder if we would be interested in only supporting opset 18+ here to reduce the complexity? (we have version converter) It's just the matter whether we see the rule will be applied standalone or not I guess?
There was a problem hiding this comment.
That makes sense to remove, I see this rule should be mostly used in pipeline, thanks for the suggestion!
There was a problem hiding this comment.
Only opset 18+ is fine
|
|
||
| # Step 7: Normalize axes if rank is known (handle negative indices) | ||
| input_x = reduce_node.inputs[0] | ||
| rank = len(input_x.shape) if input_x.shape is not None else None |
There was a problem hiding this comment.
I wonder if symbolic shape could work on this case? @justinchuby
There was a problem hiding this comment.
Skipping none of shape means this does not support dynamic at the moment. But symbolic inference should be able to handle the eq
|
You will have to enable it here: |
I don’t think we want to enable this by default. It is unclear if this is generally more performant. @danielhumanmod you may simply expose the rule in https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/rules/common/__init__.py |
| *_fuse_pad_into_conv.rules, | ||
| *_fuse_batchnorm.rules, | ||
| *_remove_optional_bias.rules, | ||
| *_fuse_reduce_arg_to_topk.rules, |
There was a problem hiding this comment.
| *_fuse_reduce_arg_to_topk.rules, |
Do not expose as default unless it is measured to be performant
Fix pytorch/pytorch#76344
Context
As mentioned in the issue,
torch.max(dim=...)can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).In additional, given the
torch.min(dim=...)has the similar pattern with max, I also apply this optimization to it.Verification
Successfully passed existing OpInfo consistency tests: