Skip to content

mlp fused kernel + compiler improvements#4

Merged
AndreSlavescu merged 4 commits intomainfrom
kernels
Mar 18, 2026
Merged

mlp fused kernel + compiler improvements#4
AndreSlavescu merged 4 commits intomainfrom
kernels

Conversation

@AndreSlavescu
Copy link
Copy Markdown
Owner

No description provided.

@AndreSlavescu
Copy link
Copy Markdown
Owner Author

@codex review

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d05280b4c5

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +944 to +947
if const_val is not None:
epilogue.append(("binop", op.op, "rhs", const_val))
else:
epilogue.append(("scale",))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Don't coerce scalar add/sub/div epilogues into scale

When a post-GEMM tile/scalar BinOp has a non-constant RHS, _detect_epilogue() now always records it as ("scale",) regardless of op.op. In this compiler, everything between the dot loop and tile_store is realized only through the fused epilogue path, so a kernel such as acc = acc + alpha or acc = acc / alpha will be emitted as _v *= _scale in msl_emitter.py instead of preserving the requested operator. That turns valid fused-bias or normalization epilogues into either wrong math or a Metal compile failure as soon as the scalar comes from a kernel argument rather than a literal.

Useful? React with 👍 / 👎.

Comment thread metile/codegen/msl_emitter.py Outdated
Comment on lines +80 to +82
elif epi[0] == "binop":
_, op_name, const_side, const_val = epi
sym = _BINOP_SYMBOLS_EPILOGUE.get(op_name, "+")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle max/min epilogue binops instead of defaulting to '+'

The new epilogue emitter records arbitrary tile/constant BinOps, but _emit_epilogue_chain() only maps add/sub/mul/div and falls back to '+' for anything else. That means a fused GEMM epilogue written with the public metile.maximum()/minimum() helpers will now be silently emitted as addition (for example, maximum(acc, 0.0) becomes acc + 0.0f) because the post-loop ops are no longer lowered any other way. This is silent numerical corruption for any fused epilogue that uses max, min, or another non-arithmetic binop.

Useful? React with 👍 / 👎.

@AndreSlavescu AndreSlavescu merged commit 65c6976 into main Mar 18, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant