-
Notifications
You must be signed in to change notification settings - Fork 25
[passes] Optimize ConvertConv3dToConv2d for special case #518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -403,6 +403,88 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo | |
|
|
||
| return modified | ||
|
|
||
| def optimized_convert( | ||
| self, exported_program: ExportedProgram, node: torch.fx.Node | ||
| ) -> bool: | ||
| logger = logging.getLogger(__name__) | ||
| modified = False | ||
| graph_module = exported_program.graph_module | ||
| graph = graph_module.graph | ||
|
|
||
| # Extract conv3d arguments | ||
| args = Conv3DArgs(*node.args, **node.kwargs) # type: ignore[arg-type] | ||
|
|
||
| input = args.input | ||
| weight = args.weight | ||
| bias = args.bias | ||
| groups = args.groups | ||
|
|
||
| input_shape = extract_shape(input) | ||
| weight_shape = extract_shape(weight) | ||
|
|
||
| if not (len(input_shape) == 5): | ||
| raise NotYetSupportedError( | ||
| f"Only support 5D input tensor: node's input shape: {input_shape}" | ||
| ) | ||
|
|
||
| if not (len(weight_shape) == 5): | ||
| raise NotYetSupportedError( | ||
| f"Only support 5D weight tensor: node's weight shape: {weight_shape}" | ||
| ) | ||
|
|
||
| N, C_in, T_in, H_in, W_in = input_shape | ||
| C_out, C_in_weight, kT, kH, kW = weight_shape | ||
|
|
||
| if T_in == kT and H_in == kH and W_in == kW and groups == 1: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this pass work when input shape(T,H,W) is same with kernel's? |
||
| with graph.inserting_before(node): | ||
| input_reshape = create_node( | ||
| graph, | ||
| torch.ops.aten.reshape.default, | ||
| args=(input, [1, 1, N, C_in * T_in * H_in * W_in]), | ||
| origin=node, | ||
| ) | ||
| weight_reshape = create_node( | ||
| graph, | ||
| torch.ops.aten.reshape.default, | ||
| args=(weight, [C_out, 1, 1, C_in_weight * kT * kH * kW]), | ||
| origin=node, | ||
| ) | ||
| conv2d = create_node( | ||
| graph, | ||
| torch.ops.aten.conv2d.default, | ||
| args=( | ||
| input_reshape, | ||
| weight_reshape, | ||
| bias, | ||
| [1, 1], # stride | ||
| [0, 0], # padding | ||
| [1, 1], # dilation | ||
| groups, | ||
| ), | ||
| origin=node, | ||
| ) | ||
|
Comment on lines
+452
to
+465
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @llFreetimell We have considered Linear vs Conv2d for TRIV optimization, so, in the end, is Conv2d better? |
||
| conv2d_permute = create_node( | ||
| graph, | ||
| torch.ops.aten.permute.default, | ||
| args=(conv2d, [2, 1, 0, 3]), | ||
| origin=node, | ||
| ) | ||
| conv2d_reshape = create_node( | ||
| graph, | ||
| torch.ops.aten.reshape.default, | ||
| args=(conv2d_permute, [N, C_out, 1, 1, 1]), | ||
| origin=node, | ||
| ) | ||
|
|
||
| # Replace the original node | ||
| node.replace_all_uses_with(conv2d_reshape, propagate_meta=False) | ||
| logger.debug( | ||
| f"{node.name} is replaced with optimized conv2d decomposition" | ||
| ) | ||
| modified = True | ||
|
|
||
| return modified | ||
|
|
||
| def call(self, exported_program: ExportedProgram) -> PassResult: | ||
| target_conv_op = [torch.ops.aten.conv3d.default, torch.ops.aten.conv3d.padding] | ||
| graph_module = exported_program.graph_module | ||
|
|
@@ -414,7 +496,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult: | |
| for node in graph.nodes: | ||
| if not is_target_node(node, target_conv_op): | ||
| continue | ||
| modified |= self.convert(exported_program, node) | ||
| modified |= self.optimized_convert(exported_program, node) or self.convert( | ||
| exported_program, node | ||
| ) | ||
|
|
||
| graph.eliminate_dead_code() | ||
| graph.lint() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(optional) How about adding it to op test too? test/modules/op/conv3d.py