Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 14 additions & 19 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,44 +332,39 @@ end
return augmented_rule_return_type(rct, RT.parameters[1])
end

function _annotate(@nospecialize(T))
function _annotate(@nospecialize(T), tvars::Vector{TypeVar})
if isvarargtype(T)
VA = T
T = _annotate(Core.Compiler.unwrapva(VA))
if isdefined(VA, :N)
return Vararg{T, VA.N}
else
return Vararg{T}
end
else
return TypeVar(gensym(), Annotation{T})
T = _annotate(Core.Compiler.unwrapva(VA), tvars)
return isdefined(VA, :N) ? Vararg{T, VA.N} : Vararg{T}
end
tv = TypeVar(gensym(), Annotation{T})
push!(tvars, tv)
return tv
end
function _annotate_tt(@nospecialize(TT0))
TT = Base.unwrap_unionall(TT0)
ft = TT.parameters[1]
tt = []
for TTp in TT.parameters[2:end]
push!(tt, _annotate(Base.rewrap_unionall(TTp, TT0)))
end
return ft, tt
ft = Base.rewrap_unionall(TT.parameters[1], TT0)
tvars = TypeVar[]
tt = Any[_annotate(Base.rewrap_unionall(TTp, TT0), tvars) for TTp in TT.parameters[2:end]]
return ft, tt, tvars
end

function has_frule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
ft, tt, tvars = _annotate_tt(TT)
TT = foldr(UnionAll, tvars; init=Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...})
return isapplicable(forward, TT; world, method_table, caller)
end

function has_rrule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
ft, tt, tvars = _annotate_tt(TT)
TT = foldr(UnionAll, tvars; init=Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...})
return isapplicable(augmented_primal, TT; world, method_table, caller)
end

Expand Down
8 changes: 4 additions & 4 deletions src/compiler/tfunc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal,

function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...}
ft, tt, tvars = _annotate_tt(TT)
TT = foldr(UnionAll, tvars; init=Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...})
fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
return isapplicable(interp, forward, TT, sv, fwd_sig)
end

function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...}
ft, tt, tvars = _annotate_tt(TT)
TT = foldr(UnionAll, tvars; init=Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...})
rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
return isapplicable(interp, augmented_primal, TT, sv, rev_sig)
end
Expand Down