diff --git a/test/modules/op/conv3d.py b/test/modules/op/conv3d.py index 0f75aca1..4ebf8f67 100644 --- a/test/modules/op/conv3d.py +++ b/test/modules/op/conv3d.py @@ -306,3 +306,23 @@ def get_example_inputs(self): torch.randn(2, IC, 8, 16, 16), torch.randn(OC, IC // groups, 3, 3, 3), ), {} + + +class Conv3dWithPerfectFitKernel(torch.nn.Module): + """Conv3D with perfect fitting kernel""" + + def __init__(self): + super().__init__() + self.conv3d = torch.nn.Conv3d( + in_channels=3, + out_channels=1024, + kernel_size=(2, 16, 16), + stride=(2, 16, 16), + padding=(0, 0, 0), + ) + + def forward(self, input): + return self.conv3d(input) + + def get_example_inputs(self): + return (torch.randn(5, 3, 2, 16, 16),), {} diff --git a/test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py b/test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py index f5fd038d..5af54636 100644 --- a/test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py +++ b/test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py @@ -264,3 +264,32 @@ def test_pass(self): self.run_value_test(ConvertConv3dToConv2d()) self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 0) self.assertGreaterEqual(num_of_ops(self.exported_program(), ops.aten.conv2d), 2) + + +class Conv3dPerfectFitKernel(torch.nn.Module): + """Conv3D with perfect fitting kernel""" + + def __init__(self): + super().__init__() + self.conv3d = torch.nn.Conv3d( + in_channels=3, + out_channels=1024, + kernel_size=(2, 16, 16), + stride=(2, 16, 16), + padding=(0, 0, 0), + ) + + def forward(self, input): + return self.conv3d(input) + + def get_example_inputs(self): + return (torch.randn(5, 3, 2, 16, 16),), {} + + +class ConvertConv3dPerfectFitKernelTest(SinglePassValueTest): + def test_pass(self): + self.setup(Conv3dPerfectFitKernel()) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 1) + self.run_value_test(ConvertConv3dToConv2d()) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 0) + self.assertGreaterEqual(num_of_ops(self.exported_program(), ops.aten.conv2d), 1) diff --git a/tico/passes/convert_conv3d_to_conv2d.py b/tico/passes/convert_conv3d_to_conv2d.py index 5664c72f..9d04d314 100644 --- a/tico/passes/convert_conv3d_to_conv2d.py +++ b/tico/passes/convert_conv3d_to_conv2d.py @@ -403,6 +403,88 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo return modified + def optimized_convert( + self, exported_program: ExportedProgram, node: torch.fx.Node + ) -> bool: + logger = logging.getLogger(__name__) + modified = False + graph_module = exported_program.graph_module + graph = graph_module.graph + + # Extract conv3d arguments + args = Conv3DArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + + input = args.input + weight = args.weight + bias = args.bias + groups = args.groups + + input_shape = extract_shape(input) + weight_shape = extract_shape(weight) + + if not (len(input_shape) == 5): + raise NotYetSupportedError( + f"Only support 5D input tensor: node's input shape: {input_shape}" + ) + + if not (len(weight_shape) == 5): + raise NotYetSupportedError( + f"Only support 5D weight tensor: node's weight shape: {weight_shape}" + ) + + N, C_in, T_in, H_in, W_in = input_shape + C_out, C_in_weight, kT, kH, kW = weight_shape + + if T_in == kT and H_in == kH and W_in == kW and groups == 1: + with graph.inserting_before(node): + input_reshape = create_node( + graph, + torch.ops.aten.reshape.default, + args=(input, [1, 1, N, C_in * T_in * H_in * W_in]), + origin=node, + ) + weight_reshape = create_node( + graph, + torch.ops.aten.reshape.default, + args=(weight, [C_out, 1, 1, C_in_weight * kT * kH * kW]), + origin=node, + ) + conv2d = create_node( + graph, + torch.ops.aten.conv2d.default, + args=( + input_reshape, + weight_reshape, + bias, + [1, 1], # stride + [0, 0], # padding + [1, 1], # dilation + groups, + ), + origin=node, + ) + conv2d_permute = create_node( + graph, + torch.ops.aten.permute.default, + args=(conv2d, [2, 1, 0, 3]), + origin=node, + ) + conv2d_reshape = create_node( + graph, + torch.ops.aten.reshape.default, + args=(conv2d_permute, [N, C_out, 1, 1, 1]), + origin=node, + ) + + # Replace the original node + node.replace_all_uses_with(conv2d_reshape, propagate_meta=False) + logger.debug( + f"{node.name} is replaced with optimized conv2d decomposition" + ) + modified = True + + return modified + def call(self, exported_program: ExportedProgram) -> PassResult: target_conv_op = [torch.ops.aten.conv3d.default, torch.ops.aten.conv3d.padding] graph_module = exported_program.graph_module @@ -414,7 +496,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult: for node in graph.nodes: if not is_target_node(node, target_conv_op): continue - modified |= self.convert(exported_program, node) + modified |= self.optimized_convert(exported_program, node) or self.convert( + exported_program, node + ) graph.eliminate_dead_code() graph.lint()