Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions Deeploy/Targets/GAP9/Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,23 @@
# Import templates from PULPOpen and Generic
from Deeploy.Targets.Generic.Templates import AddTemplate, ConcatTemplate, DequantTemplate, FloatReduceMeanTemplate, \
FloatReduceSumTemplate, GatherTemplate, QuantTemplate, RQSiGELUTemplate, SliceTemplate, iHardswishTemplate
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, ConcatChecker, ConvChecker, DequantChecker, \
GatherChecker, GELUChecker, GEMMChecker, HardswishChecker, LayerNormChecker, MatMulChecker, MulChecker, \
QuantChecker, ReduceMeanChecker, ReluChecker, ReshapeChecker, RQAddChecker, RQHardswishChecker, SGDChecker, \
SliceChecker, SoftmaxChecker, SoftmaxCrossEntropyLossChecker, TransposeChecker
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, AdamChecker, ConcatChecker, ConvChecker, \
DequantChecker, GatherChecker, GELUChecker, GEMMChecker, HardswishChecker, LayerNormChecker, MatMulChecker, \
MulChecker, QuantChecker, ReduceMeanChecker, ReluChecker, ReshapeChecker, RQAddChecker, RQHardswishChecker, \
SGDChecker, SliceChecker, SoftmaxChecker, SoftmaxCrossEntropyLossChecker, TransposeChecker
from Deeploy.Targets.PULPOpen.Bindings import ForkClosure, L3MemoryAwareFunctionCallClosure, \
MemoryAwareForkTransformer, MemoryAwareFunctionCallClosure, TilingCallClosure
from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPClusterSynch import PULPSynchCoresPass
from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPClusterTiling import PULPClusterTiling
from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPL3Tiling import PULPL3Tiling
from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPProfileUntiled import PULPProfileUntiled
from Deeploy.Targets.PULPOpen.DataTypes import PULPDMAFuture
from Deeploy.Targets.PULPOpen.Templates import ConvTemplate, DMASliceTemplate, FloatAddTemplate, FloatConvTemplate, \
FloatGELUTemplate, FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, \
FloatMulTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GEMMTemplate, MatrixVectorTemplate, MaxPoolTemplate, \
MulTemplate, ReduceMeanTemplate, RequantShiftTemplate, ReshapeTemplate, RQAddTemplate, RQSiHardswishTemplate, \
SGDTemplate, SoftmaxCrossEntropyLossTemplate, TallGEMMTemplate, TransposeTemplate, UniformRequantShiftTemplate, \
iRMSNormTemplate, iSoftmaxTemplate
from Deeploy.Targets.PULPOpen.Templates import ConvTemplate, DMASliceTemplate, FloatAddTemplate, FloatAdamTemplate, \
FloatConvTemplate, FloatGELUTemplate, FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, \
FloatMaxPoolTemplate, FloatMulTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GEMMTemplate, \
MatrixVectorTemplate, MaxPoolTemplate, MulTemplate, ReduceMeanTemplate, RequantShiftTemplate, ReshapeTemplate, \
RQAddTemplate, RQSiHardswishTemplate, SGDTemplate, SoftmaxCrossEntropyLossTemplate, TallGEMMTemplate, \
TransposeTemplate, UniformRequantShiftTemplate, iRMSNormTemplate, iSoftmaxTemplate
from Deeploy.Targets.PULPOpen.TypeCheckers import PULPConvChecker, PULPLinearChecker, PULPMaxPoolChecker, \
PULPRequantShiftChecker
from Deeploy.TilingExtension.CodeTransformationPasses.TilingVariableReplacement import TilingVariableReplacement, \
Expand Down Expand Up @@ -317,6 +317,17 @@
SGDTemplate.referenceTemplate, GAP9Transformer)
]

GAP9AdamBindings = [
NodeBinding(
AdamChecker(
[PointerClass(float32_t), PointerClass(int32_t),
PointerClass(float32_t), PointerClass(float32_t),
PointerClass(float32_t), PointerClass(float32_t)], # R, T, X, G, V, H
[PointerClass(float32_t)] # X_new
),
FloatAdamTemplate.referenceTemplate, GAP9Transformer)
]

GAP9TransposeBindings = [
NodeBinding(TransposeChecker([PointerClass(type)], [PointerClass(type)]), TransposeTemplate.referenceTemplate,
GAP9Transformer) for type in IntegerDataTypes
Expand Down
47 changes: 25 additions & 22 deletions Deeploy/Targets/GAP9/Platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,29 @@
from Deeploy.MemoryLevelExtension.NetworkDeployers.MemoryLevelDeployer import MemoryPlatform, MemoryPlatformWrapper
from Deeploy.Targets.GAP9.Templates import AllocateTemplate, FreeTemplate
# Import GAP9-specific tiler bindings
from Deeploy.Targets.GAP9.Tiler import GAP9AddTilingReadyBindings, GAP9ConcatTilingReadyBindings, \
GAP9Conv2DTilingReadyBindings, GAP9DWConv2DTilingReadyBindings, GAP9FlattenTilingReadyBindings, \
GAP9FPGELUTilingReadyBindings, GAP9FPGEMMTilingReadyBindings, GAP9GatherTilingReadyBindings, \
GAP9iHardswishTilingReadyBindings, GAP9iRMSNormTilingReadyBindings, GAP9iRQSGELUTilingReadyBindings, \
GAP9LayernormTilingReadyBindings, GAP9MatMulTilingReadyBindings, GAP9MaxPool2DTilingReadyBindings, \
GAP9MulTilingReadyBindings, GAP9ReduceSumTilingReadyBindings, GAP9ReluTilingReadyBindings, \
GAP9RQAddTilingReadyBindings, GAP9RQSConv2DTilingReadyBindings, GAP9RQSDWConv2DTilingReadyBindings, \
GAP9RQSGEMMTilingReadyBindings, GAP9RQSiHardswishTilingReadyBindings, GAP9RQSMatrixVecTilingReadyBindings, \
GAP9RQSTallGEMMTilingReadyBindings, GAP9RQSTilingReadyBindings, GAP9SGDTilingReadyBindings, \
GAP9SoftmaxCrossEntropyGradTilingReadyBindings, GAP9SoftmaxCrossEntropyTilingReadyBindings, \
GAP9SoftmaxGradTilingReadyBindings, GAP9SoftmaxTilingReadyBindings, GAP9TransposeTilingReadyBindings, \
GAP9UniformRQSTilingReadyBindings
from Deeploy.Targets.GAP9.Tiler import GAP9AdamTilingReadyBindings, GAP9AddTilingReadyBindings, \
GAP9ConcatTilingReadyBindings, GAP9Conv2DTilingReadyBindings, GAP9DWConv2DTilingReadyBindings, \
GAP9FlattenTilingReadyBindings, GAP9FPGELUTilingReadyBindings, GAP9FPGEMMTilingReadyBindings, \
GAP9GatherTilingReadyBindings, GAP9iHardswishTilingReadyBindings, GAP9iRMSNormTilingReadyBindings, \
GAP9iRQSGELUTilingReadyBindings, GAP9LayernormTilingReadyBindings, GAP9MatMulTilingReadyBindings, \
GAP9MaxPool2DTilingReadyBindings, GAP9MulTilingReadyBindings, GAP9ReduceSumTilingReadyBindings, \
GAP9ReluTilingReadyBindings, GAP9RQAddTilingReadyBindings, GAP9RQSConv2DTilingReadyBindings, \
GAP9RQSDWConv2DTilingReadyBindings, GAP9RQSGEMMTilingReadyBindings, GAP9RQSiHardswishTilingReadyBindings, \
GAP9RQSMatrixVecTilingReadyBindings, GAP9RQSTallGEMMTilingReadyBindings, GAP9RQSTilingReadyBindings, \
GAP9SGDTilingReadyBindings, GAP9SoftmaxCrossEntropyGradTilingReadyBindings, \
GAP9SoftmaxCrossEntropyTilingReadyBindings, GAP9SoftmaxGradTilingReadyBindings, \
GAP9SoftmaxTilingReadyBindings, GAP9TransposeTilingReadyBindings, GAP9UniformRQSTilingReadyBindings
from Deeploy.Targets.Generic.Bindings import BasicGEMMBindings, BasicPad1DBindings, BasicPad2DBindings, \
BasicRQIntegerDivBinding
from Deeploy.Targets.Generic.Layers import AddLayer, ConcatLayer, ConvLayer, GatherLayer, GELULayer, GEMMLayer, \
LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, ReduceMeanLayer, ReduceSumLayer, \
ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, RQSiHardswishLayer, SGDLayer, \
SliceLayer, SoftmaxCrossEntropyLossGradLayer, SoftmaxCrossEntropyLossLayer, SoftmaxGradLayer, SoftmaxLayer, \
TransposeLayer, iHardswishLayer, iRMSNormLayer
from Deeploy.Targets.Generic.Parsers import AddParser, ConcatParser, DequantParser, FlattenParser, GatherParser, \
GELUParser, GEMMParser, LayerNormParser, MatMulParser, MaxPool2DParser, MulParser, Pad1DParser, Pad2DParser, \
QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, ReshapeParser, RQAddParser, \
RQIntegerDivParser, RQSiGELUParser, RQSiHardswishParser, SGDParser, SliceParser, \
from Deeploy.Targets.Generic.Layers import AdamLayer, AddLayer, ConcatLayer, ConvLayer, GatherLayer, GELULayer, \
GEMMLayer, LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, ReduceMeanLayer, \
ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, \
RQSiHardswishLayer, SGDLayer, SliceLayer, SoftmaxCrossEntropyLossGradLayer, SoftmaxCrossEntropyLossLayer, \
SoftmaxGradLayer, SoftmaxLayer, TransposeLayer, iHardswishLayer, iRMSNormLayer
from Deeploy.Targets.Generic.Parsers import AdamParser, AddParser, ConcatParser, DequantParser, FlattenParser, \
GatherParser, GELUParser, GEMMParser, LayerNormParser, MatMulParser, MaxPool2DParser, MulParser, Pad1DParser, \
Pad2DParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, ReshapeParser, \
RQAddParser, RQIntegerDivParser, RQSiGELUParser, RQSiHardswishParser, SGDParser, SliceParser, \
SoftmaxCrossEntropyLossGradParser, SoftmaxCrossEntropyLossParser, SoftmaxGradParser, SoftmaxParser, \
TransposeParser, UniformRequantShiftParser, UnsqueezeParser, iHardswishParser, iRMSNormParser, iSoftmaxParser
from Deeploy.Targets.Generic.Templates import AllocateTemplate as BasicAllocateTemplate
Expand Down Expand Up @@ -90,6 +90,7 @@
GAP9_SoftmaxCrossEntropyLossGradMapper = NodeMapper(SoftmaxCrossEntropyLossGradParser(),
GAP9SoftmaxCrossEntropyGradTilingReadyBindings)
GAP9_SGDMapper = NodeMapper(SGDParser(), GAP9SGDTilingReadyBindings)
GAP9_AdamMapper = NodeMapper(AdamParser(), GAP9AdamTilingReadyBindings)
GAP9_QuantMapper = NodeMapper(QuantParser(), BasicQuantBindings)
GAP9_DequantMapper = NodeMapper(DequantParser(), BasicDequantBindings)
GAP9_GEMMDequantMapper = NodeMapper(PULPGEMMParser(), BasicGEMMBindings)
Expand Down Expand Up @@ -171,7 +172,9 @@
'SoftmaxCrossEntropyLossGrad':
SoftmaxCrossEntropyLossGradLayer([GAP9_SoftmaxCrossEntropyLossGradMapper]),
'SGD':
SGDLayer([GAP9_SGDMapper])
SGDLayer([GAP9_SGDMapper]),
'Adam':
AdamLayer([GAP9_AdamMapper])
}


Expand Down
20 changes: 12 additions & 8 deletions Deeploy/Targets/GAP9/Tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

import copy

from Deeploy.Targets.GAP9.Bindings import GAP9AddBindings, GAP9ConcatBindings, GAP9FloatConv2DBindings, \
GAP9FloatDWConv2DBindings, GAP9FloatGELUBinding, GAP9FloatGEMMBindings, GAP9GatherBindings, \
GAP9iHardswishBindings, GAP9iRMSNormBindings, GAP9iRQSGELUBindings, GAP9LayernormBinding, GAP9MatMulBindings, \
GAP9MaxPool2DBindings, GAP9MulBindings, GAP9ReduceSumBindings, GAP9ReluBinding, GAP9ReshapeBindings, \
GAP9RQAddBindings, GAP9RQSBindings, GAP9RQSConv2DBindings, GAP9RQSDWConv2DBindings, GAP9RQSGEMMBindings, \
GAP9RQSiHardswishBindings, GAP9RQSMatrixVecBindings, GAP9RQSTallGEMMBindings, GAP9SGDBindings, \
GAP9SoftmaxBindings, GAP9SoftmaxCrossEntropyLossBindings, GAP9SoftmaxCrossEntropyLossGradBindings, \
GAP9SoftmaxGradBindings, GAP9TransposeBindings, GAP9UniformRQSBindings
from Deeploy.Targets.GAP9.Bindings import GAP9AdamBindings, GAP9AddBindings, GAP9ConcatBindings, \
GAP9FloatConv2DBindings, GAP9FloatDWConv2DBindings, GAP9FloatGELUBinding, GAP9FloatGEMMBindings, \
GAP9GatherBindings, GAP9iHardswishBindings, GAP9iRMSNormBindings, GAP9iRQSGELUBindings, GAP9LayernormBinding, \
GAP9MatMulBindings, GAP9MaxPool2DBindings, GAP9MulBindings, GAP9ReduceSumBindings, GAP9ReluBinding, \
GAP9ReshapeBindings, GAP9RQAddBindings, GAP9RQSBindings, GAP9RQSConv2DBindings, GAP9RQSDWConv2DBindings, \
GAP9RQSGEMMBindings, GAP9RQSiHardswishBindings, GAP9RQSMatrixVecBindings, GAP9RQSTallGEMMBindings, \
GAP9SGDBindings, GAP9SoftmaxBindings, GAP9SoftmaxCrossEntropyLossBindings, \
GAP9SoftmaxCrossEntropyLossGradBindings, GAP9SoftmaxGradBindings, GAP9TransposeBindings, GAP9UniformRQSBindings
from Deeploy.Targets.Generic.TileConstraints.AddTileConstraint import AddTileConstraint
from Deeploy.Targets.Generic.TileConstraints.ConcatTileConstraint import ConcatTileConstraint
from Deeploy.Targets.Generic.TileConstraints.iHardswishTileConstraint import iHardswishTileConstraint
Expand All @@ -39,6 +39,7 @@
from Deeploy.Targets.PULPOpen.TileConstraints.MatMulTileConstraint import MatMulTileConstraint
from Deeploy.Targets.PULPOpen.TileConstraints.MaxPoolTileConstraint import MaxPoolCTileConstraint
from Deeploy.Targets.PULPOpen.TileConstraints.RequantShiftTileConstraint import RequantShiftTileConstraint
from Deeploy.Targets.PULPOpen.TileConstraints.AdamTileConstraint import AdamTileConstraint
from Deeploy.Targets.PULPOpen.TileConstraints.SGDTileConstraint import SGDTileConstraint
from Deeploy.Targets.PULPOpen.TileConstraints.SoftmaxCrossEntropyTileConstraint import \
SoftmaxCrossEntropyGradTileConstraint, SoftmaxCrossEntropyTileConstraint
Expand Down Expand Up @@ -142,3 +143,6 @@

GAP9SGDTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9SGDBindings,
tileConstraint = SGDTileConstraint())

GAP9AdamTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9AdamBindings,
tileConstraint = AdamTileConstraint())
33 changes: 22 additions & 11 deletions Deeploy/Targets/Generic/Bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
from Deeploy.DeeployTypes import CodeTransformation, NodeBinding
from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration
from Deeploy.Targets.Generic.Templates import AddTemplate, BatchNormalizationTemplate, ConcatTemplate, ConvTemplate, \
ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \
FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, \
FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, \
FloatPowTemplate, FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, FloatSqrtTemplate, \
GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, \
MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \
ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAdamTemplate, \
FloatAddTemplate, FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, \
FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, \
FloatPadTemplate, FloatPowTemplate, FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, \
FloatSqrtTemplate, GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, \
MatMulTemplate, MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \
RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \
iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \
DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \
LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \
ReduceSumChecker, ReluChecker, RequantShiftChecker, ReshapeChecker, RQIntegerDivChecker, SliceChecker, \
SoftmaxChecker, TransposeChecker
from Deeploy.Targets.Generic.TypeCheckers import AdamChecker, AddChecker, BatchNormChecker, ConcatChecker, \
ConvChecker, DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, \
GEMMChecker, LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, \
ReduceMeanChecker, ReduceSumChecker, ReluChecker, RequantShiftChecker, ReshapeChecker, RQIntegerDivChecker, \
SliceChecker, SoftmaxChecker, TransposeChecker

BasicTransformer = CodeTransformation([ArgumentStructGeneration(), MemoryManagementGeneration(), FutureGeneration()])

Expand Down Expand Up @@ -312,6 +312,17 @@
for type in FloatDataTypes
]

BasicAdamBindings = [
NodeBinding(
AdamChecker(
# Note: ONNX spec defines T as int64, but we use int32 for embedded compatibility
[PointerClass(float32_t), PointerClass(int32_t), PointerClass(float32_t), PointerClass(float32_t),
PointerClass(float32_t), PointerClass(float32_t)], # R, T, X, G, V, H
[PointerClass(float32_t)] # X_new only
),
FloatAdamTemplate.referenceTemplate, BasicTransformer)
]

BasicConvTransposeBindings = [
NodeBinding(
ConvChecker(
Expand Down
16 changes: 16 additions & 0 deletions Deeploy/Targets/Generic/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,22 @@ def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)


class AdamLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)

def computeOps(self):
size = self.mapper.parser.operatorRepresentation['size']
# Per element:
# m (V) update : 2 mul + 1 add = 3 ops
# v (H) update : 3 mul + 1 add = 4 ops (includes G*G)
# weight update: 1 sqrt + 1 div +
# 1 mul + 1 sub = 4 ops (epsilon=0, +eps eliminated)
# Total = 11 ops
return size * 11
Comment on lines +500 to +508
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

AdamLayer.computeOps undercounts ops when epsilon is non-zero.

At Line 505–507, the comment and formula assume epsilon=0. But AdamParser captures epsilon, so non-zero epsilon should add one extra add op per element (sqrt(v) + eps). This can skew op-based profiling/scheduling.

Proposed fix
     def computeOps(self):
         size = self.mapper.parser.operatorRepresentation['size']
+        epsilon = self.mapper.parser.operatorRepresentation.get('epsilon', 0)
         # Per element:
         #   m (V) update : 2 mul + 1 add          = 3 ops
         #   v (H) update : 3 mul + 1 add          = 4 ops  (includes G*G)
         #   weight update: 1 sqrt + 1 div +
-        #                  1 mul + 1 sub           = 4 ops  (epsilon=0, +eps eliminated)
-        #   Total                                  = 11 ops
-        return size * 11
+        #                  1 mul + 1 sub           = 4 ops
+        #                  +1 add if epsilon != 0
+        ops_per_element = 11 + (1 if epsilon != 0 else 0)
+        return size * ops_per_element
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@Deeploy/Targets/Generic/Layers.py` around lines 500 - 508, The computeOps
method in AdamLayer currently assumes epsilon=0 and returns size * 11; update
computeOps to read epsilon from self.mapper.parser.operatorRepresentation (or
self.mapper.parser.operatorRepresentation['epsilon']) and add one extra add op
per element when epsilon is non-zero, adjusting the comment and final return to
size * (11 + 1_if_epsilon_nonzero) so the op count reflects the additional
sqrt(v)+eps add when epsilon != 0.



class LinearAttentionLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
Expand Down
45 changes: 45 additions & 0 deletions Deeploy/Targets/Generic/Parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2697,6 +2697,51 @@ def parseNodeCtxt(self,
return ctxt, True


class AdamParser(NodeParser):

def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:
n_inputs = len(node.inputs)
n_outputs = len(node.outputs)
num_tensors = (n_inputs - 2) // 4
valid_inputs = n_inputs >= 6 and (n_inputs - 2) % 4 == 0
valid_outputs = n_outputs >= 1 and n_outputs == num_tensors
valid_attrs = all(a in node.attrs for a in ['alpha', 'beta', 'epsilon', 'norm_coefficient', 'norm_coefficient_post'])

return all([valid_inputs, valid_outputs, valid_attrs])

Comment on lines +2705 to +2714
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

AdamParser.parseNode accepts forms that parseNodeCtxt cannot represent.

Line 2709-Line 2710 allow multi-group Adam signatures (n_inputs > 6, n_outputs > 1), but parseNodeCtxt (Line 2720-Line 2735) only binds the first R/T/X/G/V/H group and first X_new. This can silently drop extra groups.

🛠️ Safe fix for current implementation scope
     def parseNode(self, node: gs.Node) -> bool:
         n_inputs = len(node.inputs)
         n_outputs = len(node.outputs)
-        num_tensors = (n_inputs - 2) // 4
-        valid_inputs = n_inputs >= 6 and (n_inputs - 2) % 4 == 0
-        valid_outputs = n_outputs >= 1 and n_outputs == num_tensors
+        # Current bindings/templates support exactly one tensor group:
+        # inputs:  R, T, X, G, V, H
+        # outputs: X_new
+        valid_inputs = (n_inputs == 6)
+        valid_outputs = (n_outputs == 1)
         valid_attrs = all(a in node.attrs for a in ['alpha', 'beta', 'epsilon', 'norm_coefficient', 'norm_coefficient_post'])

         return all([valid_inputs, valid_outputs, valid_attrs])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@Deeploy/Targets/Generic/Parsers.py` around lines 2705 - 2714, parseNode
currently accepts multi-group Adam signatures (n_inputs > 6 or n_outputs > 1)
while parseNodeCtxt only binds the first R/T/X/G/V/H group and first X_new,
silently dropping extras; fix by restricting parseNode to only allow the
single-group form that parseNodeCtxt handles—change the validity checks in
AdamParser.parseNode so valid_inputs requires exactly 6 inputs (or num_tensors
== 1) and valid_outputs requires exactly 1 output, and keep references to
parseNode and parseNodeCtxt so reviewers can verify both now match.

def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
Comment on lines +2715 to +2718
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Resolve ARG002 for channels_first in AdamParser.parseNodeCtxt.

Line 2718 introduces an unused argument warning in Ruff.

🔧 Minimal lint fix
     def parseNodeCtxt(self,
                       ctxt: NetworkContext,
                       node: gs.Node,
                       channels_first: bool = True) -> Tuple[NetworkContext, bool]:
+        _ = channels_first
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
_ = channels_first
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 2718-2718: Unused method argument: channels_first

(ARG002)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@Deeploy/Targets/Generic/Parsers.py` around lines 2715 - 2718, The function
AdamParser.parseNodeCtxt currently has an unused parameter channels_first
causing ARG002; either remove the parameter if callers never pass it, or rename
it to _channels_first (or prefix with an underscore) to mark it intentionally
unused. Update the signature in AdamParser.parseNodeCtxt and any internal
references/call sites accordingly (or drop the argument from callers if you
remove it) so the linter no longer reports an unused parameter.


R = ctxt.lookup(node.inputs[0].name)
T = ctxt.lookup(node.inputs[1].name)
X = ctxt.lookup(node.inputs[2].name)
G = ctxt.lookup(node.inputs[3].name)
V = ctxt.lookup(node.inputs[4].name)
H = ctxt.lookup(node.inputs[5].name)

X_new = ctxt.lookup(node.outputs[0].name)

self.operatorRepresentation['R'] = R.name
self.operatorRepresentation['T'] = T.name
self.operatorRepresentation['X'] = X.name
self.operatorRepresentation['G'] = G.name
self.operatorRepresentation['V'] = V.name
self.operatorRepresentation['H'] = H.name
self.operatorRepresentation['X_new'] = X_new.name
self.operatorRepresentation['size'] = np.prod(X.shape)
self.operatorRepresentation['alpha'] = node.attrs['alpha']
self.operatorRepresentation['beta'] = node.attrs['beta']
self.operatorRepresentation['epsilon'] = node.attrs['epsilon']
self.operatorRepresentation['norm_coefficient'] = node.attrs['norm_coefficient']
self.operatorRepresentation['norm_coefficient_post'] = node.attrs['norm_coefficient_post']
return ctxt, True


class BatchNormParser(NodeParser):

def __init__(self):
Expand Down
Loading
Loading