Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 98 additions & 32 deletions src/typeutils/make_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,55 +197,121 @@ end
return res
end

@inline function EnzymeCore.make_zero(
@generated function EnzymeCore.make_zero(
::Type{RT},
seen::IdDict,
prev::RT,
::Val{copy_if_inactive} = Val(false),
)::RT where {copy_if_inactive, RT}
if guaranteed_const(RT)
return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev
end
if haskey(seen, prev)
return seen[prev]
return quote
return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev
end
end
@assert !Base.isabstracttype(RT)
@assert Base.isconcretetype(RT)
nf = fieldcount(RT)
if ismutable(prev)
y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT
seen[prev] = y
if ismutabletype(RT)
exprs = []
for i in 1:nf
if isdefined(prev, i)
xi = getfield(prev, i)
T = Core.Typeof(xi)
xi = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive))
if Base.isconst(RT, i)
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i - 1, xi)
else
setfield!(y, i, xi)
ST = fieldtype(RT, i)
is_const = Base.isconst(RT, i)
push!(exprs, quote
if isdefined(prev, $i)
xi = getfield(prev, $i)
T = $(isconcretetype(ST) ? ST : :(Core.Typeof(xi)))
xi_zero = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive))
if $is_const
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, $(i - 1), xi_zero)
else
setfield!(y, $i, xi_zero)
end
end
end)
end
return quote
if haskey(seen, prev)
return seen[prev]
end
y = ccall(:jl_new_struct_uninit, Any, (Any,), RT)::RT
seen[prev] = y
$(exprs...)
return y
end
return y
end
if nf == 0
return prev
end
flds = Vector{Any}(undef, nf)
for i in 1:nf
if isdefined(prev, i)
xi = getfield(prev, i)
xi = EnzymeCore.make_zero(Core.Typeof(xi), seen, xi, Val(copy_if_inactive))
flds[i] = xi
else
if nf == 0
return quote
return prev
end
end
if nf <= 16
decls = [:(local $(Symbol("f_", i))) for i in 1:nf]
evals = []
for i in 1:nf
ST = fieldtype(RT, i)
sym = Symbol("f_", i)
push!(evals, quote
if continue_flag && isdefined(prev, $i)
xi = getfield(prev, $i)
T = $(isconcretetype(ST) ? ST : :(Core.Typeof(xi)))
$sym = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive))
nf_defined = $i
else
continue_flag = false
end
end)
end
branches = []
for k in 0:nf
args = [Symbol("f_", i) for i in 1:k]
push!(branches, quote
if nf_defined == $k
y = $(Expr(:new, RT, args...))
seen[prev] = y
return y
end
end)
end
return quote
if haskey(seen, prev)
return seen[prev]
end
$(decls...)
nf_defined = 0
continue_flag = true
$(evals...)
$(branches...)
error("Unreachable")
end
else
nf = i - 1 # rest of tail must be undefined values
break
exprs = []
for i in 1:nf
ST = fieldtype(RT, i)
push!(exprs, quote
if continue_flag && isdefined(prev, $i)
xi = getfield(prev, $i)
T = $(isconcretetype(ST) ? ST : :(Core.Typeof(xi)))
flds[$i] = EnzymeCore.make_zero(T, seen, xi, Val(copy_if_inactive))
nf_defined = $i
else
continue_flag = false
end
end)
end
return quote
if haskey(seen, prev)
return seen[prev]
end
flds = Vector{Any}(undef, $nf)
nf_defined = 0
continue_flag = true
$(exprs...)
y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf_defined)
seen[prev] = y
return y
end
end
end
y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), RT, flds, nf)
seen[prev] = y
return y
end

function make_zero_immutable!(prev::T, seen::S)::T where {T <: AbstractFloat, S}
Expand Down
Loading