Onnxscript implementation of BatchNormToAffine#185
Onnxscript implementation of BatchNormToAffine#185alanbacellar wants to merge 2 commits intofastmachinelearning:mainfrom
Conversation
… match in replace pattern
| @@ -1,4 +1,4 @@ | |||
| # Copyright (c) 2020 Xilinx, Inc. | |||
| # Copyright (c) 2025 Advanced Micro Devices, Inc. | |||
There was a problem hiding this comment.
please do not remove older copyrights, you can just add the new one for 2025 AMD as the 2nd line
| # and/or other materials provided with the distribution. | ||
| # | ||
| # * Neither the name of Xilinx nor the names of its | ||
| # * Neither the name of AMD nor the names of its |
There was a problem hiding this comment.
actually this one should be QONNX, not Xilinx or AMD
| # Get epsilon from matched pattern | ||
| batch_norm = kwargs['match'].nodes[0] | ||
| epsilon_attr = batch_norm.attributes.get('epsilon', None) | ||
| epsilon_value = 1e-5 if epsilon_attr is None else epsilon_attr.value |
There was a problem hiding this comment.
would be better to have the 1e-5 default configurable with e.g. a top level variable in the module (or if possible, passed in as an optional arg to the Transformation with this default value)
| input_shape = x.shape | ||
| assert input_shape is not None and len(input_shape) >= 2 | ||
| n_spatial_dims = len(input_shape) - 2 | ||
| axes = [0] + [i + 2 for i in range(n_spatial_dims)] | ||
| A = op.Unsqueeze(A, axes=axes) | ||
| B = op.Unsqueeze(B, axes=axes) |
There was a problem hiding this comment.
- is this always safe? does it make sense to add more testcases with different dimensionalties? (now much easier to create dummy models with onnxscript)
- the original transformation removes surrounding squeeze/unsqueeze nodes if they were present around the BatchNorm, does this new version still have the same effect? (also another good thing to test/check)
| assert (output_original == output_lowered).all() | ||
|
|
||
| op_types = list(map(lambda x: x.op_type, model_lowered.graph.node)) | ||
| assert "BatchNormalization" not in op_types |
There was a problem hiding this comment.
check for no Unsqueeze/Squeeze left here as well?
| batch_norm = kwargs['match'].nodes[0] | ||
| epsilon_attr = batch_norm.attributes.get('epsilon', None) |
There was a problem hiding this comment.
can this be handled by something like https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/cast_constant_of_shape.py#L13-L22 to get the attribute instead, removing the need for the special util?
Uh oh!
There was an error while loading. Please reload this page.