From 642f87f8dc56724d438eab11685b216f330eefc6 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 3 Dec 2024 08:22:00 -0600 Subject: [PATCH 1/7] Fix higher order codegen --- src/compiler.jl | 26 ++- src/compiler/interpreter.jl | 25 +-- src/compiler/validation.jl | 368 +++--------------------------------- src/llvm/transforms.jl | 188 ++++++++++++++++++ src/rules/parallelrules.jl | 4 +- 5 files changed, 238 insertions(+), 373 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 5fcc53dbde..fed1f54eb3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5223,12 +5223,12 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType)) +function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), Val((Symbol(mod), Symbol(primal_name))), - TapeType, + TapeType ) end @@ -5284,7 +5284,12 @@ function _thunk(job, postopt::Bool = true) end # Run post optimization pipeline - if postopt + prepost = if postopt + mstr = if job.config.params.ABI <: InlineABI + "" + else + string(mod) + end if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI post_optimze!(mod, JIT.get_tm()) if DumpPostOpt[] @@ -5293,12 +5298,17 @@ function _thunk(job, postopt::Bool = true) else propagate_returned!(mod) end + mstr + else + "" end - return (mod, adjoint_name, primal_name, meta.TapeType) + return (mod, adjoint_name, primal_name, meta.TapeType, prepost) end const cache = Dict{UInt,CompileResult}() +const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String, String}}() + const cache_lock = ReentrantLock() @inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult key = hash(job) @@ -5310,9 +5320,15 @@ const cache_lock = ReentrantLock() if obj === nothing asm = _thunk(job) obj = _link(job, asm...) + if obj.adjoint isa Ptr{Nothing} + autodiff_cache[obj.adjoint] = (asm[2], asm[1]) + end + if obj.primal isa Ptr{Nothing} + autodiff_cache[obj.primal] = (asm[3], asm[1]) + end cache[key] = obj end - obj + nothing finally unlock(cache_lock) end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2d02604eda..2f9d1fbf60 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -44,7 +44,6 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter forward_rules::Bool reverse_rules::Bool - deferred_lower::Bool broadcast_rewrite::Bool handler::T end @@ -55,7 +54,6 @@ function EnzymeInterpreter( world::UInt, forward_rules::Bool, reverse_rules::Bool, - deferred_lower::Bool = true, broadcast_rewrite::Bool = true, handler = nothing ) @@ -83,7 +81,6 @@ function EnzymeInterpreter( IdDict{Any, Bool}(), forward_rules, reverse_rules, - deferred_lower, broadcast_rewrite, handler ) @@ -94,10 +91,9 @@ EnzymeInterpreter( mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode, - deferred_lower::Bool = true, broadcast_rewrite::Bool = true, handler = nothing -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler) +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, broadcast_rewrite, handler) Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params @@ -865,25 +861,6 @@ function abstract_call_known( end end - if interp.deferred_lower && f === Enzyme.autodiff && length(argtypes) >= 4 - if widenconst(argtypes[2]) <: Enzyme.Mode && - widenconst(argtypes[3]) <: Enzyme.Annotation && - widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} - arginfo2 = ArgInfo( - fargs isa Nothing ? nothing : - [:(Enzyme.autodiff_deferred), fargs[2:end]...], - [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...], - ) - return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, - Enzyme.autodiff_deferred::Any, - arginfo2::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int, - ) - end - end if interp.handler != nothing return interp.handler(interp, f, arginfo, si, sv, max_methods) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index e109415d0f..efdc5ec8f4 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -2,161 +2,10 @@ using LLVM using ObjectFile using Libdl -module FFI -using LLVM -module BLASSupport -# TODO: LAPACK handling -using LinearAlgebra -using ObjectFile -using Libdl -function __init__() - global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) -end -function get_blas_symbols() - symbols = BLAS.get_config().exported_symbols - if BLAS.USE_BLAS64 - return map(Base.Fix2(*, "64_"), symbols) - end - return symbols -end - -function lookup_blas_symbol(name::String) - Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error = false) -end -end - -const ptr_map = Dict{Ptr{Cvoid},String}() - -function __init__() - known_names = ( - "jl_alloc_array_1d", - "jl_alloc_array_2d", - "jl_alloc_array_3d", - "ijl_alloc_array_1d", - "ijl_alloc_array_2d", - "ijl_alloc_array_3d", - "jl_new_array", - "ijl_new_array", - "jl_array_copy", - "ijl_array_copy", - "jl_alloc_string", - "jl_in_threaded_region", - "jl_enter_threaded_region", - "jl_exit_threaded_region", - "jl_set_task_tid", - "jl_new_task", - "malloc", - "memmove", - "memcpy", - "memset", - "jl_array_grow_beg", - "ijl_array_grow_beg", - "jl_array_grow_end", - "ijl_array_grow_end", - "jl_array_grow_at", - "ijl_array_grow_at", - "jl_array_del_beg", - "ijl_array_del_beg", - "jl_array_del_end", - "ijl_array_del_end", - "jl_array_del_at", - "ijl_array_del_at", - "jl_array_ptr", - "ijl_array_ptr", - "jl_value_ptr", - "jl_get_ptls_states", - "jl_gc_add_finalizer_th", - "jl_symbol_n", - "jl_", - "jl_object_id", - "jl_reshape_array", - "ijl_reshape_array", - "jl_matching_methods", - "ijl_matching_methods", - "jl_array_sizehint", - "ijl_array_sizehint", - "jl_get_keyword_sorter", - "ijl_get_keyword_sorter", - "jl_ptr_to_array", - "jl_box_float32", - "ijl_box_float32", - "jl_box_float64", - "ijl_box_float64", - "jl_ptr_to_array_1d", - "jl_eqtable_get", - "ijl_eqtable_get", - "memcmp", - "memchr", - "jl_get_nth_field_checked", - "ijl_get_nth_field_checked", - "jl_stored_inline", - "ijl_stored_inline", - "jl_array_isassigned", - "ijl_array_isassigned", - "jl_array_ptr_copy", - "ijl_array_ptr_copy", - "jl_array_typetagdata", - "ijl_array_typetagdata", - "jl_idtable_rehash", - ) - for name in known_names - sym = LLVM.find_symbol(name) - if sym == C_NULL - continue - end - if haskey(ptr_map, sym) - # On MacOS memcpy and memmove seem to collide? - if name == "memcpy" - continue - end - end - @assert !haskey(ptr_map, sym) - ptr_map[sym] = name - end - for sym in BLASSupport.get_blas_symbols() - ptr = BLASSupport.lookup_blas_symbol(sym) - if ptr !== nothing - if haskey(ptr_map, ptr) - if ptr_map[ptr] != sym - @warn "Duplicated symbol in ptr_map" ptr, sym, ptr_map[ptr] - end - continue - end - ptr_map[ptr] = sym - end - end -end - -function memoize!(ptr::Ptr{Cvoid}, fn::String)::String - fn = get(ptr_map, ptr, fn) - if !haskey(ptr_map, ptr) - ptr_map[ptr] = fn - else - @assert ptr_map[ptr] == fn - end - return fn -end -end - import GPUCompiler: IRError, InvalidIRError function restore_lookups(mod::LLVM.Module)::Nothing T_size_t = convert(LLVM.LLVMType, Int) - for (v, k) in FFI.ptr_map - if haskey(functions(mod), k) - f = functions(mod)[k] - replace_uses!( - f, - LLVM.Value( - LLVM.API.LLVMConstIntToPtr( - ConstantInt(T_size_t, convert(UInt, v)), - value_type(f), - ), - ), - ) - eraseInst(mod, f) - end - end for f in functions(mod) for fattr in collect(function_attributes(f)) if isa(fattr, LLVM.StringAttribute) @@ -185,194 +34,6 @@ function check_ir(@nospecialize(job::CompilerJob), mod::LLVM.Module) end end -# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA -function rewrite_ccalls!(mod::LLVM.Module) - for f in collect(functions(mod)) - replaceAndErase = Tuple{Instruction,Instruction}[] - for bb in blocks(f), inst in instructions(bb) - if isa(inst, LLVM.CallInst) - fn = called_operand(inst) - changed = false - B = IRBuilder() - position!(B, inst) - if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" - uservals = LLVM.Value[] - for lval in collect(arguments(inst)) - llty = value_type(lval) - if isa(llty, LLVM.PointerType) - push!(uservals, lval) - continue - end - vals = get_julia_inner_types(B, nothing, lval) - for v in vals - if isa(v, LLVM.PointerNull) - subchanged = true - continue - end - push!(uservals, v) - end - if length(vals) == 1 && vals[1] == lval - continue - end - changed = true - end - if changed - prevname = LLVM.name(inst) - LLVM.name!(inst, "") - if !isdefined(LLVM, :OperandBundleDef) - newinst = call!( - B, - called_type(inst), - called_operand(inst), - uservals, - collect(operand_bundles(inst)), - prevname, - ) - else - newinst = call!( - B, - called_type(inst), - called_operand(inst), - uservals, - collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), - prevname, - ) - end - for idx in [ - LLVM.API.LLVMAttributeFunctionIndex, - LLVM.API.LLVMAttributeReturnIndex, - [ - LLVM.API.LLVMAttributeIndex(i) for - i = 1:(length(arguments(inst))) - ]..., - ] - idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j = 1:count - LLVM.API.LLVMAddCallSiteAttribute( - newinst, - idx, - unsafe_load(Attrs, j), - ) - end - Libc.free(Attrs) - end - API.EnzymeCopyMetadata(newinst, inst) - callconv!(newinst, callconv(inst)) - push!(replaceAndErase, (inst, newinst)) - end - continue - end - if !isdefined(LLVM, :OperandBundleDef) - newbundles = OperandBundle[] - else - newbundles = OperandBundleDef[] - end - for bunduse in operand_bundles(inst) - if isdefined(LLVM, :OperandBundleDef) - bunduse = LLVM.OperandBundleDef(bunduse) - end - - if !isdefined(LLVM, :OperandBundleDef) - if LLVM.tag(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - else - if LLVM.tag_name(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - end - uservals = LLVM.Value[] - subchanged = false - for lval in LLVM.inputs(bunduse) - llty = value_type(lval) - if isa(llty, LLVM.PointerType) - push!(uservals, lval) - continue - end - vals = get_julia_inner_types(B, nothing, lval) - for v in vals - if isa(v, LLVM.PointerNull) - subchanged = true - continue - end - push!(uservals, v) - end - if length(vals) == 1 && vals[1] == lval - continue - end - subchanged = true - end - if !subchanged - push!(newbundles, bunduse) - continue - end - changed = true - if !isdefined(LLVM, :OperandBundleDef) - push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) - else - push!( - newbundles, - OperandBundleDef(LLVM.tag_name(bunduse), uservals), - ) - end - end - changed = false - if changed - prevname = LLVM.name(inst) - LLVM.name!(inst, "") - newinst = call!( - B, - called_type(inst), - called_operand(inst), - collect(arguments(inst)), - newbundles, - prevname, - ) - for idx in [ - LLVM.API.LLVMAttributeFunctionIndex, - LLVM.API.LLVMAttributeReturnIndex, - [ - LLVM.API.LLVMAttributeIndex(i) for - i = 1:(length(arguments(inst))) - ]..., - ] - idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j = 1:count - LLVM.API.LLVMAddCallSiteAttribute( - newinst, - idx, - unsafe_load(Attrs, j), - ) - end - Libc.free(Attrs) - end - API.EnzymeCopyMetadata(newinst, inst) - callconv!(newinst, callconv(inst)) - push!(replaceAndErase, (inst, newinst)) - end - end - end - for (inst, newinst) in replaceAndErase - replace_uses!(inst, newinst) - LLVM.API.LLVMInstructionEraseFromParent(inst) - end - end -end - function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod::LLVM.Module) imported = Set(String[]) if haskey(functions(mod), "malloc") @@ -390,7 +51,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) eraseInst(mod, f) end - rewrite_ccalls!(mod) + Compiler.rewrite_ccalls!(mod) del = LLVM.Function[] for f in collect(functions(mod)) @@ -1211,14 +872,34 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp ptr_val = convert(Int, ptr_arg) ptr = Ptr{Cvoid}(ptr_val) + @show ptr, autodiff_cache + if haskey(autodiff_cache, ptr) + pmod, pname = autodiff_cache[ptr] + + pmod = parse(LLVM.Module, pmod) + + for fn in functions(pmod) + if !isempty(LLVM.blocks(fn)) + linkage!(functions(mod)[pmod], fn == pname ? LLVM.API.LLVMInternalLinkage : LLVM.API.LLVMExternalLinkage) + end + end + + GPUCompiler.link_library!(mod, inmod) + + replaceWith = functions(mod)[pname] + push!(function_attributes(replaceWith), EnumAttribute("alwaysinline")) + linkage!(functions(mod)[pname], LLVM.API.LLVMInternalLinkage) + replace_uses!(ptr_arg, LLVM.const_pointercast(b, replaceWith, value_type(ptr_arg))) + return errors + end + # look it up in the Julia JIT cache frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint), ptr, 0) if length(frames) >= 1 fn, file, line, linfo, fromC, inlined = last(frames) - # Remember pointer in our global map - fn = FFI.memoize!(ptr, string(fn)) + fn = string(fn) if length(fn) > 1 && fromC mod = LLVM.parent(LLVM.parent(LLVM.parent(inst))) @@ -1229,6 +910,9 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp fn, LLVM.API.LLVMGetCalledFunctionType(inst), ) + # Remember pointer for subsequent restoration + push!(function_attributes(lfn), StringAttribute("enzymejl_needs_restoration", string(reinterpret(UInt, ptr)))) + @show string(inst), string(lfn), ptr else lfn = LLVM.API.LLVMConstBitCast( lfn, diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index aebb8bab5c..2f9c61c0b4 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -1,4 +1,192 @@ +# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA +function rewrite_ccalls!(mod::LLVM.Module) + for f in collect(functions(mod)) + replaceAndErase = Tuple{Instruction,Instruction}[] + for bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.CallInst) + fn = called_operand(inst) + changed = false + B = IRBuilder() + position!(B, inst) + if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" + uservals = LLVM.Value[] + for lval in collect(arguments(inst)) + llty = value_type(lval) + if isa(llty, LLVM.PointerType) + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + changed = true + end + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + if !isdefined(LLVM, :OperandBundleDef) + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(operand_bundles(inst)), + prevname, + ) + else + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), + prevname, + ) + end + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + continue + end + if !isdefined(LLVM, :OperandBundleDef) + newbundles = OperandBundle[] + else + newbundles = OperandBundleDef[] + end + for bunduse in operand_bundles(inst) + if isdefined(LLVM, :OperandBundleDef) + bunduse = LLVM.OperandBundleDef(bunduse) + end + + if !isdefined(LLVM, :OperandBundleDef) + if LLVM.tag(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + else + if LLVM.tag_name(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + end + uservals = LLVM.Value[] + subchanged = false + for lval in LLVM.inputs(bunduse) + llty = value_type(lval) + if isa(llty, LLVM.PointerType) + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + subchanged = true + end + if !subchanged + push!(newbundles, bunduse) + continue + end + changed = true + if !isdefined(LLVM, :OperandBundleDef) + push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) + else + push!( + newbundles, + OperandBundleDef(LLVM.tag_name(bunduse), uservals), + ) + end + end + changed = false + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + newinst = call!( + B, + called_type(inst), + called_operand(inst), + collect(arguments(inst)), + newbundles, + prevname, + ) + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + end + end + for (inst, newinst) in replaceAndErase + replace_uses!(inst, newinst) + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end +end + function force_recompute!(mod::LLVM.Module) for f in functions(mod), bb in blocks(f) iter = LLVM.API.LLVMGetFirstInstruction(bb) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 78c9cd9ce8..d4356aba61 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -275,7 +275,7 @@ end world, ) - cmod, fwdmodenm, _, _ = _thunk(ejob, false) #=postopt=# + cmod, fwdmodenm, _, _, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) @@ -334,7 +334,7 @@ end world, ) - cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, false) #=postopt=# + cmod, adjointnm, augfwdnm, TapeType, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) From 0d9791cf1aeeab20a16d5618e20c88f6efb976b2 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 3 Dec 2024 08:29:12 -0600 Subject: [PATCH 2/7] fix --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index fed1f54eb3..f316e6c4b6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5321,10 +5321,10 @@ const cache_lock = ReentrantLock() asm = _thunk(job) obj = _link(job, asm...) if obj.adjoint isa Ptr{Nothing} - autodiff_cache[obj.adjoint] = (asm[2], asm[1]) + autodiff_cache[obj.adjoint] = (asm[2], asm[5]) end if obj.primal isa Ptr{Nothing} - autodiff_cache[obj.primal] = (asm[3], asm[1]) + autodiff_cache[obj.primal] = (asm[3], asm[5]) end cache[key] = obj end From 340da7fdba9860215f12d301c1edc269df4facbd Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 3 Dec 2024 08:40:49 -0600 Subject: [PATCH 3/7] fix --- src/compiler.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index f316e6c4b6..7893f168b4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5266,7 +5266,7 @@ end const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool = true) +function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{String, Nothing}, Type, String} mod, meta = codegen(:llvm, job; optimize = false) adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf @@ -5320,15 +5320,16 @@ const cache_lock = ReentrantLock() if obj === nothing asm = _thunk(job) obj = _link(job, asm...) + @show obj if obj.adjoint isa Ptr{Nothing} autodiff_cache[obj.adjoint] = (asm[2], asm[5]) end - if obj.primal isa Ptr{Nothing} + if obj.primal isa Ptr{Nothing} && asm[3] isa String autodiff_cache[obj.primal] = (asm[3], asm[5]) end cache[key] = obj end - nothing + obj finally unlock(cache_lock) end From a5cc5d6f0b8ecc81a9f79c2fe664e33f9f0610d9 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 3 Dec 2024 09:27:32 -0600 Subject: [PATCH 4/7] working --- src/compiler.jl | 1 - src/compiler/validation.jl | 24 +++++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7893f168b4..16fd8509b3 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5320,7 +5320,6 @@ const cache_lock = ReentrantLock() if obj === nothing asm = _thunk(job) obj = _link(job, asm...) - @show obj if obj.adjoint isa Ptr{Nothing} autodiff_cache[obj.adjoint] = (asm[2], asm[5]) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index efdc5ec8f4..d5f30a71cb 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -58,7 +58,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod if in(f, del) continue end - check_ir!(job, errors, imported, f, del) + check_ir!(job, errors, imported, f, del, mod) end for d in del LLVM.API.LLVMDeleteFunction(d) @@ -69,7 +69,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod if in(f, del) continue end - check_ir!(job, errors, imported, f, del) + check_ir!(job, errors, imported, f, del, mod) end for d in del LLVM.API.LLVMDeleteFunction(d) @@ -78,7 +78,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod return errors end -function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}, mod::LLVM.Module) calls = LLVM.CallInst[] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 mod = LLVM.parent(f) @@ -304,7 +304,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp while length(calls) > 0 inst = pop!(calls) - check_ir!(job, errors, imported, inst, calls) + check_ir!(job, errors, imported, inst, calls, mod) end return errors end @@ -351,7 +351,7 @@ end import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION import GPUCompiler: backtrace, isintrinsic -function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}, mod::LLVM.Module) world = job.world interp = GPUCompiler.get_interpreter(job) method_table = Core.Compiler.method_table(interp) @@ -872,24 +872,27 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp ptr_val = convert(Int, ptr_arg) ptr = Ptr{Cvoid}(ptr_val) - @show ptr, autodiff_cache if haskey(autodiff_cache, ptr) - pmod, pname = autodiff_cache[ptr] + pname, pmod = autodiff_cache[ptr] + + @assert !haskey(functions(mod), pname) pmod = parse(LLVM.Module, pmod) + @assert haskey(functions(pmod), pname) + for fn in functions(pmod) if !isempty(LLVM.blocks(fn)) - linkage!(functions(mod)[pmod], fn == pname ? LLVM.API.LLVMInternalLinkage : LLVM.API.LLVMExternalLinkage) + linkage!(fn, LLVM.name(fn) != pname ? LLVM.API.LLVMInternalLinkage : LLVM.API.LLVMExternalLinkage) end end - GPUCompiler.link_library!(mod, inmod) + GPUCompiler.link_library!(mod, pmod) replaceWith = functions(mod)[pname] push!(function_attributes(replaceWith), EnumAttribute("alwaysinline")) linkage!(functions(mod)[pname], LLVM.API.LLVMInternalLinkage) - replace_uses!(ptr_arg, LLVM.const_pointercast(b, replaceWith, value_type(ptr_arg))) + replace_uses!(ptr_arg, LLVM.const_pointercast(replaceWith, value_type(ptr_arg))) return errors end @@ -912,7 +915,6 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp ) # Remember pointer for subsequent restoration push!(function_attributes(lfn), StringAttribute("enzymejl_needs_restoration", string(reinterpret(UInt, ptr)))) - @show string(inst), string(lfn), ptr else lfn = LLVM.API.LLVMConstBitCast( lfn, From 4e62957abcb65feb827ae747ddad88b74dc168ef Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Dec 2024 09:48:28 -0600 Subject: [PATCH 5/7] Update validation.jl --- src/compiler/validation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index d5f30a71cb..a4b00563f3 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -914,7 +914,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp LLVM.API.LLVMGetCalledFunctionType(inst), ) # Remember pointer for subsequent restoration - push!(function_attributes(lfn), StringAttribute("enzymejl_needs_restoration", string(reinterpret(UInt, ptr)))) + push!(function_attributes(LLVM.Function(lfn)), StringAttribute("enzymejl_needs_restoration", string(reinterpret(UInt, ptr)))) else lfn = LLVM.API.LLVMConstBitCast( lfn, From 39ac9eb3d1550ad2531df133c8cd02a0df538ab3 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 3 Dec 2024 19:57:25 -0500 Subject: [PATCH 6/7] handle, again --- src/compiler/validation.jl | 136 ++++++++++++++++++++++++++++++++++++- 1 file changed, 135 insertions(+), 1 deletion(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index a4b00563f3..b4e2970b39 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -2,6 +2,140 @@ using LLVM using ObjectFile using Libdl +module FFI +using LLVM +module BLASSupport +# TODO: LAPACK handling +using LinearAlgebra +using ObjectFile +using Libdl +function __init__() + global blas_handle = Libdl.dlopen(BLAS.libblastrampoline) +end +function get_blas_symbols() + symbols = BLAS.get_config().exported_symbols + if BLAS.USE_BLAS64 + return map(Base.Fix2(*, "64_"), symbols) + end + return symbols +end + +function lookup_blas_symbol(name::String) + Libdl.dlsym(blas_handle::Ptr{Cvoid}, name; throw_error = false) +end +end + +const ptr_map = Dict{Ptr{Cvoid},String}() + +function __init__() + known_names = ( + "jl_alloc_array_1d", + "jl_alloc_array_2d", + "jl_alloc_array_3d", + "ijl_alloc_array_1d", + "ijl_alloc_array_2d", + "ijl_alloc_array_3d", + "jl_new_array", + "ijl_new_array", + "jl_array_copy", + "ijl_array_copy", + "jl_alloc_string", + "jl_in_threaded_region", + "jl_enter_threaded_region", + "jl_exit_threaded_region", + "jl_set_task_tid", + "jl_new_task", + "malloc", + "memmove", + "memcpy", + "memset", + "jl_array_grow_beg", + "ijl_array_grow_beg", + "jl_array_grow_end", + "ijl_array_grow_end", + "jl_array_grow_at", + "ijl_array_grow_at", + "jl_array_del_beg", + "ijl_array_del_beg", + "jl_array_del_end", + "ijl_array_del_end", + "jl_array_del_at", + "ijl_array_del_at", + "jl_array_ptr", + "ijl_array_ptr", + "jl_value_ptr", + "jl_get_ptls_states", + "jl_gc_add_finalizer_th", + "jl_symbol_n", + "jl_", + "jl_object_id", + "jl_reshape_array", + "ijl_reshape_array", + "jl_matching_methods", + "ijl_matching_methods", + "jl_array_sizehint", + "ijl_array_sizehint", + "jl_get_keyword_sorter", + "ijl_get_keyword_sorter", + "jl_ptr_to_array", + "jl_box_float32", + "ijl_box_float32", + "jl_box_float64", + "ijl_box_float64", + "jl_ptr_to_array_1d", + "jl_eqtable_get", + "ijl_eqtable_get", + "memcmp", + "memchr", + "jl_get_nth_field_checked", + "ijl_get_nth_field_checked", + "jl_stored_inline", + "ijl_stored_inline", + "jl_array_isassigned", + "ijl_array_isassigned", + "jl_array_ptr_copy", + "ijl_array_ptr_copy", + "jl_array_typetagdata", + "ijl_array_typetagdata", + "jl_idtable_rehash", + ) + for name in known_names + sym = LLVM.find_symbol(name) + if sym == C_NULL + continue + end + if haskey(ptr_map, sym) + # On MacOS memcpy and memmove seem to collide? + if name == "memcpy" + continue + end + end + @assert !haskey(ptr_map, sym) + ptr_map[sym] = name + end + for sym in BLASSupport.get_blas_symbols() + ptr = BLASSupport.lookup_blas_symbol(sym) + if ptr !== nothing + if haskey(ptr_map, ptr) + if ptr_map[ptr] != sym + @warn "Duplicated symbol in ptr_map" ptr, sym, ptr_map[ptr] + end + continue + end + ptr_map[ptr] = sym + end + end +end + +function memoize!(ptr::Ptr{Cvoid}, fn::String)::String + fn = get(ptr_map, ptr, fn) + if haskey(ptr_map, ptr) + @assert ptr_map[ptr] == fn + end + return fn +end +end + import GPUCompiler: IRError, InvalidIRError function restore_lookups(mod::LLVM.Module)::Nothing @@ -902,7 +1036,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp if length(frames) >= 1 fn, file, line, linfo, fromC, inlined = last(frames) - fn = string(fn) + fn = FFI.memoize!(ptr, string(fn)) if length(fn) > 1 && fromC mod = LLVM.parent(LLVM.parent(LLVM.parent(inst))) From d629c3057119ec353a482b09ada5c4a86f4334bf Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 3 Dec 2024 21:42:45 -0600 Subject: [PATCH 7/7] Update validation.jl --- src/compiler/validation.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index b4e2970b39..525e4d874c 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -140,6 +140,21 @@ import GPUCompiler: IRError, InvalidIRError function restore_lookups(mod::LLVM.Module)::Nothing T_size_t = convert(LLVM.LLVMType, Int) + for (v, k) in FFI.ptr_map + if haskey(functions(mod), k) + f = functions(mod)[k] + replace_uses!( + f, + LLVM.Value( + LLVM.API.LLVMConstIntToPtr( + ConstantInt(T_size_t, convert(UInt, v)), + value_type(f), + ), + ), + ) + eraseInst(mod, f) + end + end for f in functions(mod) for fattr in collect(function_attributes(f)) if isa(fattr, LLVM.StringAttribute)