Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Brutus/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ authors = ["Valentin Churavy <v.churavy@gmail.com>", "Leon Shen <leon@yhls.org>"
version = "0.1.0"

[deps]
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
MLIR = "bfde9dd4-8f40-4a1e-be09-1475335e1c92"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"

[compat]
julia = "1.8"
GPUCompiler = "0.13"
61 changes: 50 additions & 11 deletions Brutus/src/Brutus.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,57 @@
module Brutus

using GPUCompiler: GPUCompiler, CompilerJob
import GPUCompiler
import GPUCompiler: AbstractCompilerTarget, AbstractCompilerParams
using Preferences
using MLIR

#####
##### Exports
#####
const libbrutus = @load_preference("libbrutus")
const libbrutus_c = @load_preference("libbrutus_c")

export emit
module BrutusAPI
import ..Brutus: libbrutus, libbrutus_c
import MLIR.API: MlirDialectHandle
function mlirGetDialectHandle__jlir__()
@ccall libbrutus_c.mlirGetDialectHandle__jlir__()::MlirDialectHandle
end
end

include("init.jl")
include("codegen.jl")
include("reflection.jl")
include("interface.jl")
import MLIR: API, IR, Dialects
module BrutusDialects
include(joinpath("Dialects", string(Base.libllvm_version.major), "JuliaOps.jl"))
end

function load_dialect(ctx)
dialect = IR.DialectHandle(BrutusAPI.mlirGetDialectHandle__jlir__())
API.mlirDialectHandleRegisterDialect(dialect, ctx)
API.mlirDialectHandleLoadDialect(dialect, ctx)
end

function code_mlir(f, types)
ctx = IR.context()

src, rt = only(Base.code_ircode(f, types))

for dialect in ("func", "cf")
IR.get_or_load_dialect!(dialect)
end
load_dialect(ctx)
IR.get_or_load_dialect!(IR.DialectHandle(BrutusAPI.mlirGetDialectHandle__jlir__()))

values = Vector{Value}(undef, length(ir.stmts))
end

# using GPUCompiler: GPUCompiler, CompilerJob
# import GPUCompiler
# import GPUCompiler: AbstractCompilerTarget, AbstractCompilerParams

# #####
# ##### Exports
# #####

# export emit

# include("init.jl")
# include("codegen.jl")
# include("reflection.jl")
# include("interface.jl")

end # module
Loading