-
Notifications
You must be signed in to change notification settings - Fork 40
[DRAFT] Add Adam Support for Generic and Siracusa Platforms #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2697,6 +2697,51 @@ def parseNodeCtxt(self, | |||||||||||||||||||
| return ctxt, True | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class AdamParser(NodeParser): | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def parseNode(self, node: gs.Node) -> bool: | ||||||||||||||||||||
| n_inputs = len(node.inputs) | ||||||||||||||||||||
| n_outputs = len(node.outputs) | ||||||||||||||||||||
| num_tensors = (n_inputs - 2) // 4 | ||||||||||||||||||||
| valid_inputs = n_inputs >= 6 and (n_inputs - 2) % 4 == 0 | ||||||||||||||||||||
| valid_outputs = n_outputs >= 1 and n_outputs == num_tensors | ||||||||||||||||||||
| valid_attrs = all(a in node.attrs for a in ['alpha', 'beta', 'epsilon', 'norm_coefficient', 'norm_coefficient_post']) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return all([valid_inputs, valid_outputs, valid_attrs]) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
Comment on lines
+2705
to
+2714
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Line 2709-Line 2710 allow multi-group Adam signatures ( 🛠️ Safe fix for current implementation scope def parseNode(self, node: gs.Node) -> bool:
n_inputs = len(node.inputs)
n_outputs = len(node.outputs)
- num_tensors = (n_inputs - 2) // 4
- valid_inputs = n_inputs >= 6 and (n_inputs - 2) % 4 == 0
- valid_outputs = n_outputs >= 1 and n_outputs == num_tensors
+ # Current bindings/templates support exactly one tensor group:
+ # inputs: R, T, X, G, V, H
+ # outputs: X_new
+ valid_inputs = (n_inputs == 6)
+ valid_outputs = (n_outputs == 1)
valid_attrs = all(a in node.attrs for a in ['alpha', 'beta', 'epsilon', 'norm_coefficient', 'norm_coefficient_post'])
return all([valid_inputs, valid_outputs, valid_attrs])🤖 Prompt for AI Agents |
||||||||||||||||||||
| def parseNodeCtxt(self, | ||||||||||||||||||||
| ctxt: NetworkContext, | ||||||||||||||||||||
| node: gs.Node, | ||||||||||||||||||||
| channels_first: bool = True) -> Tuple[NetworkContext, bool]: | ||||||||||||||||||||
|
Comment on lines
+2715
to
+2718
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resolve ARG002 for Line 2718 introduces an unused argument warning in Ruff. 🔧 Minimal lint fix def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
+ _ = channels_first📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.15.2)[warning] 2718-2718: Unused method argument: (ARG002) 🤖 Prompt for AI Agents |
||||||||||||||||||||
|
|
||||||||||||||||||||
| R = ctxt.lookup(node.inputs[0].name) | ||||||||||||||||||||
| T = ctxt.lookup(node.inputs[1].name) | ||||||||||||||||||||
| X = ctxt.lookup(node.inputs[2].name) | ||||||||||||||||||||
| G = ctxt.lookup(node.inputs[3].name) | ||||||||||||||||||||
| V = ctxt.lookup(node.inputs[4].name) | ||||||||||||||||||||
| H = ctxt.lookup(node.inputs[5].name) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| X_new = ctxt.lookup(node.outputs[0].name) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| self.operatorRepresentation['R'] = R.name | ||||||||||||||||||||
| self.operatorRepresentation['T'] = T.name | ||||||||||||||||||||
| self.operatorRepresentation['X'] = X.name | ||||||||||||||||||||
| self.operatorRepresentation['G'] = G.name | ||||||||||||||||||||
| self.operatorRepresentation['V'] = V.name | ||||||||||||||||||||
| self.operatorRepresentation['H'] = H.name | ||||||||||||||||||||
| self.operatorRepresentation['X_new'] = X_new.name | ||||||||||||||||||||
| self.operatorRepresentation['size'] = np.prod(X.shape) | ||||||||||||||||||||
| self.operatorRepresentation['alpha'] = node.attrs['alpha'] | ||||||||||||||||||||
| self.operatorRepresentation['beta'] = node.attrs['beta'] | ||||||||||||||||||||
| self.operatorRepresentation['epsilon'] = node.attrs['epsilon'] | ||||||||||||||||||||
| self.operatorRepresentation['norm_coefficient'] = node.attrs['norm_coefficient'] | ||||||||||||||||||||
| self.operatorRepresentation['norm_coefficient_post'] = node.attrs['norm_coefficient_post'] | ||||||||||||||||||||
| return ctxt, True | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class BatchNormParser(NodeParser): | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AdamLayer.computeOpsundercounts ops when epsilon is non-zero.At Line 505–507, the comment and formula assume
epsilon=0. ButAdamParsercapturesepsilon, so non-zero epsilon should add one extra add op per element (sqrt(v) + eps). This can skew op-based profiling/scheduling.Proposed fix
def computeOps(self): size = self.mapper.parser.operatorRepresentation['size'] + epsilon = self.mapper.parser.operatorRepresentation.get('epsilon', 0) # Per element: # m (V) update : 2 mul + 1 add = 3 ops # v (H) update : 3 mul + 1 add = 4 ops (includes G*G) # weight update: 1 sqrt + 1 div + - # 1 mul + 1 sub = 4 ops (epsilon=0, +eps eliminated) - # Total = 11 ops - return size * 11 + # 1 mul + 1 sub = 4 ops + # +1 add if epsilon != 0 + ops_per_element = 11 + (1 if epsilon != 0 else 0) + return size * ops_per_element🤖 Prompt for AI Agents