diff --git a/src/typeutils/make_zero.jl b/src/typeutils/make_zero.jl index 75c497eb7f..c72a05f6cc 100644 --- a/src/typeutils/make_zero.jl +++ b/src/typeutils/make_zero.jl @@ -197,55 +197,121 @@ end return res end -@inline function EnzymeCore.make_zero( +@generated function EnzymeCore.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive} = Val(false), )::RT where {copy_if_inactive, RT} if guaranteed_const(RT) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - if haskey(seen, prev) - return seen[prev] + return quote + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + end end @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) - if ismutable(prev) - y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT - seen[prev] = y + if ismutabletype(RT) + exprs = [] for i in 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - T = Core.Typeof(xi) - xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) - if Base.isconst(RT, i) - ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, xi) - else - setfield!(y, i, xi) + ST = fieldtype(RT, i) + is_const = Base.isconst(RT, i) + push!(exprs, quote + if isdefined(prev, $i) + xi = getfield(prev, $i) + T = $(isconcretetype(ST) ? ST : :(Core.Typeof(xi))) + xi_zero = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) + if $is_const + ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, $(i - 1), xi_zero) + else + setfield!(y, $i, xi_zero) + end end + end) + end + return quote + if haskey(seen, prev) + return seen[prev] end + y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT + seen[prev] = y + $(exprs...) + return y end - return y - end - if nf == 0 - return prev - end - flds = Vector{Any}(undef, nf) - for i in 1:nf - if isdefined(prev, i) - xi = getfield(prev, i) - xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive)) - flds[i] = xi + else + if nf == 0 + return quote + return prev + end + end + if nf <= 16 + decls = [:(local $(Symbol("f_", i))) for i in 1:nf] + evals = [] + for i in 1:nf + ST = fieldtype(RT, i) + sym = Symbol("f_", i) + push!(evals, quote + if continue_flag && isdefined(prev, $i) + xi = getfield(prev, $i) + T = $(isconcretetype(ST) ? ST : :(Core.Typeof(xi))) + $sym = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) + nf_defined = $i + else + continue_flag = false + end + end) + end + branches = [] + for k in 0:nf + args = [Symbol("f_", i) for i in 1:k] + push!(branches, quote + if nf_defined == $k + y = $(Expr(:new, RT, args...)) + seen[prev] = y + return y + end + end) + end + return quote + if haskey(seen, prev) + return seen[prev] + end + $(decls...) + nf_defined = 0 + continue_flag = true + $(evals...) + $(branches...) + error("Unreachable") + end else - nf = i - 1 # rest of tail must be undefined values - break + exprs = [] + for i in 1:nf + ST = fieldtype(RT, i) + push!(exprs, quote + if continue_flag && isdefined(prev, $i) + xi = getfield(prev, $i) + T = $(isconcretetype(ST) ? ST : :(Core.Typeof(xi))) + flds[$i] = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive)) + nf_defined = $i + else + continue_flag = false + end + end) + end + return quote + if haskey(seen, prev) + return seen[prev] + end + flds = Vector{Any}(undef, $nf) + nf_defined = 0 + continue_flag = true + $(exprs...) + y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf_defined) + seen[prev] = y + return y + end end end - y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf) - seen[prev] = y - return y end function make_zero_immutable!(prev::T, seen::S)::T where {T <: AbstractFloat, S}