diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index db35f132e3..d3a87f3aac 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -332,35 +332,30 @@ end return augmented_rule_return_type(rct, RT.parameters[1]) end -function _annotate(@nospecialize(T)) +function _annotate(@nospecialize(T), tvars::Vector{TypeVar}) if isvarargtype(T) VA = T - T = _annotate(Core.Compiler.unwrapva(VA)) - if isdefined(VA, :N) - return Vararg{T, VA.N} - else - return Vararg{T} - end - else - return TypeVar(gensym(), Annotation{T}) + T = _annotate(Core.Compiler.unwrapva(VA), tvars) + return isdefined(VA, :N) ? Vararg{T, VA.N} : Vararg{T} end + tv = TypeVar(gensym(), Annotation{T}) + push!(tvars, tv) + return tv end function _annotate_tt(@nospecialize(TT0)) TT = Base.unwrap_unionall(TT0) - ft = TT.parameters[1] - tt = [] - for TTp in TT.parameters[2:end] - push!(tt, _annotate(Base.rewrap_unionall(TTp, TT0))) - end - return ft, tt + ft = Base.rewrap_unionall(TT.parameters[1], TT0) + tvars = TypeVar[] + tt = Any[_annotate(Base.rewrap_unionall(TTp, TT0), tvars) for TTp in TT.parameters[2:end]] + return ft, tt, tvars end function has_frule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool - ft, tt = _annotate_tt(TT) - TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} + ft, tt, tvars = _annotate_tt(TT) + TT = foldr(UnionAll, tvars; init=Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}) return isapplicable(forward, TT; world, method_table, caller) end @@ -368,8 +363,8 @@ function has_rrule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool - ft, tt = _annotate_tt(TT) - TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} + ft, tt, tvars = _annotate_tt(TT) + TT = foldr(UnionAll, tvars; init=Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}) return isapplicable(augmented_primal, TT; world, method_table, caller) end diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl index 57ec0053d1..34ba158d57 100644 --- a/src/compiler/tfunc.jl +++ b/src/compiler/tfunc.jl @@ -3,16 +3,16 @@ import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool - ft, tt = _annotate_tt(TT) - TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...} + ft, tt, tvars = _annotate_tt(TT) + TT = foldr(UnionAll, tvars; init=Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...}) fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} return isapplicable(interp, forward, TT, sv, fwd_sig) end function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool - ft, tt = _annotate_tt(TT) - TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...} + ft, tt, tvars = _annotate_tt(TT) + TT = foldr(UnionAll, tvars; init=Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...}) rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} return isapplicable(interp, augmented_primal, TT, sv, rev_sig) end