Skip to content

Commit 12208d1

Browse files
committed
fix typo and handle negative dims
1 parent b2c5d08 commit 12208d1

1 file changed

Lines changed: 4 additions & 3 deletions

File tree

tico/passes/convert_permute_to_reshape.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,18 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
6363
dims = args.dims
6464

6565
input_shape = extract_shape(input)
66+
normalized_dims = [(d if d >= 0 else d + len(input_shape)) for d in dims]
6667

6768
# When permute dims with non-1 values have same order,
6869
# we can replace permute to reshape
6970
#
7071
# For example, if
7172
# - input.shape = [1, x, 1, y]
7273
# - torch.permute(input, [1, 2, 3, 0])
73-
# then permute dims 2 and 0 keeps same order for 'x' and 'y'.
74+
# then permute dims 1 and 3 keeps same order for 'x' and 'y'.
7475
is_same_order = True
7576
last_dim = -1
76-
for dim in dims:
77+
for dim in normalized_dims:
7778
if input_shape[dim] == 1:
7879
continue
7980

@@ -88,7 +89,7 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
8889
reshape = create_node(
8990
graph,
9091
torch.ops.aten.reshape.default,
91-
args=(input, [input_shape[dim] for dim in dims]),
92+
args=(input, [input_shape[dim] for dim in normalized_dims]),
9293
origin=node,
9394
)
9495

0 commit comments

Comments
 (0)