Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/qonnx/custom_op/general/multithreshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def get_nodeattr_types(self):
"out_dtype": ("s", True, ""),
"out_scale": ("f", False, 1.0),
"out_bias": ("f", False, 0.0),
"data_layout": ("s", False, "NCHW"),
"data_layout": ("s", False, ""),
}

def make_shape_compatible_op(self, model):
Expand Down Expand Up @@ -130,12 +130,32 @@ def execute_node(self, context, graph):
# retrieve attributes if output scaling is used
out_scale = self.get_nodeattr("out_scale")
out_bias = self.get_nodeattr("out_bias")
# transpose input if NHWC data layout is chosen

# Consider the data layout for transposing the input into the format
# accepted by the multithreshold function above, i.e, the channel
# dimension is along the axis with index 1.
data_layout = self.get_nodeattr("data_layout")
channels_last = True if data_layout[-1] == "C" else False
# calculate output
# If there is no layout annotation, guess based on rank of the
# tensor
if not data_layout and len(v.shape) < 5:
# Maps tensor rank to layout annotation
rank_to_layout = {0: None, 1: None, 2: "NC", 3: "NWC", 4: "NCHW"}
# Lookup the layout required by this input shape
data_layout = rank_to_layout[len(v.shape)]
# Lookup the index of the channel dimension in the data layout
# Note: Assumes there is at most one "C" which denotes the channel
# dimension
if data_layout is not None:
cdim = data_layout.index("C") if "C" in data_layout else 1
else:
cdim = 1
# Rearrange the input to the expected (N, C, ...) layout
orig_shape = v.shape
output = multithreshold(v, thresholds, out_scale, out_bias, channels_last)
v = v.swapaxes(cdim, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this always work? Is it an improvement over using the channels_last parameter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 1d, 2d, 3d and 4d input this is fine and matches the previous behavior. Looking at it now, I just realized that in case there is no fallback (>= 5d) or for 0d, it will try indexing into a None. Though, as far as I am aware, this would not affect any model we are supporting at the moment. Still this should be address. I will add a quick workaround, though the entire layout mechanism is long overdue for a proper rework...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to just use the channels_last flag?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, channels_last can only differentiate between channels at axis 1 or -1, but we had instances where channels could end up at any index, at least temporarily while transforming the model. Usually we always end up with channels last (anything else cannot really be implemented in FINN at the moment), but to get there we might transition through a few mixed-up layouts, but we still want to be able to execute/simulate this.

# Now we can use the multithreshold function to calculate output
output = multithreshold(v, thresholds, out_scale, out_bias)
# Rearrange the output back to the original layout
output = output.swapaxes(cdim, 1)
assert output.shape == orig_shape, "Shape changed during thresholding!"
context[node.output[0]] = output

Expand Down