Skip to content

Add @code_xla_llvm macro#2911

Open
Antipath1 wants to merge 1 commit into
EnzymeAD:mainfrom
Antipath1:la/Add_code_xla_llvm_macro_v2
Open

Add @code_xla_llvm macro#2911
Antipath1 wants to merge 1 commit into
EnzymeAD:mainfrom
Antipath1:la/Add_code_xla_llvm_macro_v2

Conversation

@Antipath1
Copy link
Copy Markdown
Contributor

Adds a macro @code_xla_llvm that allows to see the LLVM IR generated for CPUs by XLA. This importantly does not implement the functionality from createLLVMMod since I did not get that to work yet. Still have to open an issue in Enzyme-Jax for that. (Yes, I opened a PR for this already, but I accidentally closed that one, whoops)

Ref #1412

@wsmoses
Copy link
Copy Markdown
Member

wsmoses commented May 19, 2026

cc @gbaraldi

@Antipath1
Copy link
Copy Markdown
Contributor Author

Also an example IR that I generated from this function

function test_math_reduce(x, y)
    return sum(sin.(x) .+ cos.(y) .^ 2)
end
vec_jl = rand(Float32, 100)
vec_ra = Reactant.to_rarray(vec_jl)
f_reduce = Reactant.@code_xla_llvm test_math_reduce(vec_ra, vec_ra)
println(f_reduce)
Click to expand LLVM IR
; ModuleID = '__compute_module'
source_filename = "__compute_module"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@0 = private unnamed_addr constant [4 x i8] zeroinitializer, align 4

!xla_cpu_memory_region_name = !{!0}

!0 = !{!"ir_emitter"}

; ModuleID = '__compute_module___compute_module_multiply_add_fusion'
source_filename = "__compute_module___compute_module_multiply_add_fusion"

%XLA_CPU_KernelCallFrame = type { ptr, ptr, i64, ptr }
%XLA_CPU_KernelArg = type { ptr, i64 }
%kernel_dim3 = type { i64, i64, i64 }

define ptr @multiply_add_fusion(ptr %0) {
  %2 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 3
  %3 = load ptr, ptr %2, align 8, !invariant.load !2
  %4 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 0, i32 0
  %5 = load ptr, ptr %4, align 8, !invariant.load !2<details>
<summary>Click to expand LLVM IR</summary>

```llvm
; Your long LLVM IR dump goes here...
define i32 @main() {
entry:
  ret i32 0
}
  %6 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 1, i32 0
  %7 = load ptr, ptr %6, align 8, !invariant.load !2
  %8 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 1
  %9 = load ptr, ptr %8, align 8
  %10 = getelementptr inbounds %kernel_dim3, ptr %9, i32 0, i32 0
  %11 = load i64, ptr %10, align 4, !invariant.load !2
  %12 = call i64 @llvm.sx.i64(i64 %11, i64 0)
  %13 = mul nuw nsw i64 %12, 7
  %14 = call i64 @llvm.smin.i64(i64 %13, i64 7)
  br label %15

15:                                               ; preds = %18, %1
  %16 = phi i64 [ %19, %18 ], [ %14, %1 ]
  %17 = icmp slt i64 %16, 7
  br i1 %17, label %18, label %20

18:                                               ; preds = %15
  call void @multiply_add_fusion_impl(ptr %5, ptr %5, i64 0, i64 100, i64 1, ptr %7, ptr %7, i64 0, i64 100, i64 1, i64 %16)
  %19 = add i64 %16, 1
  br label %15

20:                                               ; preds = %15
  ret ptr null
}

; Function Attrs: alwaysinline
define internal void @multiply_add_fusion_impl(ptr %0, ptr %1, i64 %2, i64 %3, i64 %4, ptr %5, ptr %6, i64 %7, i64 %8, i64 %9, i64 %10) #0 {
  %12 = mul nsw i64 %10, 16
  %13 = sub i64 100, %12
  %14 = call i64 @llvm.smin.i64(i64 %13, i64 16)
  %15 = icmp eq i64 %14, 16
  %16 = alloca float, i64 16, align 4
  %17 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } poison, ptr %16, 0
  %18 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %17, ptr %16, 1
  %19 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %18, i64 0, 2
  %20 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %19, i64 16, 3, 0
  %21 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %20, i64 1, 4, 0
  %22 = alloca float, i64 16, align 4
  %23 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } poison, ptr %22, 0
  %24 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %23, ptr %22, 1
  %25 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %24, i64 0, 2
  %26 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %25, i64 16, 3, 0
  %27 = insertvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %26, i64 1, 4, 0
  %28 = select i1 %15, { ptr, ptr, i64, [1 x i64], [1 x i64] } %21, { ptr, ptr, i64, [1 x i64], [1 x i64] } %27
  br i1 %15, label %29, label %31

29:                                               ; preds = %11
  %30 = getelementptr float, ptr %1, i64 %12
  call void @llvm.memcpy.p0.p0.i64(ptr %16, ptr %30, i64 64, i1 false)
  br label %42

31:                                               ; preds = %11
  br label %32

32:                                               ; preds = %35, %31
  %33 = phi i64 [ %40, %35 ], [ 0, %31 ]
  %34 = icmp slt i64 %33, %14
  br i1 %34, label %35, label %41

35:                                               ; preds = %32
  %36 = getelementptr float, ptr %1, i64 %12
  %37 = getelementptr inbounds nuw float, ptr %36, i64 %33
  %38 = load float, ptr %37, align 4
  %39 = getelementptr inbounds nuw float, ptr %22, i64 %33
  store float %38, ptr %39, align 4
  %40 = add i64 %33, 1
  br label %32

41:                                               ; preds = %32
  br label %42

42:                                               ; preds = %29, %41
  %43 = alloca float, i64 16, align 64
  %44 = extractvalue { ptr, ptr, i64, [1 x i64], [1 x i64] } %28, 1
  br label %45

45:                                               ; preds = %48, %42
  %46 = phi i64 [ %56, %48 ], [ 0, %42 ]
  %47 = icmp slt i64 %46, 16
  br i1 %47, label %48, label %57

48:                                               ; preds = %45
  %49 = getelementptr float, ptr %44, i64 %46
  %50 = load <8 x float>, ptr %49, align 4
  %51 = call <8 x float> @llvm.cos.v8f32(<8 x float> %50)
  %52 = call <8 x float> @llvm.sin.v8f32(<8 x float> %50)
  %53 = fmul <8 x float> %51, %51
  %54 = fadd <8 x float> %52, %53
  %55 = getelementptr float, ptr %43, i64 %46
  store <8 x float> %54, ptr %55, align 4
  %56 = add i64 %46, 8
  br label %45

57:                                               ; preds = %45
  br i1 %15, label %58, label %60

58:                                               ; preds = %57
  %59 = getelementptr float, ptr %6, i64 %12
  call void @llvm.memcpy.p0.p0.i64(ptr %59, ptr %43, i64 64, i1 false)
  br label %71

60:                                               ; preds = %57
  br label %61

61:                                               ; preds = %64, %60
  %62 = phi i64 [ %69, %64 ], [ 0, %60 ]
  %63 = icmp slt i64 %62, %14
  br i1 %63, label %64, label %70

64:                                               ; preds = %61
  %65 = getelementptr inbounds nuw float, ptr %43, i64 %62
  %66 = load float, ptr %65, align 4
  %67 = getelementptr float, ptr %6, i64 %12
  %68 = getelementptr inbounds nuw float, ptr %67, i64 %62
  store float %66, ptr %68, align 4
  %69 = add i64 %62, 1
  br label %61

70:                                               ; preds = %61
  br label %71

71:                                               ; preds = %58, %70
  ret void
}

; Function Attrs: nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none)
declare i64 @llvm.smax.i64(i64, i64) #1

; Function Attrs: nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none)
declare i64 @llvm.smin.i64(i64, i64) #1

; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: readwrite)
declare void @llvm.memcpy.p0.p0.i64(ptr noalias writeonly captures(none), ptr noalias readonly captures(none), i64, i1 immarg) #2

; Function Attrs: nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none)
declare <8 x float> @llvm.cos.v8f32(<8 x float>) #1

; Function Attrs: nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none)
declare <8 x float> @llvm.sin.v8f32(<8 x float>) #1

attributes #0 = { alwaysinline }
attributes #1 = { nocallback nocreateundeforpoison nofree nosync nounwind speculatable willreturn memory(none) }
attributes #2 = { nocallback nofree nounwind willreturn memory(argmem: readwrite) }

!llvm.module.flags = !{!0}
!xla_cpu_memory_region_name = !{!1}

!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{!"xla_cpu_emitter__tiled_emitter__hlo_opcode__fusion"}
!2 = !{}

; ModuleID = '__compute_module_wrapped_reduce_kernel_module'
source_filename = "__compute_module_wrapped_reduce_kernel_module"

%XLA_CPU_KernelCallFrame = type { ptr, ptr, i64, ptr }
%XLA_CPU_KernelArg = type { ptr, i64 }
%kernel_dim3 = type { i64, i64, i64 }

; Function Attrs: uwtable
define ptr @wrapped_reduce(ptr %0) #0 {
  %2 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 3
  %3 = load ptr, ptr %2, align 8, !invariant.load !2
  %4 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 0, i32 0
  %5 = load ptr, ptr %4, align 8, !invariant.load !2, !dereferenceable !3
  %6 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 1, i32 0
  %7 = load ptr, ptr %6, align 8, !invariant.load !2, !dereferenceable !4
  %8 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 2, i32 0
  %9 = load ptr, ptr %8, align 8, !invariant.load !2, !dereferenceable !4
  %10 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 1
  %11 = load ptr, ptr %10, align 8
  %12 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 0
  %13 = load i64, ptr %12, align 4, !invariant.load !2
  %14 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 1
  %15 = load i64, ptr %14, align 4, !invariant.load !2
  %16 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 2
  %17 = load i64, ptr %16, align 4, !invariant.load !2
  call void @wrapped_reduce_wrapped(ptr %5, ptr %7, ptr %9, i64 %13, i64 %15, i64 %17)
  ret ptr null
}

; Function Attrs: alwaysinline
define internal void @wrapped_reduce_wrapped(ptr noalias align 64 dereferenceable(16) %0, ptr noalias align 64 dereferenceable(4) %1, ptr noalias align 64 dereferenceable(4) %2, i64 %3, i64 %4, i64 %5) #1 {
  %7 = getelementptr inbounds [1 x float], ptr %1, i32 0, i32 0
  %8 = load float, ptr %7, align 4, !invariant.load !2
  br label %9

9:                                                ; preds = %13, %6
  %10 = phi i64 [ %17, %13 ], [ 0, %6 ]
  %11 = phi float [ %16, %13 ], [ %8, %6 ]
  %12 = icmp slt i64 %10, 4
  br i1 %12, label %13, label %18

13:                                               ; preds = %9
  %14 = getelementptr inbounds [4 x float], ptr %0, i32 0, i64 %10
  %15 = load float, ptr %14, align 4, !invariant.load !2
  %16 = fadd reassoc float %11, %15
  %17 = add i64 %10, 1
  br label %9

18:                                               ; preds = %9
  %19 = getelementptr inbounds [1 x float], ptr %2, i32 0, i32 0
  store float %11, ptr %19, align 4
  ret void
}

attributes #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" }
attributes #1 = { alwaysinline }

!llvm.module.flags = !{!0}
!xla_cpu_memory_region_name = !{!1}

!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{!"xla_cpu_emitter__loop_fusion_kernel_emitter__hlo_opcode__fusion"}
!2 = !{}
!3 = !{i64 16}
!4 = !{i64 4}

; ModuleID = '__compute_module_wrapped_reduce-window_kernel_module'
source_filename = "__compute_module_wrapped_reduce-window_kernel_module"

%XLA_CPU_KernelCallFrame = type { ptr, ptr, i64, ptr }
%XLA_CPU_KernelArg = type { ptr, i64 }
%kernel_dim3 = type { i64, i64, i64 }

; Function Attrs: uwtable
define ptr @wrapped_reduce-window(ptr %0) #0 {
  %2 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 3
  %3 = load ptr, ptr %2, align 8, !invariant.load !2
  %4 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 0, i32 0
  %5 = load ptr, ptr %4, align 8, !invariant.load !2, !dereferenceable !3
  %6 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 1, i32 0
  %7 = load ptr, ptr %6, align 8, !invariant.load !2, !dereferenceable !4
  %8 = getelementptr inbounds %XLA_CPU_KernelArg, ptr %3, i32 2, i32 0
  %9 = load ptr, ptr %8, align 8, !invariant.load !2, !dereferenceable !5
  %10 = getelementptr inbounds %XLA_CPU_KernelCallFrame, ptr %0, i32 0, i32 1
  %11 = load ptr, ptr %10, align 8
  %12 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 0
  %13 = load i64, ptr %12, align 4, !invariant.load !2
  %14 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 1
  %15 = load i64, ptr %14, align 4, !invariant.load !2
  %16 = getelementptr inbounds %kernel_dim3, ptr %11, i32 0, i32 2
  %17 = load i64, ptr %16, align 4, !invariant.load !2
  call void @wrapped_reduce-window_wrapped(ptr %5, ptr %7, ptr %9, i64 %13, i64 %15, i64 %17)
  ret ptr null
}

; Function Attrs: alwaysinline
define internal void @wrapped_reduce-window_wrapped(ptr noalias align 64 dereferenceable(400) %0, ptr noalias align 64 dereferenceable(4) %1, ptr noalias align 64 dereferenceable(16) %2, i64 %3, i64 %4, i64 %5) #1 {
  %7 = getelementptr inbounds [1 x float], ptr %1, i32 0, i32 0
  %8 = load float, ptr %7, align 4, !invariant.load !2
  br label %9

9:                                                ; preds = %33, %6
  %10 = phi i64 [ %35, %33 ], [ 0, %6 ]
  %11 = icmp slt i64 %10, 4
  br i1 %11, label %12, label %36

12:                                               ; preds = %9
  %13 = mul nsw i64 %10, 32
  br label %14

14:                                               ; preds = %31, %12
  %15 = phi i64 [ %32, %31 ], [ 0, %12 ]
  %16 = phi float [ %30, %31 ], [ %8, %12 ]
  %17 = icmp slt i64 %15, 32
  br i1 %17, label %18, label %33

18:                                               ; preds = %14
  %19 = add nsw i64 %13, %15
  %20 = icmp sge i64 %19, 14
  %21 = icmp sle i64 %19, 113
  %22 = and i1 %20, %21
  br i1 %22, label %23, label %28

23:                                               ; preds = %18
  %24 = add nsw i64 %19, -14
  %25 = getelementptr inbounds [100 x float], ptr %0, i32 0, i64 %24
  %26 = load float, ptr %25, align 4, !invariant.load !2
  %27 = fadd reassoc float %16, %26
  br label %29

28:                                               ; preds = %18
  br label %29

29:                                               ; preds = %23, %28
  %30 = phi float [ %16, %28 ], [ %27, %23 ]
  br label %31

31:                                               ; preds = %29
  %32 = add i64 %15, 1
  br label %14

33:                                               ; preds = %14
  %34 = getelementptr inbounds [4 x float], ptr %2, i32 0, i64 %10
  store float %16, ptr %34, align 4
  %35 = add i64 %10, 1
  br label %9, !llvm.loop !6

36:                                               ; preds = %9
  ret void
}

attributes #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" }
attributes #1 = { alwaysinline }

!llvm.module.flags = !{!0}
!xla_cpu_memory_region_name = !{!1}

!0 = !{i32 2, !"Debug Info Version", i32 3}
!1 = !{!"xla_cpu_emitter__loop_fusion_kernel_emitter__hlo_opcode__fusion"}
!2 = !{}
!3 = !{i64 400}
!4 = !{i64 4}
!5 = !{i64 16}
!6 = distinct !{!6, !7}
!7 = !{!"llvm.loop.unroll.disable"}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants