Add @code_xla_llvm macro#2911
Open
Antipath1 wants to merge 1 commit into
Open
Conversation
Member
|
cc @gbaraldi |
Contributor
Author
|
Also an example IR that I generated from this function function test_math_reduce(x, y)
return sum(sin.(x) .+ cos.(y) .^ 2)
end
vec_jl = rand(Float32, 100)
vec_ra = Reactant.to_rarray(vec_jl)
f_reduce = Reactant.@code_xla_llvm test_math_reduce(vec_ra, vec_ra)
println(f_reduce)Click to expand LLVM IR; ModuleID = '__compute_module'
source_filename = "__compute_module"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
@0 = private unnamed_addr constant [4 x i8] zeroinitializer, align 4
!xla_cpu_memory_region_name = !{!0}
!0 = !{!"ir_emitter"}
; ModuleID = '__compute_module___compute_module_multiply_add_fusion'
source_filename = "__compute_module___compute_module_multiply_add_fusion"
%XLA_CPU_KernelCallFrame = type { ptr, ptr, i64, ptr }
%XLA_CPU_KernelArg = type { ptr, i64 }
%kernel_dim3 = type { i64, i64, i64 }
define ptr @multiply_add_fusion(ptr %0) {
%2 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 3
%3 = load ptr, ptr %2, align 8, !invariant.load !2
%4 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 0, i32 0
%5 = load ptr, ptr %4, align 8, !invariant.load !2<details>
<summary>Click to expand LLVM IR</summary>
```llvm
; Your long LLVM IR dump goes here...
define i32 @main() {
entry:
ret i32 0
}
%6 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 1, i32 0
%7 = load ptr, ptr %6, align 8, !invariant.load !2
%8 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 1
%9 = load ptr, ptr %8, align 8
%10 = getelementptr inbounds %kernel_dim3, ptr %9, i32 0, i32 0
%11 = load i64, ptr %10, align 4, !invariant.load !2
%12 = call i64 @llvm.sx.i64(i64 %11, i64 0)
%13 = mul nuw nsw i64 %12, 7
%14 = call i64 @llvm.smin.i64(i64 %13, i64 7)
br label %15
15: ; preds = %18, %1
%16 = phi i64 [ %19, %18 ], [ %14, %1 ]
%17 = icmp slt i64 %16, 7
br i1 %17, label %18, label %20
18: ; preds = %15
call void @multiply_add_fusion_impl(ptr %5, ptr %5, i64 0, i64 100, i64 1, ptr %7, ptr %7, i64 0, i64 100, i64 1, i64 %16)
%19 = add i64 %16, 1
br label %15
20: ; preds = %15
ret ptr null
}
; Function Attrs: alwaysinline
define internal void @multiply_add_fusion_impl(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, ptr %5, ptr %6, i64 %7, i64 %8, i64 %9, i64 %10) #0 {
%12 = mul nsw i64 %10, 16
%13 = sub i64 100, %12
%14 = call i64 @llvm.smin.i64(i64 %13, i64 16)
%15 = icmp eq i64 %14, 16
%16 = alloca float, i64 16, align 4
%17 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } poison, ptr %16, 0
%18 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %17, ptr %16, 1
%19 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %18, i64 0, 2
%20 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %19, i64 16, 3, 0
%21 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %20, i64 1, 4, 0
%22 = alloca float, i64 16, align 4
%23 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } poison, ptr %22, 0
%24 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %23, ptr %22, 1
%25 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %24, i64 0, 2
%26 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %25, i64 16, 3, 0
%27 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %26, i64 1, 4, 0
%28 = select i1 %15, { ptr, ptr, i64, [1 x i64], [1 x i64] } %21, { ptr, ptr, i64, [1 x i64], [1 x i64] } %27
br i1 %15, label %29, label %31
29: ; preds = %11
%30 = getelementptr float, ptr %1, i64 %12
call void @llvm.memcpy.p0.p0.i64(ptr %16, ptr %30, i64 64, i1 false)
br label %42
31: ; preds = %11
br label %32
32: ; preds = %35, %31
%33 = phi i64 [ %40, %35 ], [ 0, %31 ]
%34 = icmp slt i64 %33, %14
br i1 %34, label %35, label %41
35: ; preds = %32
%36 = getelementptr float, ptr %1, i64 %12
%37 = getelementptr inbounds nuw float, ptr %36, i64 %33
%38 = load float, ptr %37, align 4
%39 = getelementptr inbounds nuw float, ptr %22, i64 %33
store float %38, ptr %39, align 4
%40 = add i64 %33, 1
br label %32
41: ; preds = %32
br label %42
42: ; preds = %29, %41
%43 = alloca float, i64 16, align 64
%44 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %28, 1
br label %45
45: ; preds = %48, %42
%46 = phi i64 [ %56, %48 ], [ 0, %42 ]
%47 = icmp slt i64 %46, 16
br i1 %47, label %48, label %57
48: ; preds = %45
%49 = getelementptr float, ptr %44, i64 %46
%50 = load <8 x float>, ptr %49, align 4
%51 = call <8 x float> @llvm.cos.v8f32(<8 x float> %50)
%52 = call <8 x float> @llvm.sin.v8f32(<8 x float> %50)
%53 = fmul <8 x float> %51, %51
%54 = fadd <8 x float> %52, %53
%55 = getelementptr float, ptr %43, i64 %46
store <8 x float> %54, ptr %55, align 4
%56 = add i64 %46, 8
br label %45
57: ; preds = %45
br i1 %15, label %58, label %60
58: ; preds = %57
%59 = getelementptr float, ptr %6, i64 %12
call void @llvm.memcpy.p0.p0.i64(ptr %59, ptr %43, i64 64, i1 false)
br label %71
60: ; preds = %57
br label %61
61: ; preds = %64, %60
%62 = phi i64 [ %69, %64 ], [ 0, %60 ]
%63 = icmp slt i64 %62, %14
br i1 %63, label %64, label %70
64: ; preds = %61
%65 = getelementptr inbounds nuw float, ptr %43, i64 %62
%66 = load float, ptr %65, align 4
%67 = getelementptr float, ptr %6, i64 %12
%68 = getelementptr inbounds nuw float, ptr %67, i64 %62
store float %66, ptr %68, align 4
%69 = add i64 %62, 1
br label %61
70: ; preds = %61
br label %71
71: ; preds = %58, %70
ret void
}
; Function Attrs: nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none)
declare i64 @llvm.smax.i64(i64, i64) #1
; Function Attrs: nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none)
declare i64 @llvm.smin.i64(i64, i64) #1
; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: readwrite)
declare void @llvm.memcpy.p0.p0.i64(ptr noalias writeonly captures(none), ptr noalias readonly captures(none), i64, i1 immarg) #2
; Function Attrs: nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none)
declare <8 x float> @llvm.cos.v8f32(<8 x float>) #1
; Function Attrs: nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none)
declare <8 x float> @llvm.sin.v8f32(<8 x float>) #1
attributes #0 = { alwaysinline }
attributes #1 = { nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none) }
attributes #2 = { nocallback nofree nounwind willreturn memory(argmem: readwrite) }
!llvm.module.flags = !{!0}
!xla_cpu_memory_region_name = !{!1}
!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{!"xla_cpu_emitter__tiled_emitter__hlo_opcode__fusion"}
!2 = !{}
; ModuleID = '__compute_module_wrapped_reduce_kernel_module'
source_filename = "__compute_module_wrapped_reduce_kernel_module"
%XLA_CPU_KernelCallFrame = type { ptr, ptr, i64, ptr }
%XLA_CPU_KernelArg = type { ptr, i64 }
%kernel_dim3 = type { i64, i64, i64 }
; Function Attrs: uwtable
define ptr @wrapped_reduce(ptr %0) #0 {
%2 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 3
%3 = load ptr, ptr %2, align 8, !invariant.load !2
%4 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 0, i32 0
%5 = load ptr, ptr %4, align 8, !invariant.load !2, !dereferenceable !3
%6 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 1, i32 0
%7 = load ptr, ptr %6, align 8, !invariant.load !2, !dereferenceable !4
%8 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 2, i32 0
%9 = load ptr, ptr %8, align 8, !invariant.load !2, !dereferenceable !4
%10 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 1
%11 = load ptr, ptr %10, align 8
%12 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 0
%13 = load i64, ptr %12, align 4, !invariant.load !2
%14 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 1
%15 = load i64, ptr %14, align 4, !invariant.load !2
%16 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 2
%17 = load i64, ptr %16, align 4, !invariant.load !2
call void @wrapped_reduce_wrapped(ptr %5, ptr %7, ptr %9, i64 %13, i64 %15, i64 %17)
ret ptr null
}
; Function Attrs: alwaysinline
define internal void @wrapped_reduce_wrapped(ptr noalias align 64 dereferenceable(16) %0, ptr noalias align 64 dereferenceable(4) %1, ptr noalias align 64 dereferenceable(4) %2, i64 %3, i64 %4, i64 %5) #1 {
%7 = getelementptr inbounds [1 x float], ptr %1, i32 0, i32 0
%8 = load float, ptr %7, align 4, !invariant.load !2
br label %9
9: ; preds = %13, %6
%10 = phi i64 [ %17, %13 ], [ 0, %6 ]
%11 = phi float [ %16, %13 ], [ %8, %6 ]
%12 = icmp slt i64 %10, 4
br i1 %12, label %13, label %18
13: ; preds = %9
%14 = getelementptr inbounds [4 x float], ptr %0, i32 0, i64 %10
%15 = load float, ptr %14, align 4, !invariant.load !2
%16 = fadd reassoc float %11, %15
%17 = add i64 %10, 1
br label %9
18: ; preds = %9
%19 = getelementptr inbounds [1 x float], ptr %2, i32 0, i32 0
store float %11, ptr %19, align 4
ret void
}
attributes #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" }
attributes #1 = { alwaysinline }
!llvm.module.flags = !{!0}
!xla_cpu_memory_region_name = !{!1}
!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{!"xla_cpu_emitter__loop_fusion_kernel_emitter__hlo_opcode__fusion"}
!2 = !{}
!3 = !{i64 16}
!4 = !{i64 4}
; ModuleID = '__compute_module_wrapped_reduce-window_kernel_module'
source_filename = "__compute_module_wrapped_reduce-window_kernel_module"
%XLA_CPU_KernelCallFrame = type { ptr, ptr, i64, ptr }
%XLA_CPU_KernelArg = type { ptr, i64 }
%kernel_dim3 = type { i64, i64, i64 }
; Function Attrs: uwtable
define ptr @wrapped_reduce-window(ptr %0) #0 {
%2 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 3
%3 = load ptr, ptr %2, align 8, !invariant.load !2
%4 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 0, i32 0
%5 = load ptr, ptr %4, align 8, !invariant.load !2, !dereferenceable !3
%6 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 1, i32 0
%7 = load ptr, ptr %6, align 8, !invariant.load !2, !dereferenceable !4
%8 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 2, i32 0
%9 = load ptr, ptr %8, align 8, !invariant.load !2, !dereferenceable !5
%10 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 1
%11 = load ptr, ptr %10, align 8
%12 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 0
%13 = load i64, ptr %12, align 4, !invariant.load !2
%14 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 1
%15 = load i64, ptr %14, align 4, !invariant.load !2
%16 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 2
%17 = load i64, ptr %16, align 4, !invariant.load !2
call void @wrapped_reduce-window_wrapped(ptr %5, ptr %7, ptr %9, i64 %13, i64 %15, i64 %17)
ret ptr null
}
; Function Attrs: alwaysinline
define internal void @wrapped_reduce-window_wrapped(ptr noalias align 64 dereferenceable(400) %0, ptr noalias align 64 dereferenceable(4) %1, ptr noalias align 64 dereferenceable(16) %2, i64 %3, i64 %4, i64 %5) #1 {
%7 = getelementptr inbounds [1 x float], ptr %1, i32 0, i32 0
%8 = load float, ptr %7, align 4, !invariant.load !2
br label %9
9: ; preds = %33, %6
%10 = phi i64 [ %35, %33 ], [ 0, %6 ]
%11 = icmp slt i64 %10, 4
br i1 %11, label %12, label %36
12: ; preds = %9
%13 = mul nsw i64 %10, 32
br label %14
14: ; preds = %31, %12
%15 = phi i64 [ %32, %31 ], [ 0, %12 ]
%16 = phi float [ %30, %31 ], [ %8, %12 ]
%17 = icmp slt i64 %15, 32
br i1 %17, label %18, label %33
18: ; preds = %14
%19 = add nsw i64 %13, %15
%20 = icmp sge i64 %19, 14
%21 = icmp sle i64 %19, 113
%22 = and i1 %20, %21
br i1 %22, label %23, label %28
23: ; preds = %18
%24 = add nsw i64 %19, -14
%25 = getelementptr inbounds [100 x float], ptr %0, i32 0, i64 %24
%26 = load float, ptr %25, align 4, !invariant.load !2
%27 = fadd reassoc float %16, %26
br label %29
28: ; preds = %18
br label %29
29: ; preds = %23, %28
%30 = phi float [ %16, %28 ], [ %27, %23 ]
br label %31
31: ; preds = %29
%32 = add i64 %15, 1
br label %14
33: ; preds = %14
%34 = getelementptr inbounds [4 x float], ptr %2, i32 0, i64 %10
store float %16, ptr %34, align 4
%35 = add i64 %10, 1
br label %9, !llvm.loop !6
36: ; preds = %9
ret void
}
attributes #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" }
attributes #1 = { alwaysinline }
!llvm.module.flags = !{!0}
!xla_cpu_memory_region_name = !{!1}
!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{!"xla_cpu_emitter__loop_fusion_kernel_emitter__hlo_opcode__fusion"}
!2 = !{}
!3 = !{i64 400}
!4 = !{i64 4}
!5 = !{i64 16}
!6 = distinct !{!6, !7}
!7 = !{!"llvm.loop.unroll.disable"}
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a macro @code_xla_llvm that allows to see the LLVM IR generated for CPUs by XLA. This importantly does not implement the functionality from createLLVMMod since I did not get that to work yet. Still have to open an issue in Enzyme-Jax for that. (Yes, I opened a PR for this already, but I accidentally closed that one, whoops)
Ref #1412