Skip to content

Add complete ONNX export support for SleepConViT#4

Merged
runwangdl merged 5 commits intopulp-platform:develfrom
runwangdl:sleepvit_run
Feb 14, 2026
Merged

Add complete ONNX export support for SleepConViT#4
runwangdl merged 5 commits intopulp-platform:develfrom
runwangdl:sleepvit_run

Conversation

@runwangdl
Copy link
Collaborator

Key Features

  • New Model: SleepConViT for sleep stage classification (5 classes: Wake, N1, N2, N3, REM)

    • Input: 1D time-series (3000 samples)
    • Architecture: ConvStem (3 parallel branches) + Transformer Encoder + MLP classifier
    • Model dim: 48, Heads: 6, Patches: 94
  • ONNX Exporter (sleep_convit_exporter.py)

    • Transformer optimization pipeline (LayerNorm fusion, GELU optimization, etc.)
    • Smart test data generation with ONNX weight loading
  • PyTorch Model (pytorch_models/sleep_convit/)

    • Modular architecture: ConvStem, TransformerEncoder, MLPHead
    • CascadedConcat custom op for branch concatenation
    • Batch size 1 optimization (direct concat without expand)

Technical Highlights

Correct outputs.npz Generation

The exporter's save_test_data() method handles ONNX-specific optimizations:

  1. MatMul Weight Transpose: ONNX MatMul weights (Y = X @ W) are transposed for PyTorch Linear layers (Y = X @ W.T + b)

  2. Weight-Sharing Optimization: ONNX optimizer reuses encoder.ln_1 weights for:

    • encoder.ln_2.weight/bias
    • norm.weight/bias
    • All Linear layer biases (mha.out_proj, ff.ff1, ff.ff2)

This ensures outputs.npz matches the exported ONNX model exactly.

Validation

Deeploy Integration Test: 0/5 errors (PASSED)

  • Runtime: 22909628 cycles on SIRACUSA target
  • All intermediate layer outputs match expected values

This commit adds support for the ONNX Sub (subtraction) operator:

- Created SubOperatorTest class in onnx4deeploy/operators/sub.py
  * Implements element-wise binary subtraction (A - B)
  * Supports Numpy-style broadcasting
  * Based on ONNX Opset 14 specification
  * Uses SimpleElementwiseOperator base class

- Registered Sub operator in onnx4deeploy/operators/__init__.py

- Added default configuration in onnx4deeploy/operators/config.yaml
  * Default input shape: [1, 64, 32, 32] (NCHW format)
  * Opset version: 14

- Added comprehensive pytest tests in tests/operators/test_operators.py
  * test_sub_basic: Verifies basic functionality
  * test_sub_different_shapes: Tests multiple input shapes
  * test_sub_result_range: Validates output range correctness
  * All tests passing ✓

Generated test files validated:
- network.onnx: Sub operator graph
- inputs.npz: Two input tensors (input_a, input_b)
- outputs.npz: Expected output (A - B)
- Numerical verification: max difference = 0.000000e+00

Reference: https://onnx.ai/onnx/operators/onnx__Sub.html

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds ONNX export support for the SleepConViT model, a vision transformer for sleep stage classification. It introduces new operator test implementations (Sub, Concat), enhances Conv2D with asymmetric padding support, and implements a sophisticated weight loading mechanism for ONNX-to-PyTorch conversion that handles MatMul weight transposition and weight-sharing optimizations.

Changes:

  • Added Sub operator implementation with comprehensive tests
  • Added Concat operator implementation supporting multiple inputs with cascaded concatenation
  • Enhanced Conv2D operator to support asymmetric padding (e.g., [1, 2, 1, 2]) for flexible convolution configurations
  • Implemented CascadedConcat custom operation in SleepConViT to prevent ONNX optimizer from fusing 3-input concatenations
  • Added ONNX weight loading mechanism in sleep_convit_exporter.py that handles MatMul transposition and weight-sharing optimizations
  • Updated model to remove .expand() for batch_size=1 optimization

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 13 comments.

Show a summary per file
File Description
onnx4deeploy/operators/__init__.py Added imports and exports for Sub, Concat, and Selu operators
onnx4deeploy/operators/sub.py New Sub operator implementation following SimpleElementwiseOperator pattern
onnx4deeploy/operators/concat.py New Concat operator implementation with shape validation and multiple input support
onnx4deeploy/operators/config.yaml Default configurations for new operators (Sub, Selu, Conv2D with asymmetric padding, Concat)
onnx4deeploy/operators/conv2d.py Enhanced to support asymmetric padding and changed default use_bias to False
tests/operators/test_operators.py Added comprehensive test suite for Sub operator (basic, different shapes, range validation)
onnx4deeploy/models/sleep_convit_exporter.py Implemented ONNX weight loading with MatMul transposition and weight-sharing logic
onnx4deeploy/models/pytorch_models/sleep_convit/sleep_convit.py Added CascadedConcat custom operation and optimized cls_token concatenation for batch_size=1
Onnx4Deeploy.py Updated help text to document Concat and Conv2D capabilities
Comments suppressed due to low confidence (1)

onnx4deeploy/operators/init.py:35

  • The import from .selu import SeluOperatorTest will fail because the file onnx4deeploy/operators/selu.py does not exist in this pull request. This import statement references a non-existent module and will cause an ImportError when the module is loaded. Either add the missing selu.py file or remove this import and its references from __all__.
from .sgd import SGDOperatorTest

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

runwangdl and others added 3 commits February 14, 2026 22:55
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Run Wang <52746141+runwangdl@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Run Wang <52746141+runwangdl@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Run Wang <52746141+runwangdl@users.noreply.github.com>
@runwangdl runwangdl requested a review from Copilot February 14, 2026 22:04
@runwangdl runwangdl merged commit 7bd999f into pulp-platform:devel Feb 14, 2026
16 checks passed
@runwangdl runwangdl deleted the sleepvit_run branch February 14, 2026 22:11
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.

Comments suppressed due to low confidence (1)

Onnx4Deeploy.py:179

  • The Sub operator is implemented and imported in the operators module, but it's not listed in the list_available_operators() function. This means users cannot discover or use it through the command-line interface. Add an entry for "Sub": "Subtraction operator" to the operators dictionary.
def list_available_operators():
    """List available operators"""
    operators = {
        # Basic operators
        "Add": "Addition operator",
        "Relu": "ReLU activation function",
        "Transpose": "Tensor transpose",
        "Concat": "Tensor concatenation (supports 3 inputs)",
        "Split": "Tensor split",
        # Matrix operations
        "Gemm": "General matrix multiplication",
        "MatMul": "Matrix multiplication",
        # Pooling
        "MaxPool": "Max pooling",
        "AveragePool": "Average pooling",
        "AveragePoolGrad": "Average pooling gradient",
        # Normalization
        "LayerNorm": "Layer normalization",
        "LayerNormGrad": "Layer normalization gradient",
        "GroupNorm": "Group normalization",
        "GroupNormGradX": "Group normalization input gradient",
        "GroupNormGradW": "Group normalization weight gradient",
        # Convolution
        "Conv2D": "2D convolution (supports asymmetric padding)",
        "ConvGradX": "Convolution input gradient",
        "ConvGradW": "Convolution weight gradient",
        "ConvGradB": "Convolution bias gradient",
        # Others
        "ReduceSum": "Sum reduction",
        "SoftmaxCrossEntropy": "Softmax cross entropy",
        "ReluGrad": "ReLU gradient",
    }
    return operators

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +76 to +80
# Handle padding: can be int or list [top, left, bottom, right]
if isinstance(self.padding, (list, tuple)):
pad_top, pad_left, pad_bottom, pad_right = self.padding
else:
pad_top = pad_left = pad_bottom = pad_right = self.padding
Copy link

Copilot AI Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self.padding is a list with length other than 4 (e.g., [1, 2] for symmetric padding in h and w), the unpacking on line 78 will raise a ValueError. Consider adding validation to check the list length, or handle lists of length 2 (for symmetric h/w padding) as well. The code should either validate the list length in load_config() or handle different list sizes gracefully.

Copilot uses AI. Check for mistakes.
Comment on lines +120 to +124
# Conv node - padding format is [top, left, bottom, right]
if isinstance(self.padding, (list, tuple)):
pads_list = list(self.padding)
else:
pads_list = [self.padding, self.padding, self.padding, self.padding]
Copy link

Copilot AI Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ONNX Conv operator expects pads in the format [begin_h, begin_w, end_h, end_w], but the comment and implementation use [top, left, bottom, right] which is equivalent. However, if the padding list provided has exactly 4 elements and represents [top, left, bottom, right], the current implementation correctly uses it directly. But this could be clearer - the ONNX pads format is [x1_begin, x2_begin, ..., x1_end, x2_end, ...] where for 2D it becomes [top, left, bottom, right]. The current implementation is correct, but consider adding a comment clarifying this matches ONNX's expected format.

Copilot uses AI. Check for mistakes.
Comment on lines +276 to +295
if hasattr(model.encoder, "ln_2"):
model.encoder.ln_2.weight.data.copy_(ln_weight)
model.encoder.ln_2.bias.data.copy_(ln_bias)

if hasattr(model, "norm"):
model.norm.weight.data.copy_(ln_weight)
model.norm.bias.data.copy_(ln_bias)

# All Linear biases use encoder.ln_1.bias in ONNX
if (
hasattr(model.encoder.mha.out_proj, "bias")
and model.encoder.mha.out_proj.bias is not None
):
model.encoder.mha.out_proj.bias.data.copy_(ln_bias)

if hasattr(model.encoder.ff.ff1, "bias") and model.encoder.ff.ff1.bias is not None:
model.encoder.ff.ff1.bias.data.copy_(ln_bias)

if hasattr(model.encoder.ff.ff2, "bias") and model.encoder.ff.ff2.bias is not None:
model.encoder.ff.ff2.bias.data.copy_(ln_bias)
Copy link

Copilot AI Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weight-sharing optimization code uses hasattr checks but doesn't verify that the attributes are actually LayerNorm or Linear layers with the expected structure. If model architecture changes and these attributes become different types (e.g., a Module instead of LayerNorm), the .weight.data.copy_() calls could fail at runtime. Consider adding type checks or try-except blocks to make this more robust.

Copilot uses AI. Check for mistakes.
Comment on lines +59 to +77
if isinstance(self.kernel_size, (list, tuple)):
kernel_h, kernel_w = self.kernel_size
else:
kernel_h = kernel_w = self.kernel_size

# Handle stride: can be int or list [h, w]
if isinstance(self.stride, (list, tuple)):
stride_h, stride_w = self.stride
else:
stride_h = stride_w = self.stride

# Handle dilation: can be int or list [h, w]
if isinstance(self.dilation, (list, tuple)):
dilation_h, dilation_w = self.dilation
else:
dilation_h = dilation_w = self.dilation

# Handle padding: can be int or list [top, left, bottom, right]
if isinstance(self.padding, (list, tuple)):
Copy link

Copilot AI Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self.kernel_size, self.stride, or self.dilation are provided as lists with length other than 2, the unpacking operations on lines 60, 66, and 72 will raise ValueError. Consider adding validation in load_config() to ensure these lists have exactly 2 elements, or add error handling with a clear error message.

Suggested change
if isinstance(self.kernel_size, (list, tuple)):
kernel_h, kernel_w = self.kernel_size
else:
kernel_h = kernel_w = self.kernel_size
# Handle stride: can be int or list [h, w]
if isinstance(self.stride, (list, tuple)):
stride_h, stride_w = self.stride
else:
stride_h = stride_w = self.stride
# Handle dilation: can be int or list [h, w]
if isinstance(self.dilation, (list, tuple)):
dilation_h, dilation_w = self.dilation
else:
dilation_h = dilation_w = self.dilation
# Handle padding: can be int or list [top, left, bottom, right]
if isinstance(self.padding, (list, tuple)):
if isinstance(self.kernel_size, (list, tuple)):
if len(self.kernel_size) != 2:
raise ValueError(
f"kernel_size must be an int or a sequence of length 2 [h, w], "
f"but got {self.kernel_size!r} (length {len(self.kernel_size)})"
)
kernel_h, kernel_w = self.kernel_size
else:
kernel_h = kernel_w = self.kernel_size
# Handle stride: can be int or list [h, w]
if isinstance(self.stride, (list, tuple)):
if len(self.stride) != 2:
raise ValueError(
f"stride must be an int or a sequence of length 2 [h, w], "
f"but got {self.stride!r} (length {len(self.stride)})"
)
stride_h, stride_w = self.stride
else:
stride_h = stride_w = self.stride
# Handle dilation: can be int or list [h, w]
if isinstance(self.dilation, (list, tuple)):
if len(self.dilation) != 2:
raise ValueError(
f"dilation must be an int or a sequence of length 2 [h, w], "
f"but got {self.dilation!r} (length {len(self.dilation)})"
)
dilation_h, dilation_w = self.dilation
else:
dilation_h = dilation_w = self.dilation
# Handle padding: can be int or list [top, left, bottom, right]
if isinstance(self.padding, (list, tuple)):
if len(self.padding) != 4:
raise ValueError(
f"padding must be an int or a sequence of length 4 "
f"[top, left, bottom, right], but got {self.padding!r} "
f"(length {len(self.padding)})"
)

Copilot uses AI. Check for mistakes.
ReluOperatorTest,
SGDOperatorTest,
SplitOperatorTest,
SubOperatorTest,
Copy link

Copilot AI Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ConcatOperatorTest is imported in the operators module but there are no corresponding test cases in this test file. All other operators (Add, Sub, Split, etc.) have test classes, but TestConcatOperator is missing. This creates an incomplete test suite for the new Concat operator functionality.

Copilot uses AI. Check for mistakes.
Comment on lines +12 to +17
# Selu operator configuration
selu:
input_shape: [1, 64, 32, 32] # NCHW format
opset_version: 13
# Note: alpha and gamma use ONNX default values

Copy link

Copilot AI Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The configuration file includes a "selu" operator configuration (lines 12-16), but there is no corresponding SeluOperatorTest implementation in the operators module. This orphaned configuration could be confusing and suggests incomplete implementation or leftover code from a removed feature.

Suggested change
# Selu operator configuration
selu:
input_shape: [1, 64, 32, 32] # NCHW format
opset_version: 13
# Note: alpha and gamma use ONNX default values

Copilot uses AI. Check for mistakes.
Comment on lines +53 to +65
def _validate_shapes(self) -> bool:
"""Validate that input shapes can be concatenated along the specified axis."""
if len(self.input_shapes) < 2:
return False

reference_shape = list(self.input_shapes[0])
for shape in self.input_shapes[1:]:
if len(shape) != len(reference_shape):
return False
for i, (dim1, dim2) in enumerate(zip(reference_shape, shape)):
if i != self.axis and dim1 != dim2:
return False
return True
Copy link

Copilot AI Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _validate_shapes method doesn't validate that self.axis is within the valid range for the input shapes. If axis is negative or greater than or equal to the number of dimensions, the validation will pass but may cause runtime errors later. Consider adding validation like: if self.axis < 0 or self.axis >= len(reference_shape): return False

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants