EnzymeRules: bind free TypeVars in has_{f,r}rule_from_sig query type#3121
Open
Keno wants to merge 1 commit into
Open
EnzymeRules: bind free TypeVars in has_{f,r}rule_from_sig query type#3121Keno wants to merge 1 commit into
has_{f,r}rule_from_sig query type#3121Keno wants to merge 1 commit into
Conversation
`_annotate_tt` synthesizes a method-table query like
`Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}` from the
caller's signature. The `tt...` entries came from `_annotate`, which builds
fresh `TypeVar`s but never wraps them in a `UnionAll`; `ft` (the first
parameter of `unwrap_unionall(TT0)`) likewise could contain `TypeVar`s bound by
`TT0`'s outer `where` clauses. Both sources left the constructed query with
free `TypeVar`s, which used to be tolerated by Julia's subtype/intersection
but no longer is under JuliaLang/julia#61876 (free `TypeVar`s now compare by
identity, so e.g. `typeintersect(Tuple{…,Const{Float64}},
Tuple{…,V<:Annotation{Float64}})` collapses to `Union{}`).
`_annotate_tt` now returns the new `TypeVar`s alongside `ft, tt`, with `ft`
rewrapped in `TT0`'s `UnionAll` chain. Each `has_*_from_sig` caller wraps the
constructed query in a `UnionAll` per fresh `TypeVar` via `foldr(UnionAll, …)`.
The internal callers in `Enzyme.Compiler.tfunc` are updated to the new return
shape.
The behavior of free TypeVars in subtyping was previously undefined (in effect
a bug); JuliaLang/julia#61876 makes them well-defined singleton identities,
which is what surfaces this latent issue.
This change was AI-generated and should be reviewed carefully before merging.
Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl
index d3a87f3a..440a6d25 100644
--- a/lib/EnzymeCore/src/rules.jl
+++ b/lib/EnzymeCore/src/rules.jl
@@ -355,7 +355,7 @@ function has_frule_from_sig(@nospecialize(TT);
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt, tvars = _annotate_tt(TT)
- TT = foldr(UnionAll, tvars; init=Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...})
+ TT = foldr(UnionAll, tvars; init = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...})
return isapplicable(forward, TT; world, method_table, caller)
end
@@ -364,7 +364,7 @@ function has_rrule_from_sig(@nospecialize(TT);
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt, tvars = _annotate_tt(TT)
- TT = foldr(UnionAll, tvars; init=Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...})
+ TT = foldr(UnionAll, tvars; init = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...})
return isapplicable(augmented_primal, TT; world, method_table, caller)
end
diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl
index 34ba158d..982b198a 100644
--- a/src/compiler/tfunc.jl
+++ b/src/compiler/tfunc.jl
@@ -4,7 +4,7 @@ import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal,
function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool
ft, tt, tvars = _annotate_tt(TT)
- TT = foldr(UnionAll, tvars; init=Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...})
+ TT = foldr(UnionAll, tvars; init = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...})
fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
return isapplicable(interp, forward, TT, sv, fwd_sig)
end
@@ -12,7 +12,7 @@ end
function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool
ft, tt, tvars = _annotate_tt(TT)
- TT = foldr(UnionAll, tvars; init=Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...})
+ TT = foldr(UnionAll, tvars; init = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...})
rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
return isapplicable(interp, augmented_primal, TT, sv, rev_sig)
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
_annotate_ttsynthesizes a method-table query likeTuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}from the caller's signature. Thett...entries came from_annotate, which builds freshTypeVars but never wraps them in aUnionAll;ft(the first parameter ofunwrap_unionall(TT0)) likewise could containTypeVars bound byTT0's outerwhereclauses. Both sources left the constructed query with freeTypeVars, which used to be tolerated by Julia's subtype/intersection but no longer is under JuliaLang/julia#61876 (freeTypeVars now compare by identity, so e.g.typeintersect(Tuple{…,Const{Float64}}, Tuple{…,V<:Annotation{Float64}})collapses toUnion{})._annotate_ttnow returns the newTypeVars alongsideft, tt, withftrewrapped inTT0'sUnionAllchain. Eachhas_*_from_sigcaller wraps the constructed query in aUnionAllper freshTypeVarviafoldr(UnionAll, …). The internal callers inEnzyme.Compiler.tfuncare updated to the new return shape.The behavior of free TypeVars in subtyping was previously undefined (in effect a bug); JuliaLang/julia#61876 makes them well-defined singleton identities, which is what surfaces this latent issue.
This change was AI-generated and should be reviewed carefully before merging.