Add complete ONNX export support for SleepConViT#4
Add complete ONNX export support for SleepConViT#4runwangdl merged 5 commits intopulp-platform:develfrom
Conversation
This commit adds support for the ONNX Sub (subtraction) operator: - Created SubOperatorTest class in onnx4deeploy/operators/sub.py * Implements element-wise binary subtraction (A - B) * Supports Numpy-style broadcasting * Based on ONNX Opset 14 specification * Uses SimpleElementwiseOperator base class - Registered Sub operator in onnx4deeploy/operators/__init__.py - Added default configuration in onnx4deeploy/operators/config.yaml * Default input shape: [1, 64, 32, 32] (NCHW format) * Opset version: 14 - Added comprehensive pytest tests in tests/operators/test_operators.py * test_sub_basic: Verifies basic functionality * test_sub_different_shapes: Tests multiple input shapes * test_sub_result_range: Validates output range correctness * All tests passing ✓ Generated test files validated: - network.onnx: Sub operator graph - inputs.npz: Two input tensors (input_a, input_b) - outputs.npz: Expected output (A - B) - Numerical verification: max difference = 0.000000e+00 Reference: https://onnx.ai/onnx/operators/onnx__Sub.html Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
e5b1168 to
cd5d182
Compare
There was a problem hiding this comment.
Pull request overview
This pull request adds ONNX export support for the SleepConViT model, a vision transformer for sleep stage classification. It introduces new operator test implementations (Sub, Concat), enhances Conv2D with asymmetric padding support, and implements a sophisticated weight loading mechanism for ONNX-to-PyTorch conversion that handles MatMul weight transposition and weight-sharing optimizations.
Changes:
- Added Sub operator implementation with comprehensive tests
- Added Concat operator implementation supporting multiple inputs with cascaded concatenation
- Enhanced Conv2D operator to support asymmetric padding (e.g., [1, 2, 1, 2]) for flexible convolution configurations
- Implemented CascadedConcat custom operation in SleepConViT to prevent ONNX optimizer from fusing 3-input concatenations
- Added ONNX weight loading mechanism in
sleep_convit_exporter.pythat handles MatMul transposition and weight-sharing optimizations - Updated model to remove
.expand()for batch_size=1 optimization
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
onnx4deeploy/operators/__init__.py |
Added imports and exports for Sub, Concat, and Selu operators |
onnx4deeploy/operators/sub.py |
New Sub operator implementation following SimpleElementwiseOperator pattern |
onnx4deeploy/operators/concat.py |
New Concat operator implementation with shape validation and multiple input support |
onnx4deeploy/operators/config.yaml |
Default configurations for new operators (Sub, Selu, Conv2D with asymmetric padding, Concat) |
onnx4deeploy/operators/conv2d.py |
Enhanced to support asymmetric padding and changed default use_bias to False |
tests/operators/test_operators.py |
Added comprehensive test suite for Sub operator (basic, different shapes, range validation) |
onnx4deeploy/models/sleep_convit_exporter.py |
Implemented ONNX weight loading with MatMul transposition and weight-sharing logic |
onnx4deeploy/models/pytorch_models/sleep_convit/sleep_convit.py |
Added CascadedConcat custom operation and optimized cls_token concatenation for batch_size=1 |
Onnx4Deeploy.py |
Updated help text to document Concat and Conv2D capabilities |
Comments suppressed due to low confidence (1)
onnx4deeploy/operators/init.py:35
- The import
from .selu import SeluOperatorTestwill fail because the fileonnx4deeploy/operators/selu.pydoes not exist in this pull request. This import statement references a non-existent module and will cause an ImportError when the module is loaded. Either add the missingselu.pyfile or remove this import and its references from__all__.
from .sgd import SGDOperatorTest
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
onnx4deeploy/models/pytorch_models/sleep_convit/sleep_convit.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Run Wang <52746141+runwangdl@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Run Wang <52746141+runwangdl@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Run Wang <52746141+runwangdl@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.
Comments suppressed due to low confidence (1)
Onnx4Deeploy.py:179
- The Sub operator is implemented and imported in the operators module, but it's not listed in the list_available_operators() function. This means users cannot discover or use it through the command-line interface. Add an entry for "Sub": "Subtraction operator" to the operators dictionary.
def list_available_operators():
"""List available operators"""
operators = {
# Basic operators
"Add": "Addition operator",
"Relu": "ReLU activation function",
"Transpose": "Tensor transpose",
"Concat": "Tensor concatenation (supports 3 inputs)",
"Split": "Tensor split",
# Matrix operations
"Gemm": "General matrix multiplication",
"MatMul": "Matrix multiplication",
# Pooling
"MaxPool": "Max pooling",
"AveragePool": "Average pooling",
"AveragePoolGrad": "Average pooling gradient",
# Normalization
"LayerNorm": "Layer normalization",
"LayerNormGrad": "Layer normalization gradient",
"GroupNorm": "Group normalization",
"GroupNormGradX": "Group normalization input gradient",
"GroupNormGradW": "Group normalization weight gradient",
# Convolution
"Conv2D": "2D convolution (supports asymmetric padding)",
"ConvGradX": "Convolution input gradient",
"ConvGradW": "Convolution weight gradient",
"ConvGradB": "Convolution bias gradient",
# Others
"ReduceSum": "Sum reduction",
"SoftmaxCrossEntropy": "Softmax cross entropy",
"ReluGrad": "ReLU gradient",
}
return operators
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Handle padding: can be int or list [top, left, bottom, right] | ||
| if isinstance(self.padding, (list, tuple)): | ||
| pad_top, pad_left, pad_bottom, pad_right = self.padding | ||
| else: | ||
| pad_top = pad_left = pad_bottom = pad_right = self.padding |
There was a problem hiding this comment.
If self.padding is a list with length other than 4 (e.g., [1, 2] for symmetric padding in h and w), the unpacking on line 78 will raise a ValueError. Consider adding validation to check the list length, or handle lists of length 2 (for symmetric h/w padding) as well. The code should either validate the list length in load_config() or handle different list sizes gracefully.
| # Conv node - padding format is [top, left, bottom, right] | ||
| if isinstance(self.padding, (list, tuple)): | ||
| pads_list = list(self.padding) | ||
| else: | ||
| pads_list = [self.padding, self.padding, self.padding, self.padding] |
There was a problem hiding this comment.
The ONNX Conv operator expects pads in the format [begin_h, begin_w, end_h, end_w], but the comment and implementation use [top, left, bottom, right] which is equivalent. However, if the padding list provided has exactly 4 elements and represents [top, left, bottom, right], the current implementation correctly uses it directly. But this could be clearer - the ONNX pads format is [x1_begin, x2_begin, ..., x1_end, x2_end, ...] where for 2D it becomes [top, left, bottom, right]. The current implementation is correct, but consider adding a comment clarifying this matches ONNX's expected format.
| if hasattr(model.encoder, "ln_2"): | ||
| model.encoder.ln_2.weight.data.copy_(ln_weight) | ||
| model.encoder.ln_2.bias.data.copy_(ln_bias) | ||
|
|
||
| if hasattr(model, "norm"): | ||
| model.norm.weight.data.copy_(ln_weight) | ||
| model.norm.bias.data.copy_(ln_bias) | ||
|
|
||
| # All Linear biases use encoder.ln_1.bias in ONNX | ||
| if ( | ||
| hasattr(model.encoder.mha.out_proj, "bias") | ||
| and model.encoder.mha.out_proj.bias is not None | ||
| ): | ||
| model.encoder.mha.out_proj.bias.data.copy_(ln_bias) | ||
|
|
||
| if hasattr(model.encoder.ff.ff1, "bias") and model.encoder.ff.ff1.bias is not None: | ||
| model.encoder.ff.ff1.bias.data.copy_(ln_bias) | ||
|
|
||
| if hasattr(model.encoder.ff.ff2, "bias") and model.encoder.ff.ff2.bias is not None: | ||
| model.encoder.ff.ff2.bias.data.copy_(ln_bias) |
There was a problem hiding this comment.
The weight-sharing optimization code uses hasattr checks but doesn't verify that the attributes are actually LayerNorm or Linear layers with the expected structure. If model architecture changes and these attributes become different types (e.g., a Module instead of LayerNorm), the .weight.data.copy_() calls could fail at runtime. Consider adding type checks or try-except blocks to make this more robust.
| if isinstance(self.kernel_size, (list, tuple)): | ||
| kernel_h, kernel_w = self.kernel_size | ||
| else: | ||
| kernel_h = kernel_w = self.kernel_size | ||
|
|
||
| # Handle stride: can be int or list [h, w] | ||
| if isinstance(self.stride, (list, tuple)): | ||
| stride_h, stride_w = self.stride | ||
| else: | ||
| stride_h = stride_w = self.stride | ||
|
|
||
| # Handle dilation: can be int or list [h, w] | ||
| if isinstance(self.dilation, (list, tuple)): | ||
| dilation_h, dilation_w = self.dilation | ||
| else: | ||
| dilation_h = dilation_w = self.dilation | ||
|
|
||
| # Handle padding: can be int or list [top, left, bottom, right] | ||
| if isinstance(self.padding, (list, tuple)): |
There was a problem hiding this comment.
If self.kernel_size, self.stride, or self.dilation are provided as lists with length other than 2, the unpacking operations on lines 60, 66, and 72 will raise ValueError. Consider adding validation in load_config() to ensure these lists have exactly 2 elements, or add error handling with a clear error message.
| if isinstance(self.kernel_size, (list, tuple)): | |
| kernel_h, kernel_w = self.kernel_size | |
| else: | |
| kernel_h = kernel_w = self.kernel_size | |
| # Handle stride: can be int or list [h, w] | |
| if isinstance(self.stride, (list, tuple)): | |
| stride_h, stride_w = self.stride | |
| else: | |
| stride_h = stride_w = self.stride | |
| # Handle dilation: can be int or list [h, w] | |
| if isinstance(self.dilation, (list, tuple)): | |
| dilation_h, dilation_w = self.dilation | |
| else: | |
| dilation_h = dilation_w = self.dilation | |
| # Handle padding: can be int or list [top, left, bottom, right] | |
| if isinstance(self.padding, (list, tuple)): | |
| if isinstance(self.kernel_size, (list, tuple)): | |
| if len(self.kernel_size) != 2: | |
| raise ValueError( | |
| f"kernel_size must be an int or a sequence of length 2 [h, w], " | |
| f"but got {self.kernel_size!r} (length {len(self.kernel_size)})" | |
| ) | |
| kernel_h, kernel_w = self.kernel_size | |
| else: | |
| kernel_h = kernel_w = self.kernel_size | |
| # Handle stride: can be int or list [h, w] | |
| if isinstance(self.stride, (list, tuple)): | |
| if len(self.stride) != 2: | |
| raise ValueError( | |
| f"stride must be an int or a sequence of length 2 [h, w], " | |
| f"but got {self.stride!r} (length {len(self.stride)})" | |
| ) | |
| stride_h, stride_w = self.stride | |
| else: | |
| stride_h = stride_w = self.stride | |
| # Handle dilation: can be int or list [h, w] | |
| if isinstance(self.dilation, (list, tuple)): | |
| if len(self.dilation) != 2: | |
| raise ValueError( | |
| f"dilation must be an int or a sequence of length 2 [h, w], " | |
| f"but got {self.dilation!r} (length {len(self.dilation)})" | |
| ) | |
| dilation_h, dilation_w = self.dilation | |
| else: | |
| dilation_h = dilation_w = self.dilation | |
| # Handle padding: can be int or list [top, left, bottom, right] | |
| if isinstance(self.padding, (list, tuple)): | |
| if len(self.padding) != 4: | |
| raise ValueError( | |
| f"padding must be an int or a sequence of length 4 " | |
| f"[top, left, bottom, right], but got {self.padding!r} " | |
| f"(length {len(self.padding)})" | |
| ) |
| ReluOperatorTest, | ||
| SGDOperatorTest, | ||
| SplitOperatorTest, | ||
| SubOperatorTest, |
There was a problem hiding this comment.
The ConcatOperatorTest is imported in the operators module but there are no corresponding test cases in this test file. All other operators (Add, Sub, Split, etc.) have test classes, but TestConcatOperator is missing. This creates an incomplete test suite for the new Concat operator functionality.
| # Selu operator configuration | ||
| selu: | ||
| input_shape: [1, 64, 32, 32] # NCHW format | ||
| opset_version: 13 | ||
| # Note: alpha and gamma use ONNX default values | ||
|
|
There was a problem hiding this comment.
The configuration file includes a "selu" operator configuration (lines 12-16), but there is no corresponding SeluOperatorTest implementation in the operators module. This orphaned configuration could be confusing and suggests incomplete implementation or leftover code from a removed feature.
| # Selu operator configuration | |
| selu: | |
| input_shape: [1, 64, 32, 32] # NCHW format | |
| opset_version: 13 | |
| # Note: alpha and gamma use ONNX default values |
| def _validate_shapes(self) -> bool: | ||
| """Validate that input shapes can be concatenated along the specified axis.""" | ||
| if len(self.input_shapes) < 2: | ||
| return False | ||
|
|
||
| reference_shape = list(self.input_shapes[0]) | ||
| for shape in self.input_shapes[1:]: | ||
| if len(shape) != len(reference_shape): | ||
| return False | ||
| for i, (dim1, dim2) in enumerate(zip(reference_shape, shape)): | ||
| if i != self.axis and dim1 != dim2: | ||
| return False | ||
| return True |
There was a problem hiding this comment.
The _validate_shapes method doesn't validate that self.axis is within the valid range for the input shapes. If axis is negative or greater than or equal to the number of dimensions, the validation will pass but may cause runtime errors later. Consider adding validation like: if self.axis < 0 or self.axis >= len(reference_shape): return False
Key Features
New Model: SleepConViT for sleep stage classification (5 classes: Wake, N1, N2, N3, REM)
ONNX Exporter (
sleep_convit_exporter.py)PyTorch Model (
pytorch_models/sleep_convit/)Technical Highlights
Correct outputs.npz Generation
The exporter's
save_test_data()method handles ONNX-specific optimizations:MatMul Weight Transpose: ONNX MatMul weights (
Y = X @ W) are transposed for PyTorch Linear layers (Y = X @ W.T + b)Weight-Sharing Optimization: ONNX optimizer reuses
encoder.ln_1weights for:encoder.ln_2.weight/biasnorm.weight/biasmha.out_proj,ff.ff1,ff.ff2)This ensures
outputs.npzmatches the exported ONNX model exactly.Validation
✅ Deeploy Integration Test: 0/5 errors (PASSED)