diff --git a/src/qonnx/custom_op/general/multithreshold.py b/src/qonnx/custom_op/general/multithreshold.py index 3cdf80df..07e4ddad 100644 --- a/src/qonnx/custom_op/general/multithreshold.py +++ b/src/qonnx/custom_op/general/multithreshold.py @@ -100,7 +100,7 @@ def get_nodeattr_types(self): "out_dtype": ("s", True, ""), "out_scale": ("f", False, 1.0), "out_bias": ("f", False, 0.0), - "data_layout": ("s", False, "NCHW"), + "data_layout": ("s", False, ""), } def make_shape_compatible_op(self, model): @@ -130,12 +130,32 @@ def execute_node(self, context, graph): # retrieve attributes if output scaling is used out_scale = self.get_nodeattr("out_scale") out_bias = self.get_nodeattr("out_bias") - # transpose input if NHWC data layout is chosen + + # Consider the data layout for transposing the input into the format + # accepted by the multithreshold function above, i.e, the channel + # dimension is along the axis with index 1. data_layout = self.get_nodeattr("data_layout") - channels_last = True if data_layout[-1] == "C" else False - # calculate output + # If there is no layout annotation, guess based on rank of the + # tensor + if not data_layout and len(v.shape) < 5: + # Maps tensor rank to layout annotation + rank_to_layout = {0: None, 1: None, 2: "NC", 3: "NWC", 4: "NCHW"} + # Lookup the layout required by this input shape + data_layout = rank_to_layout[len(v.shape)] + # Lookup the index of the channel dimension in the data layout + # Note: Assumes there is at most one "C" which denotes the channel + # dimension + if data_layout is not None: + cdim = data_layout.index("C") if "C" in data_layout else 1 + else: + cdim = 1 + # Rearrange the input to the expected (N, C, ...) layout orig_shape = v.shape - output = multithreshold(v, thresholds, out_scale, out_bias, channels_last) + v = v.swapaxes(cdim, 1) + # Now we can use the multithreshold function to calculate output + output = multithreshold(v, thresholds, out_scale, out_bias) + # Rearrange the output back to the original layout + output = output.swapaxes(cdim, 1) assert output.shape == orig_shape, "Shape changed during thresholding!" context[node.output[0]] = output