Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/modules/op/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,23 @@ def get_example_inputs(self):
torch.randn(2, IC, 8, 16, 16),
torch.randn(OC, IC // groups, 3, 3, 3),
), {}


class Conv3dWithPerfectFitKernel(torch.nn.Module):
"""Conv3D with perfect fitting kernel"""

def __init__(self):
super().__init__()
self.conv3d = torch.nn.Conv3d(
in_channels=3,
out_channels=1024,
kernel_size=(2, 16, 16),
stride=(2, 16, 16),
padding=(0, 0, 0),
)

def forward(self, input):
return self.conv3d(input)

def get_example_inputs(self):
return (torch.randn(5, 3, 2, 16, 16),), {}
29 changes: 29 additions & 0 deletions test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,32 @@ def test_pass(self):
self.run_value_test(ConvertConv3dToConv2d())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 0)
self.assertGreaterEqual(num_of_ops(self.exported_program(), ops.aten.conv2d), 2)


class Conv3dPerfectFitKernel(torch.nn.Module):
"""Conv3D with perfect fitting kernel"""

def __init__(self):
super().__init__()
self.conv3d = torch.nn.Conv3d(
in_channels=3,
out_channels=1024,
kernel_size=(2, 16, 16),
stride=(2, 16, 16),
padding=(0, 0, 0),
)

def forward(self, input):
return self.conv3d(input)

def get_example_inputs(self):
return (torch.randn(5, 3, 2, 16, 16),), {}


class ConvertConv3dPerfectFitKernelTest(SinglePassValueTest):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(optional) How about adding it to op test too? test/modules/op/conv3d.py

def test_pass(self):
self.setup(Conv3dPerfectFitKernel())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 1)
self.run_value_test(ConvertConv3dToConv2d())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 0)
self.assertGreaterEqual(num_of_ops(self.exported_program(), ops.aten.conv2d), 1)
86 changes: 85 additions & 1 deletion tico/passes/convert_conv3d_to_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,88 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo

return modified

def optimized_convert(
self, exported_program: ExportedProgram, node: torch.fx.Node
) -> bool:
logger = logging.getLogger(__name__)
modified = False
graph_module = exported_program.graph_module
graph = graph_module.graph

# Extract conv3d arguments
args = Conv3DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]

input = args.input
weight = args.weight
bias = args.bias
groups = args.groups

input_shape = extract_shape(input)
weight_shape = extract_shape(weight)

if not (len(input_shape) == 5):
raise NotYetSupportedError(
f"Only support 5D input tensor: node's input shape: {input_shape}"
)

if not (len(weight_shape) == 5):
raise NotYetSupportedError(
f"Only support 5D weight tensor: node's weight shape: {weight_shape}"
)

N, C_in, T_in, H_in, W_in = input_shape
C_out, C_in_weight, kT, kH, kW = weight_shape

if T_in == kT and H_in == kH and W_in == kW and groups == 1:
Copy link
Copy Markdown
Contributor

@seockho-kim seockho-kim Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this pass work when input shape(T,H,W) is same with kernel's?

with graph.inserting_before(node):
input_reshape = create_node(
graph,
torch.ops.aten.reshape.default,
args=(input, [1, 1, N, C_in * T_in * H_in * W_in]),
origin=node,
)
weight_reshape = create_node(
graph,
torch.ops.aten.reshape.default,
args=(weight, [C_out, 1, 1, C_in_weight * kT * kH * kW]),
origin=node,
)
conv2d = create_node(
graph,
torch.ops.aten.conv2d.default,
args=(
input_reshape,
weight_reshape,
bias,
[1, 1], # stride
[0, 0], # padding
[1, 1], # dilation
groups,
),
origin=node,
)
Comment on lines +452 to +465
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@llFreetimell We have considered Linear vs Conv2d for TRIV optimization, so, in the end, is Conv2d better?

conv2d_permute = create_node(
graph,
torch.ops.aten.permute.default,
args=(conv2d, [2, 1, 0, 3]),
origin=node,
)
conv2d_reshape = create_node(
graph,
torch.ops.aten.reshape.default,
args=(conv2d_permute, [N, C_out, 1, 1, 1]),
origin=node,
)

# Replace the original node
node.replace_all_uses_with(conv2d_reshape, propagate_meta=False)
logger.debug(
f"{node.name} is replaced with optimized conv2d decomposition"
)
modified = True

return modified

def call(self, exported_program: ExportedProgram) -> PassResult:
target_conv_op = [torch.ops.aten.conv3d.default, torch.ops.aten.conv3d.padding]
graph_module = exported_program.graph_module
Expand All @@ -414,7 +496,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
for node in graph.nodes:
if not is_target_node(node, target_conv_op):
continue
modified |= self.convert(exported_program, node)
modified |= self.optimized_convert(exported_program, node) or self.convert(
exported_program, node
)

graph.eliminate_dead_code()
graph.lint()
Expand Down