Skip to content
4 changes: 3 additions & 1 deletion src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,9 @@ end
ET_IllegalReplaceFicticiousPHIs = 8,
ET_GetIndexError = 9,
ET_NoTruncate = 10,
ET_GCRewrite = 11
ET_GCRewrite = 11,
ET_NaNError = 12,
ET_ShowInternalError = 12,
)

function EnzymeTypeAnalyzerToString(typeanalyzer)
Expand Down
103 changes: 89 additions & 14 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3569,7 +3569,17 @@ function create_abi_wrapper(
if existed[i] != 0
eval = val
if data[i] != -1
eval = extract_value!(builder, val, data[i])
eval = extract_value!(builder, val, data[i], "revprimal_extract_$(i)")
end
if i == 2 && actualRetType != literal_rt
if Base.isconcretetype(literal_rt) && !Base.isconcretetype(actualRetType)
eval = addrspacecast!(builder, eval, LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), Derived))
lvalty = convert(LLVM.LLVMType, literal_rt)
eval = bitcast!(builder, eval, LLVM.PointerType(lvalty, Derived))
eval = load!(builder, lvalty, eval)
else
emit_error(builder, nothing, "Unexpected type inference from LLVM codegen. \nActual return type from GPUCompiler: $(actualRetType)\n Inferred return type: $(literal_rt)\n rettype=$(rettype)\n Mode=$Mode\n TT=$TT")
end
end
if i == 3
if rettype <: MixedDuplicated || rettype <: BatchMixedDuplicated
Expand Down Expand Up @@ -3612,6 +3622,27 @@ function create_abi_wrapper(
insert_value!(builder, ival, ires, idx - 1)
end
eval = ival
elseif actualRetType != literal_rt
if Base.isconcretetype(literal_rt) && !Base.isconcretetype(actualRetType)
lvalty = convert(LLVM.LLVMType, literal_rt)
ival = UndefValue(
LLVM.LLVMType(API.EnzymeGetShadowType(width, lvalty)),
)
for idx = 1:width
pv =
(width == 1) ? eval : extract_value!(builder, eval, idx - 1)
eval = addrspacecast!(builder, eval, LLVM.PointerType(LLVM.StructType(LLVM.LLVMType[]), Derived))
eval = bitcast!(builder, eval, LLVM.PointerType(lvalty, Derived))
eval = load!(builder, lvalty, eval)

ival =
(width == 1) ? eval :
insert_value!(builder, ival, eval, idx - 1)
end
eval = ival
else
emit_error(builder, nothing, "Unexpected type inference from LLVM codegen. \nActual return type from GPUCompiler: $(actualRetType)\n Inferred return type: $(literal_rt)\n rettype=$(rettype)\n Mode=$Mode\n TT=$TT")
end
end
end
eval = fixup_abi(i, eval)
Expand All @@ -3623,9 +3654,11 @@ function create_abi_wrapper(
LLVM.ConstantInt(LLVM.IntType(64), 0),
LLVM.ConstantInt(LLVM.IntType(32), returnNum),
],
"revprimal_1_wrap_sret_gep_$returnNum"
)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
extract_struct_into!(builder, ptr, eval)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)),
"revprimal_1_wrap_sret_cast_$returnNum")
extract_struct_into!(builder, ptr, eval, "revprimal_1_wrap_sret_extract_$returnNum")
returnNum += 1
if i == 3 && shadow_init
shadows = LLVM.Value[]
Expand Down Expand Up @@ -3665,9 +3698,10 @@ function create_abi_wrapper(
LLVM.ConstantInt(LLVM.IntType(64), 0),
LLVM.ConstantInt(LLVM.IntType(32), returnNum),
],
"revprimal_2_wrap_sret_gep_$returnNum"
)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
extract_struct_into!(builder, ptr, eval)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)), "revprimal_1_wrap_sret_cast_$returnNum")
extract_struct_into!(builder, ptr, eval, "revprimal_2_wrap_sret_extract_$returnNum")
returnNum += 1
end
end
Expand Down Expand Up @@ -3765,9 +3799,10 @@ function create_abi_wrapper(
LLVM.ConstantInt(LLVM.IntType(64), 0),
LLVM.ConstantInt(LLVM.IntType(32), returnNum),
],
"fwd_wrap_sret_gep_$returnNum"
)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)))
extract_struct_into!(builder, ptr, eval)
ptr = pointercast!(builder, ptr, LLVM.PointerType(value_type(eval)), "fwd_wrap_sret_cast_$returnNum")
extract_struct_into!(builder, ptr, eval, "fwd_wrap_sret_extract_$returnNum")
end
@assert count_Sret == numLLVMReturns
else
Expand Down Expand Up @@ -3797,14 +3832,16 @@ function create_abi_wrapper(
length(elements(jltype)) - 1,
),
],
"revcombined_wrap_sret_gep_$returnNum"
),
eval,
"revcombined_wrap_sret_extract_$returnNum"
)
returnNum += 1
end
end
end
for T in TT.parameters[2:end]
for (i, T) in enumerate(TT.parameters[2:end])
if T <: Active
T′ = eltype(T)
isboxed = GPUCompiler.deserves_argbox(T′)
Expand All @@ -3821,8 +3858,10 @@ function create_abi_wrapper(
LLVM.ConstantInt(LLVM.IntType(32), 0),
LLVM.ConstantInt(LLVM.IntType(32), activeNum),
],
"revcombined_wrap_sret_gep_active_$(i)_$(T′)"
),
eval,
"revcombined_wrap_sret_extract_active_$(i)_$(T′)"
)
returnNum += 1
end
Expand Down Expand Up @@ -4174,7 +4213,7 @@ function extract_nonjlvalues_into!(builder::LLVM.IRBuilder, jltype::LLVM.LLVMTyp
return nothing
end

function extract_struct_into!(builder::LLVM.IRBuilder, dst::LLVM.Value, src::LLVM.Value)
function extract_struct_into!(builder::LLVM.IRBuilder, dst::LLVM.Value, src::LLVM.Value, name::String)
count = 0
jltype = value_type(src)
todo = Tuple{Vector{Cuint},LLVM.LLVMType}[(
Expand Down Expand Up @@ -4223,8 +4262,8 @@ function extract_struct_into!(builder::LLVM.IRBuilder, dst::LLVM.Value, src::LLV
continue
end

dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocsi")
val = length(path) == 0 ? src : Enzyme.API.e_extract_value!(builder, src, path)
dstloc = inbounds_gep!(builder, jltype, dst, to_llvm(path), "dstlocsi_$(name)_$(join(path, ","))")
val = length(path) == 0 ? src : Enzyme.API.e_extract_value!(builder, src, path, "srclocei_$(name)_$(join(path, ","))")
st = store!(builder, val, dstloc)
end

Expand Down Expand Up @@ -4668,7 +4707,13 @@ function lower_convention(
@assert elty == eltype(ty)
end

ptr = alloca!(builder, elty, LLVM.name(parm) * ".innerparm")
elty_foralloca = if VERSION >= v"1.12" && arg.rooted_typ !== nothing
strip_tracked_pointers(elty)
else
elty
end

ptr = alloca!(builder, elty_foralloca, LLVM.name(parm) * ".innerparm")
if TT !== nothing && TT.parameters[arg.arg_jl_i] <: Const
metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
end
Expand Down Expand Up @@ -4859,6 +4904,13 @@ function lower_convention(
string(convert(UInt, unsafe_to_pointer(actualRetType))),
),
)
push!(
return_attributes(wrapper_f),
StringAttribute(
"enzymejl_parmtype_str",
string(actualRetType),
),
)
push!(
return_attributes(wrapper_f),
StringAttribute(
Expand All @@ -4885,6 +4937,13 @@ function lower_convention(
string(convert(UInt, unsafe_to_pointer(actualRetType))),
),
)
push!(
return_attributes(wrapper_f),
StringAttribute(
"enzymejl_parmtype_str",
string(actualRetType),
),
)
push!(
return_attributes(wrapper_f),
StringAttribute(
Expand Down Expand Up @@ -4920,6 +4979,13 @@ function lower_convention(
string(convert(UInt, unsafe_to_pointer(expected_RT))),
),
)
push!(
return_attributes(wrapper_f),
StringAttribute(
"enzymejl_parmtype_str",
string(expected_RT),
),
)
push!(
return_attributes(wrapper_f),
StringAttribute(
Expand Down Expand Up @@ -6685,7 +6751,12 @@ const DumpLLVMCall = Ref(false)
callparams,
alloca!(builder, LLVM.ArrayType(T_prjlvalue, tracked.count), "enzyme_call.return_roots"),
)
pushfirst!(callparams, alloca!(builder, jltype, "enzyme_call.sret"))
jltype_foralloca = if VERSION >= v"1.12"
strip_tracked_pointers(jltype)
else
jltype
end
pushfirst!(callparams, alloca!(builder, jltype_foralloca, "enzyme_call.sret"))
end

if needs_tape && !(isghostty(TapeType) || Core.Compiler.isconstType(TapeType))
Expand Down Expand Up @@ -6748,7 +6819,11 @@ const DumpLLVMCall = Ref(false)
if !LLVM.is_opaque(value_type(callparams[1]))
@assert eltype(value_type(callparams[1])) == jltype
end
r = load!(builder, jltype, callparams[1])
r = @static if VERSION >= v"1.12"
recombine_value_ptr!(builder, jltype, callparams[1], callparams[2])
else
load!(builder, jltype, callparams[1])
end
end

if T_ret != T_void
Expand Down
3 changes: 3 additions & 0 deletions src/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,9 @@ function julia_error(
throw(IllegalFirstPointerException(msg, ir, bt))
elseif errtype == API.ET_InternalError
throw(EnzymeInternalError(msg, ir, bt))
elseif errtype == API.ET_ShowInternalError
Core.print(EnzymeInternalError(msg, ir, bt))
return C_NULL
elseif errtype == API.ET_GCRewrite
data2 = LLVM.Value(data2)
fn = LLVM.Function(LLVM.API.LLVMGetParamParent(data2::LLVM.Argument))
Expand Down
Loading
Loading