-
Notifications
You must be signed in to change notification settings - Fork 25
[quantization] Quantization-Aware Shape Operation Folding #499
Description
From #491
Constant Folding Logic
To note, below is the ConstProp pass's core logic.
# propagate constant because all of its args are constant tensors. with torch.no_grad(): prop_constant_tensor = node.target(*args_data, **kwargs_data) const_node_to_tensor[node] = prop_constant_tensorQuantized Operation's Constant Folding
- reshape.default(quantized_a, shape) : It is folded well.
- add.Tensor(quantized_a, quantized_b) # ERR! : If operation includes calculation, it makes an error as below.
import torch input_a = torch.quantize_per_tensor(torch.tensor([1.0, 2.0], dtype=torch.float32), 0.1, 10, torch.qint8) input_b = torch.quantize_per_tensor(torch.tensor([3.0, 4.0], dtype=torch.float32), 0.1, 10, torch.qint8) target_op_add = torch.ops.aten.add.Tensor target_op_reshape = torch.ops.aten.reshape.default try: with torch.no_grad(): result = target_op_add(input_a, input_b) # ERROR! # Error: Could not run 'aten::add.out' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. result = target_op_reshape(input_a, torch.tensor([1, 1, 2], dtype=torch.int32)) # PASSES print(f"Result Scale: {result.q_scale()}, ZP: {result.q_zero_point()}") except Exception as e: print(f"Error: {e}")Therefore, the problematic case will be caught by this exception.
I believe we could implement const props with only allowed shape-relevant operators.
Originally posted by @dayo09 in #491 (comment)
What?
Let's fold shape operators(reshape, view, permute, transpose, ...) with full consideration of Channel-wise quantization!
Why it is needed?
When "Quantization Boundary Break" occurs due to some shape-related operators are generated after composite operation's decomposition, Circle is compiled with weird pattern (int,int) => Conv2d => (float) (Why has it not validated in TICO, though? 🤔) linke: #491 (comment)
Since these are "shape-only" operations, they can be pre-calculated at compile-time to simplify the graph and improve inference performance.
How?
- precondition: Circle IR doesn't restrict channel-wise quantization axis to certain number, it offers the field.
- scale/zp can be aligned with the axis. We should track if the axis is kept well - not merged, not splitted.
We need to extend ConstPropPass to handle aten.reshape, aten.transpose, and aten.view for quantized tensors by synchronizing the quantization metadata.
- Transpose / Permute Logic
- Metadata Update: Update the quantized_dimension (axis) to reflect the new position of the channel dimension.
- Reshape / View / StridedSlice / Concat Logic
- Validation: * If the channel dimension remains "independent" (not merged with other dimensions), simply update the axis index.