From ca3342ea7a5acd4707493dc3fe595983d8e4b9b5 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Tue, 20 Apr 2021 18:39:48 -0400 Subject: [PATCH 01/13] Added JLIRBuilder and working on codegen in Julia. --- .gitignore | 2 + Brutus/Project.toml | 1 + Brutus/scratch/juliacodegen.jl | 18 ++ Brutus/src/Brutus.jl | 2 +- Brutus/src/codegen.jl | 105 ----------- Brutus/src/compiler/Compiler.jl | 10 + Brutus/src/compiler/codegen.jl | 6 + Brutus/src/compiler/jlirgen.jl | 20 ++ Brutus/src/compiler/opbuilder.jl | 224 +++++++++++++++++++++++ Brutus/src/interface.jl | 108 ++++++++++- Brutus/test/runtests.jl | 1 + include/brutus/Dialect/Julia/JuliaOps.h | 2 + include/brutus/Dialect/Julia/JuliaOps.td | 2 +- include/brutus/brutus.h | 4 + lib/Codegen/Codegen.cpp | 44 +++-- 15 files changed, 427 insertions(+), 122 deletions(-) create mode 100644 Brutus/scratch/juliacodegen.jl delete mode 100644 Brutus/src/codegen.jl create mode 100644 Brutus/src/compiler/Compiler.jl create mode 100644 Brutus/src/compiler/codegen.jl create mode 100644 Brutus/src/compiler/jlirgen.jl create mode 100644 Brutus/src/compiler/opbuilder.jl diff --git a/.gitignore b/.gitignore index aba1851..238af48 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ build*/ Brutus/Manifest.toml llvm julia +MLIR.jl +Brutus/dev diff --git a/Brutus/Project.toml b/Brutus/Project.toml index 9e5698d..9520d08 100644 --- a/Brutus/Project.toml +++ b/Brutus/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" +MLIR = "bfde9dd4-8f40-4a1e-be09-1475335e1c92" [compat] julia = "1.5" diff --git a/Brutus/scratch/juliacodegen.jl b/Brutus/scratch/juliacodegen.jl new file mode 100644 index 0000000..5d4279a --- /dev/null +++ b/Brutus/scratch/juliacodegen.jl @@ -0,0 +1,18 @@ +module JuliaCodegen + +using Brutus +using MLIR + +function gauss(N) + acc = 0 + for i in 1:N + acc += i + end + return acc +end + +mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) +ir_code, ret = Brutus.code_ircode(mi) +ft = Brutus.Compiler.create_func_op(ir_code, ret, "gauss") + +end # module diff --git a/Brutus/src/Brutus.jl b/Brutus/src/Brutus.jl index 3469b81..449459d 100644 --- a/Brutus/src/Brutus.jl +++ b/Brutus/src/Brutus.jl @@ -11,7 +11,7 @@ import GPUCompiler: AbstractCompilerTarget, AbstractCompilerParams export emit include("init.jl") -include("codegen.jl") +include("compiler/Compiler.jl") include("reflection.jl") include("interface.jl") diff --git a/Brutus/src/codegen.jl b/Brutus/src/codegen.jl deleted file mode 100644 index 4fd1174..0000000 --- a/Brutus/src/codegen.jl +++ /dev/null @@ -1,105 +0,0 @@ -##### -##### Codegen -##### - -struct BrutusCompilerTarget <: AbstractCompilerTarget end -GPUCompiler.llvm_triple(::BrutusCompilerTarget) = Sys.MACHINE -GPUCompiler.llvm_machine(::BrutusCompilerTarget) = tm[] - -module Runtime - # the runtime library - signal_exception() = return - malloc(sz) = Base.Libc.malloc(sz) - report_oom(sz) = return - report_exception(ex) = return - report_exception_name(ex) = return - report_exception_frame(idx, func, file, line) = return -end - -@enum DumpOption::UInt8 begin - DumpIRCode = 0 - DumpTranslated = 1 - DumpCanonicalized = 2 - DumpLoweredToStd = 4 - DumpLoweredToLLVM = 8 - DumpTranslateToLLVM = 16 -end - -struct BrutusCompilerParams <: AbstractCompilerParams - emit_fptr::Bool - dump_options::Vector{DumpOption} -end - -GPUCompiler.ci_cache(job::CompilerJob{BrutusCompilerTarget}) = GLOBAL_CI_CACHE -GPUCompiler.runtime_module(job::CompilerJob{BrutusCompilerTarget}) = Runtime -GPUCompiler.isintrinsic(::CompilerJob{BrutusCompilerTarget}, fn::String) = true -GPUCompiler.can_throw(::CompilerJob{BrutusCompilerTarget}) = true -GPUCompiler.runtime_slug(job::CompilerJob{BrutusCompilerTarget}) = "brutus" - -function find_invokes(IR) - callees = Core.MethodInstance[] - for stmt in IR.stmts - if stmt isa Expr - if stmt.head == :invoke - mi = stmt.args[1] - push!(callees, mi) - end - end - end - return callees -end - -# Emit MLIR IR to stdout -function emit(job::CompilerJob) - ft = job.source.f - tt = job.source.tt - emit_fptr = job.params.emit_fptr - dump_options = job.params.dump_options - name = (ft <: Function) ? nameof(ft.instance) : nameof(ft) - - # get first method instance matching signature - entry_mi = get_methodinstance(Tuple{ft, tt.parameters...}) - IR, rt = code_ircode(entry_mi) - - if DumpIRCode in dump_options - println("return type: ", rt) - println("IRCode:\n") - println(IR) - end - - worklist = [IR] - methods = Dict{Core.MethodInstance, Tuple{Core.Compiler.IRCode, Any}}( - entry_mi => (IR, rt) - ) - - while !isempty(worklist) - code = pop!(worklist) - callees = find_invokes(code) - for callee in callees - if !haskey(methods, callee) - _code, _rt = code_ircode(callee) - - methods[callee] = (_code, _rt) - push!(worklist, _code) - end - end - end - - # generate LLVM bitcode and load it - dump_flags = reduce(|, map(UInt8, dump_options), init=0) - fptr = ccall((:brutus_codegen, "libbrutus"), - Ptr{Nothing}, - (Any, Any, Cuchar, Cuchar), - methods, entry_mi, emit_fptr, dump_flags) - return (fptr, rt) -end - -function emit(@nospecialize(ft), @nospecialize(tt); - emit_fptr::Bool=true, - dump_options::Vector{DumpOption}=DumpOption[]) - fspec = GPUCompiler.FunctionSpec(ft, Tuple{tt...}, false, nothing) - target = BrutusCompilerTarget() - params = BrutusCompilerParams(emit_fptr, dump_options) - job = CompilerJob(target, fspec, params) - return emit(job) -end diff --git a/Brutus/src/compiler/Compiler.jl b/Brutus/src/compiler/Compiler.jl new file mode 100644 index 0000000..a5640ce --- /dev/null +++ b/Brutus/src/compiler/Compiler.jl @@ -0,0 +1,10 @@ +module Compiler + +using MLIR +import MLIR.IR as JLIR + +include("jlirgen.jl") +include("opbuilder.jl") +include("codegen.jl") + +end # module diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl new file mode 100644 index 0000000..8cc4149 --- /dev/null +++ b/Brutus/src/compiler/codegen.jl @@ -0,0 +1,6 @@ +##### +##### Codegen +##### + +# This is the Julia interface between Julia's IRCode and JLIR. + diff --git a/Brutus/src/compiler/jlirgen.jl b/Brutus/src/compiler/jlirgen.jl new file mode 100644 index 0000000..7b19db0 --- /dev/null +++ b/Brutus/src/compiler/jlirgen.jl @@ -0,0 +1,20 @@ +function create_unimplemented_op(loc::JLIR.Location, type) + state = JLIR.create_operation_state("jlir::unimplemented", loc) + JLIR.push_results!(state, 1, type) + return JLIR.Operation(state) +end + +function create_constant_op(loc::JLIR.Location, value, type) + state = JLIR.create_operation_state("jlir::constant", loc) + JLIR.push_operands!(state, 1, value) + JLIR.push_results!(state, 1, type) + return JLIR.Operation(state) +end + +function create_call_op(loc::JLIR.Location, callee, arguments, type) + state = JLIR.create_operation_state("jlir::call", loc) + operands = [callee, arguments...] + JLIR.push_operands!(state, length(operands), operands) + JLIR.push_results!(state, 1, type) + return JLIR.Operation(state) +end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl new file mode 100644 index 0000000..4dd4055 --- /dev/null +++ b/Brutus/src/compiler/opbuilder.jl @@ -0,0 +1,224 @@ +##### +##### Builder +##### + +# High-level version of MLIR's OpBuilder. + +mutable struct JLIRBuilder + ctx::JLIR.Context + values::Vector{JLIR.Value} + arguments::Vector{JLIR.Value} + insertion::Int + blocks::Vector{JLIR.Block} + function JLIRBuilder() + ctx = JLIR.create_context() + ccall((:brutus_register_dialects, "libbrutus"), + Cvoid, + (JLIR.Context, ), + ctx) + new(ctx, JLIR.Value[], JLIR.Value[], 1) + end +end + +set_insertion!(b::JLIRBuilder, blk::Int) = b.insertion = blk + +function push!(b::JLIRBuilder, op::JLIR.Operation) + @assert(isdefined(b, :blocks)) + blk = b.blocks[b.insertion] + push_operation!(blk, op) +end + +##### +##### Utilities +##### + +function convert_type_to_mlir(builder::JLIRBuilder, a) + ctx = builder.ctx + return ccall((:brutus_get_juliatype, "libbrutus"), + JLIR.Type, + (JLIR.Context, Any), + ctx, a) +end + +function get_functype(builder::JLIRBuilder, args::Vector{JLIR.Type}, ret::JLIR.Type) + return MLIR.API.mlirFunctionTypeGet(builder.ctx, length(args), args, 1, [ret]) +end + +function get_functype(builder::JLIRBuilder, args, ret) + return get_functype(builder, length(args), map(args) do a + convert_type_to_mlir(builder, a) + end, 1, [convert_type_to_mlir(builder, ret)]) +end + +function unwrap(mi::Core.MethodInstance) + return mi.def.value +end +unwrap(s) = s + +function extract_linetable_meta(builder::JLIRBuilder, v::Vector{Core.LineInfoNode}) + locations = JLIR.Location[] + for n in v + method = unwrap(n.method) + file = String(n.file) + line = n.line + inlined_at = n.inlined_at + if method isa Method + fname = String(method.name) + end + if method isa Symbol + fname = String(method) + end + current = JLIR.Location(builder.ctx, fname, UInt32(line), UInt32(0)) # TODO: col. + if inlined_at > 0 + current = JLIR.Location(current, locations[inlined_at - 1]) + end + push!(locations, current) + end + return locations +end + +##### +##### High-level version of create +##### + +# In future, should be autogenerated from tablegen. + +struct UnimplementedOp end +struct ConstantOp end +struct GotoOp end +struct GotoIfNotOp end +struct PiOp end +struct CallOp end + +function create!(b::JLIRBuilder, ::UnimplementedOp, loc, type) + @assert(isdefined(b, :blocks)) + op = create_unimplemented_op(loc, type) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::ConstantOp, loc, value, type) + @assert(isdefined(b, :blocks)) + op = create_constant_op(loc, value, type) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::GotoOp, loc::JLIR.Location, blk::JLIR.Block, v::Vector{JLIR.Value}) + @assert(isdefined(b, :blocks)) + op = create_goto_op(loc, blk, v) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::GotoIfNotOp, loc::JLIR.Location, + cond::JLIR.Value, dest::JLIR.Block, v::Vector{JLIR.Value}, + fall::JLIR.Block, fallv::Vector{JLIR.Value}) + @assert(isdefined(b, :blocks)) + op = create_gotoifnot_op(loc, cond, dest, v, fall, fallv) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::PiOp, loc, value::JLIR.Value, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_pi_op(loc, value, type) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::CallOp, loc, callee::JLIR.Value, arguments::Vector{JLIR.Value}, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_call_op(loc, callee, arguments, type) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +##### +##### JLIR emission +##### + +function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value::GlobalRef, type) + name = value.name + v = getproperty(value.mod, value.name) + return create_constant_op(builder, loc, v, type) +end + +function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value::Core.SSAValue, type) + @assert(value.id >= 1) + return getindex(builder.values, value.id) +end + +function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value, type) + return create_unimplemented_op(builder, loc, type) +end + +function emit_ftype(builder::JLIRBuilder, ir_code::Core.Compiler.IRCode, ret_type) + argtypes = getfield(ir_code, :argtypes) + nargs = length(argtypes) + args = [convert_type_to_mlir(builder, a) for a in argtypes] + ret = convert_type_to_mlir(builder, ret_type) + return get_functype(builder, args, ret) +end + +function process_node!(b::JLIRBuilder) +end + +function walk_cfg_emit_branchargs(builder, + cfg::Core.Compiler.CFG, current_block::Int, + target_block::Int, stmts, types, loc::JLIR.Location) + v = JLIR.Value[] + for stmt in cfg.blocks[target].stmts + handle_node!(builder, v, stmt, loc) + end + return v +end + +function create_func_op(builder::JLIRBuilder, + ir_code::Core.Compiler.IRCode, ret::Type, name::String) + + # Setup. + irstream = ir_code.stmts + location_indices = getfield(irstream, :line) + linetable = getfield(ir_code, :linetable) + locations = extract_linetable_meta(builder, linetable) + argtypes = getfield(ir_code, :argtypes) + args = [convert_type_to_mlir(builder, a) for a in argtypes] + state = JLIR.create_operation_state(name, locations[1]) + entry_blk, reg = JLIR.add_entry_block!(state, args) + cfg = ir_code.cfg + cfg_blocks = cfg.blocks + nblocks = length(cfg_blocks) + blocks = JLIR.Block[JLIR.push_new_block!(reg) for _ in 1 : nblocks] + pushfirst!(blocks, entry_blk) + builder.blocks = blocks + stmts = irstream.inst + types = irstream.type + v = walk_cfg_emit_branchargs(builder, cfg, 1, 2, stmts, types, locations[0]) + goto = create_goto_op(Location(builder.ctx), blocks[2], v) + push!(builder, goto) + set_insertion!(builder, 2) + + # Process. + for (ind, (stmt, type)) in enumerate(zip(stmts, types)) + loc = linetable[ind] == 0 ? JLIR.Location() : locations[ind] + is_terminator = false + process_node!(builder, stmt, loc) + end + + # Create op from state and verify. + op = JLIR.Operation(state) + @assert(JLIR.verify(op)) + return op +end + +function create_func_op(ir_code::Core.Compiler.IRCode, ret::Type, name::String) + b = JLIRBuilder() + return create_func_op(b, ir_code, ret, name) +end diff --git a/Brutus/src/interface.jl b/Brutus/src/interface.jl index 4fe859b..f00e18e 100644 --- a/Brutus/src/interface.jl +++ b/Brutus/src/interface.jl @@ -1,3 +1,109 @@ +##### +##### GPUCompiler codegen +##### + +struct BrutusCompilerTarget <: AbstractCompilerTarget end +GPUCompiler.llvm_triple(::BrutusCompilerTarget) = Sys.MACHINE +GPUCompiler.llvm_machine(::BrutusCompilerTarget) = tm[] + +module Runtime + # the runtime library + signal_exception() = return + malloc(sz) = Base.Libc.malloc(sz) + report_oom(sz) = return + report_exception(ex) = return + report_exception_name(ex) = return + report_exception_frame(idx, func, file, line) = return +end + +@enum DumpOption::UInt8 begin + DumpIRCode = 0 + DumpTranslated = 1 + DumpCanonicalized = 2 + DumpLoweredToStd = 4 + DumpLoweredToLLVM = 8 + DumpTranslateToLLVM = 16 +end + +struct BrutusCompilerParams <: AbstractCompilerParams + emit_fptr::Bool + dump_options::Vector{DumpOption} +end + +GPUCompiler.ci_cache(job::CompilerJob{BrutusCompilerTarget}) = GLOBAL_CI_CACHE +GPUCompiler.runtime_module(job::CompilerJob{BrutusCompilerTarget}) = Runtime +GPUCompiler.isintrinsic(::CompilerJob{BrutusCompilerTarget}, fn::String) = true +GPUCompiler.can_throw(::CompilerJob{BrutusCompilerTarget}) = true +GPUCompiler.runtime_slug(job::CompilerJob{BrutusCompilerTarget}) = "brutus" + +function find_invokes(IR) + callees = Core.MethodInstance[] + for stmt in IR.stmts + if stmt isa Expr + if stmt.head == :invoke + mi = stmt.args[1] + push!(callees, mi) + end + end + end + return callees +end + +# Emit MLIR IR to stdout +function emit(job::CompilerJob) + ft = job.source.f + tt = job.source.tt + emit_fptr = job.params.emit_fptr + dump_options = job.params.dump_options + name = (ft <: Function) ? nameof(ft.instance) : nameof(ft) + + # get first method instance matching signature + entry_mi = get_methodinstance(Tuple{ft, tt.parameters...}) + IR, rt = code_ircode(entry_mi) + + if DumpIRCode in dump_options + println("return type: ", rt) + println("IRCode:\n") + println(IR) + end + + worklist = [IR] + methods = Dict{Core.MethodInstance, Tuple{Core.Compiler.IRCode, Any}}( + entry_mi => (IR, rt) + ) + + while !isempty(worklist) + code = pop!(worklist) + callees = find_invokes(code) + for callee in callees + if !haskey(methods, callee) + _code, _rt = code_ircode(callee) + + methods[callee] = (_code, _rt) + push!(worklist, _code) + end + end + end + + # generate LLVM bitcode and load it + dump_flags = reduce(|, map(UInt8, dump_options), init=0) + fptr = ccall((:brutus_codegen, "libbrutus"), + Ptr{Nothing}, + (Any, Any, Cuchar, Cuchar), + methods, entry_mi, emit_fptr, dump_flags) + return (fptr, rt) +end + +function emit(@nospecialize(ft), @nospecialize(tt); + emit_fptr::Bool=true, + dump_options::Vector{DumpOption}=DumpOption[]) + fspec = GPUCompiler.FunctionSpec(ft, Tuple{tt...}, false, nothing) + target = BrutusCompilerTarget() + params = BrutusCompilerParams(emit_fptr, dump_options) + job = CompilerJob(target, fspec, params) + return emit(job) +end + ##### ##### Call Interface ##### @@ -40,7 +146,7 @@ function link(job::CompilerJob, (fptr, rt)) end function thunk(f::F, tt::TT=Tuple{}; emit_fptr::Bool = true, dump_options::Vector{DumpOption} = DumpOption[]) where {F<:Base.Callable, TT<:Type} - fspec = GPUCompiler.FunctionSpec(f, tt, false, nothing) + fspec = GPUCompiler.FunctionSpec(F, tt, false, nothing) target = BrutusCompilerTarget() params = BrutusCompilerParams(emit_fptr, dump_options) job = CompilerJob(target, fspec, params) diff --git a/Brutus/test/runtests.jl b/Brutus/test/runtests.jl index e89665e..b571c37 100644 --- a/Brutus/test/runtests.jl +++ b/Brutus/test/runtests.jl @@ -68,4 +68,5 @@ for array in [rand(Int64, 2, 3), rand(Int64, 2, 3)] @test Brutus.call(customsum, array) == customsum(array) end end + # TODO: arrays with floating point elements diff --git a/include/brutus/Dialect/Julia/JuliaOps.h b/include/brutus/Dialect/Julia/JuliaOps.h index fa52ea1..878abf3 100644 --- a/include/brutus/Dialect/Julia/JuliaOps.h +++ b/include/brutus/Dialect/Julia/JuliaOps.h @@ -1,6 +1,8 @@ #ifndef JL_DIALECT_JLIR_H #define JL_DIALECT_JLIR_H +#include + #include "mlir/IR/Dialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" diff --git a/include/brutus/Dialect/Julia/JuliaOps.td b/include/brutus/Dialect/Julia/JuliaOps.td index 1690d99..e034351 100644 --- a/include/brutus/Dialect/Julia/JuliaOps.td +++ b/include/brutus/Dialect/Julia/JuliaOps.td @@ -380,4 +380,4 @@ def JLIR_Builtin_ifelse : JLIR_IntrinsicBuiltinOp<"ifelse">; def JLIR_Builtin__typevar : JLIR_IntrinsicBuiltinOp<"_typevar">; // invoke_kwsorter? -#endif // JULIA_MLIR_JLIR_TD \ No newline at end of file +#endif // JULIA_MLIR_JLIR_TD diff --git a/include/brutus/brutus.h b/include/brutus/brutus.h index d280adf..886b8c5 100644 --- a/include/brutus/brutus.h +++ b/include/brutus/brutus.h @@ -24,7 +24,11 @@ extern "C" { #endif + void brutus_register_dialects(MlirContext Context); + void brutus_register_extern_dialect(MlirContext Context, MlirDialect Dialect); + MlirType brutus_get_juliatype(MlirContext context, jl_datatype_t *datatype); + // Export C API for pipeline. typedef void (*ExecutionEngineFPtrResult)(void **); void brutus_init(jl_module_t *brutus); diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index fd15665..e0b317f 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -1,4 +1,3 @@ - #include "brutus/brutus.h" #include "brutus/brutus_internal.h" #include "brutus/Dialect/Julia/JuliaOps.h" @@ -469,32 +468,39 @@ mlir::FuncOp emit_function(jl_mlirctx_t &ctx, extern "C" { - enum DumpOption + void brutus_register_extern_dialect(MlirContext Context, MlirDialect Dialect) { - // DUMP_IRCODE = 0, - DUMP_TRANSLATED = 1, - DUMP_CANONICALIZED = 2, - DUMP_LOWERED_TO_STD = 4, - DUMP_LOWERED_TO_LLVM = 8, - DUMP_TRANSLATE_TO_LLVM = 16, + return; + } + + void brutus_register_dialects(MlirContext Context) + { + mlir::MLIRContext *ctx = unwrap(Context); + ctx->getOrLoadDialect(); + ctx->getOrLoadDialect(); + ctx->getOrLoadDialect(); + }; + + MlirType brutus_get_juliatype(MlirContext Context, + jl_datatype_t *datatype) + { + mlir::MLIRContext *ctx = unwrap(Context); + mlir::Type type = JuliaType::get(ctx, datatype); + return wrap(type); }; - // TODO: enum with ERROR codes for failures. void brutus_codegen_jlir(MlirContext Context, MlirModule Module, jl_value_t *methods, jl_method_instance_t *entry_mi, char dump_flags) { - mlir::MLIRContext *context = unwrap(Context); mlir::ModuleOp module = unwrap(Module); + brutus_register_dialects(Context); + mlir::MLIRContext *context = unwrap(Context); jl_mlirctx_t ctx(context); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - jl_value_t *entry = jl_call2(getindex_func, methods, (jl_value_t *)entry_mi); jl_value_t *ir_code = jl_fieldref(entry, 0); jl_value_t *ret_type = jl_fieldref(entry, 1); @@ -629,6 +635,16 @@ extern "C" return expectedFPtr.get(); } + enum DumpOption + { + // DUMP_IRCODE = 0, + DUMP_TRANSLATED = 1, + DUMP_CANONICALIZED = 2, + DUMP_LOWERED_TO_STD = 4, + DUMP_LOWERED_TO_LLVM = 8, + DUMP_TRANSLATE_TO_LLVM = 16, + }; + ExecutionEngineFPtrResult brutus_codegen(jl_value_t *methods, jl_method_instance_t *entry_mi, char emit_fptr, char dump_flags) { MlirContext Context = mlirContextCreate(); From a342181b748dd21bd37f2c6417dcf5910105b177 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Tue, 20 Apr 2021 23:22:39 -0400 Subject: [PATCH 02/13] Filled out codegen further. --- Brutus/src/compiler/Compiler.jl | 2 +- Brutus/src/compiler/codegen.jl | 78 +++++++++++++++ Brutus/src/compiler/jlirgen.jl | 149 +++++++++++++++++++++++++++- Brutus/src/compiler/opbuilder.jl | 160 +++---------------------------- include/brutus/brutus.h | 1 + lib/Codegen/Codegen.cpp | 8 ++ 6 files changed, 245 insertions(+), 153 deletions(-) diff --git a/Brutus/src/compiler/Compiler.jl b/Brutus/src/compiler/Compiler.jl index a5640ce..1a27bc3 100644 --- a/Brutus/src/compiler/Compiler.jl +++ b/Brutus/src/compiler/Compiler.jl @@ -3,8 +3,8 @@ module Compiler using MLIR import MLIR.IR as JLIR -include("jlirgen.jl") include("opbuilder.jl") +include("jlirgen.jl") include("codegen.jl") end # module diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 8cc4149..3daac3d 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -4,3 +4,81 @@ # This is the Julia interface between Julia's IRCode and JLIR. +function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value::GlobalRef, type) + name = value.name + v = getproperty(value.mod, value.name) + return create_constant_op(builder, loc, v, type) +end + +function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value::Core.SSAValue, type) + @assert(value.id >= 1) + return getindex(builder.values, value.id) +end + +function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value, type) + return create_unimplemented_op(builder, loc, type) +end + +function emit_ftype(builder::JLIRBuilder, ir_code::Core.Compiler.IRCode, ret_type) + argtypes = getfield(ir_code, :argtypes) + nargs = length(argtypes) + args = [convert_type_to_mlir(builder, a) for a in argtypes] + ret = convert_type_to_mlir(builder, ret_type) + return get_functype(builder, args, ret) +end + +function process_node!(b::JLIRBuilder) +end + +function walk_cfg_emit_branchargs(builder, + cfg::Core.Compiler.CFG, current_block::Int, + target_block::Int, stmts, types, loc::JLIR.Location) + v = JLIR.Value[] + for stmt in cfg.blocks[target].stmts + handle_node!(builder, v, stmt, loc) + end + return v +end + +function emit_jlir(builder::JLIRBuilder, + ir_code::Core.Compiler.IRCode, ret::Type, name::String) + + # Setup. + irstream = ir_code.stmts + location_indices = getfield(irstream, :line) + linetable = getfield(ir_code, :linetable) + locations = extract_linetable_meta(builder, linetable) + argtypes = getfield(ir_code, :argtypes) + args = [convert_type_to_mlir(builder, a) for a in argtypes] + state = JLIR.create_operation_state(name, locations[1]) + entry_blk, reg = JLIR.add_entry_block!(state, args) + cfg = ir_code.cfg + cfg_blocks = cfg.blocks + nblocks = length(cfg_blocks) + blocks = JLIR.Block[JLIR.push_new_block!(reg) for _ in 1 : nblocks] + pushfirst!(blocks, entry_blk) + builder.blocks = blocks + stmts = irstream.inst + types = irstream.type + v = walk_cfg_emit_branchargs(builder, cfg, 1, 2, stmts, types, locations[0]) + goto = create_goto_op(Location(builder.ctx), blocks[2], v) + push!(builder, goto) + set_insertion!(builder, 2) + + # Process. + for (ind, (stmt, type)) in enumerate(zip(stmts, types)) + loc = linetable[ind] == 0 ? JLIR.Location() : locations[ind] + is_terminator = false + process_node!(builder, stmt, loc) + end + + # Create op from state and verify. + op = JLIR.Operation(state) + @assert(JLIR.verify(op)) + return op +end + +function emit_jlir(ir_code::Core.Compiler.IRCode, ret::Type, name::String) + b = JLIRBuilder() + return create_func_op(b, ir_code, ret, name) +end diff --git a/Brutus/src/compiler/jlirgen.jl b/Brutus/src/compiler/jlirgen.jl index 7b19db0..42257b1 100644 --- a/Brutus/src/compiler/jlirgen.jl +++ b/Brutus/src/compiler/jlirgen.jl @@ -1,3 +1,5 @@ +# TODO: In future, should be autogenerated from tablegen. + function create_unimplemented_op(loc::JLIR.Location, type) state = JLIR.create_operation_state("jlir::unimplemented", loc) JLIR.push_results!(state, 1, type) @@ -11,10 +13,151 @@ function create_constant_op(loc::JLIR.Location, value, type) return JLIR.Operation(state) end -function create_call_op(loc::JLIR.Location, callee, arguments, type) +function create_goto_op(loc::JLIR.Location, blk::JLIR.Block, + v::Vector{JLIR.Value}) + state = JLIR.create_operation_state("jlir::goto", loc) + JLIR.push_operands!(state, v) + JLIR.push_successors!(state, JLIR.Block[blk]) + return JLIR.Operation(state) +end + +function create_gotoifnot_op(loc::JLIR.Location, cond::JLIR.Value, + dest::JLIR.Block, v::Vector{JLIR.Value}, + fall::JLIR.Block, fallv::Vector{JLIR.Value}) + state = JLIR.create_operation_state("jlir::gotoifnot", loc) + JLIR.push_operands!(state, JLIR.Value[cond]) + JLIR.push_operands!(state, v) + JLIR.push_operands!(state, fallv) + JLIR.push_successors!(state, JLIR.Block[blk]) + JLIR.push_successors!(state, JLIR.Block[fall]) + return JLIR.Operation(state) +end + +function create_pi_op(loc::JLIR.Location, input::JLIR.Type, + type::JLIR.Type) + state = JLIR.create_operation_state("jlir::pi", loc) + JLIR.push_operands!(state, JLIR.Type[value]) + JLIR.push_results!(state, JLIR.Type[type]) + return JLIR.Operation(state) +end + +function create_return_op(loc::JLIR.Location, input::JLIR.Type) + state = JLIR.create_operation_state("jlir::return", loc) + JLIR.push_operands!(state, JLIR.Type[input]) + return JLIR.Operation(state) +end + +function create_call_op(loc::JLIR.Location, callee::JLIR.Type, + arguments::Vector{JLIR.Type}, type::JLIR.Type) state = JLIR.create_operation_state("jlir::call", loc) operands = [callee, arguments...] - JLIR.push_operands!(state, length(operands), operands) - JLIR.push_results!(state, 1, type) + JLIR.push_operands!(state, operands) + JLIR.push_results!(state, JLIR.Value[type]) return JLIR.Operation(state) end + +function create_invoke_op(loc::JLIR.Location, mi::JLIR.Value, + callee::JLIR.Value, arguments::Vector{JLIR.Value}, + type::JLIR.Type) + state = JLIR.create_operation_state("jlir::invoke", loc) + JLIR.push_operands!(state, JLIR.Value[mi, callee, arguments...]) + JLIR.push_results!(state, JLIR.Value[type]) + return JLIR.Operation(state) +end + +##### +##### High-level version of create +##### + +struct UnimplementedOp end +struct ConstantOp end +struct GotoOp end +struct GotoIfNotOp end +struct PiOp end +struct ReturnOp end +struct CallOp end +struct InvokeOp end + +function create!(b::JLIRBuilder, ::UnimplementedOp, loc::JLIR.Location, + type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_unimplemented_op(loc, type) + JLIR.verify(op) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::ConstantOp, loc::JLIR.Location, + value::JLIR.Value, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_constant_op(loc, value, type) + JLIR.verify(op) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::GotoOp, loc::JLIR.Location, + blk::JLIR.Block, v::Vector{JLIR.Value}) + @assert(isdefined(b, :blocks)) + op = create_goto_op(loc, blk, v) + JLIR.verify(op) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::GotoIfNotOp, loc::JLIR.Location, + cond::JLIR.Value, dest::JLIR.Block, v::Vector{JLIR.Value}, + fall::JLIR.Block, fallv::Vector{JLIR.Value}) + @assert(isdefined(b, :blocks)) + op = create_gotoifnot_op(loc, cond, dest, v, fall, fallv) + JLIR.verify(op) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::PiOp, loc::JLIR.Location, + value::JLIR.Value, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_pi_op(loc, value, type) + JLIR.verify(op) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::ReturnOp, loc::JLIR.Location, + input::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_return_op(loc, input) + JLIR.verify(op) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::CallOp, loc::JLIR.Location, + callee::JLIR.Value, arguments::Vector{JLIR.Value}, + type::JLIR.Type) + @assert(isdefined(b, :blocks)) + op = create_call_op(loc, callee, arguments, type) + JLIR.verify(op) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end + +function create!(b::JLIRBuilder, ::InvokeOp, loc::JLIR.Location, + mi::Core.MethodInstance, callee::JLIR.Value, + arguments::Vector{JLIR.Value}, type::JLIR.Type) + @assert(isdefined(b, :blocks)) + jlir_mi = convert_value_to_jlirattr(b, mi) + op = create_invoke_op(loc, jlir_mi, callee, arguments, type) + JLIR.verify(op) + blk = getindex(b.blocks, b.insertion) + push!(blk, op) + return op +end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl index 4dd4055..09943ac 100644 --- a/Brutus/src/compiler/opbuilder.jl +++ b/Brutus/src/compiler/opbuilder.jl @@ -32,7 +32,7 @@ end ##### Utilities ##### -function convert_type_to_mlir(builder::JLIRBuilder, a) +function convert_type_to_jlirtype(builder::JLIRBuilder, a) ctx = builder.ctx return ccall((:brutus_get_juliatype, "libbrutus"), JLIR.Type, @@ -40,14 +40,22 @@ function convert_type_to_mlir(builder::JLIRBuilder, a) ctx, a) end +function convert_value_to_jlirattr(builder::JLIRBuilder, a) + ctx = builder.ctx + return ccall((:brutus_get_juliavalueattr, "libbrutus"), + JLIR.Value, + (JLIR.Context, Any), + ctx, a) +end + function get_functype(builder::JLIRBuilder, args::Vector{JLIR.Type}, ret::JLIR.Type) return MLIR.API.mlirFunctionTypeGet(builder.ctx, length(args), args, 1, [ret]) end function get_functype(builder::JLIRBuilder, args, ret) return get_functype(builder, length(args), map(args) do a - convert_type_to_mlir(builder, a) - end, 1, [convert_type_to_mlir(builder, ret)]) + convert_type_to_jlirtype(builder, a) + end, 1, [convert_type_to_jlirtype(builder, ret)]) end function unwrap(mi::Core.MethodInstance) @@ -76,149 +84,3 @@ function extract_linetable_meta(builder::JLIRBuilder, v::Vector{Core.LineInfoNod end return locations end - -##### -##### High-level version of create -##### - -# In future, should be autogenerated from tablegen. - -struct UnimplementedOp end -struct ConstantOp end -struct GotoOp end -struct GotoIfNotOp end -struct PiOp end -struct CallOp end - -function create!(b::JLIRBuilder, ::UnimplementedOp, loc, type) - @assert(isdefined(b, :blocks)) - op = create_unimplemented_op(loc, type) - blk = getindex(b.blocks, b.insertion) - push!(blk, op) - return op -end - -function create!(b::JLIRBuilder, ::ConstantOp, loc, value, type) - @assert(isdefined(b, :blocks)) - op = create_constant_op(loc, value, type) - blk = getindex(b.blocks, b.insertion) - push!(blk, op) - return op -end - -function create!(b::JLIRBuilder, ::GotoOp, loc::JLIR.Location, blk::JLIR.Block, v::Vector{JLIR.Value}) - @assert(isdefined(b, :blocks)) - op = create_goto_op(loc, blk, v) - blk = getindex(b.blocks, b.insertion) - push!(blk, op) - return op -end - -function create!(b::JLIRBuilder, ::GotoIfNotOp, loc::JLIR.Location, - cond::JLIR.Value, dest::JLIR.Block, v::Vector{JLIR.Value}, - fall::JLIR.Block, fallv::Vector{JLIR.Value}) - @assert(isdefined(b, :blocks)) - op = create_gotoifnot_op(loc, cond, dest, v, fall, fallv) - blk = getindex(b.blocks, b.insertion) - push!(blk, op) - return op -end - -function create!(b::JLIRBuilder, ::PiOp, loc, value::JLIR.Value, type::JLIR.Type) - @assert(isdefined(b, :blocks)) - op = create_pi_op(loc, value, type) - blk = getindex(b.blocks, b.insertion) - push!(blk, op) - return op -end - -function create!(b::JLIRBuilder, ::CallOp, loc, callee::JLIR.Value, arguments::Vector{JLIR.Value}, type::JLIR.Type) - @assert(isdefined(b, :blocks)) - op = create_call_op(loc, callee, arguments, type) - blk = getindex(b.blocks, b.insertion) - push!(blk, op) - return op -end - -##### -##### JLIR emission -##### - -function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value::GlobalRef, type) - name = value.name - v = getproperty(value.mod, value.name) - return create_constant_op(builder, loc, v, type) -end - -function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value::Core.SSAValue, type) - @assert(value.id >= 1) - return getindex(builder.values, value.id) -end - -function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value, type) - return create_unimplemented_op(builder, loc, type) -end - -function emit_ftype(builder::JLIRBuilder, ir_code::Core.Compiler.IRCode, ret_type) - argtypes = getfield(ir_code, :argtypes) - nargs = length(argtypes) - args = [convert_type_to_mlir(builder, a) for a in argtypes] - ret = convert_type_to_mlir(builder, ret_type) - return get_functype(builder, args, ret) -end - -function process_node!(b::JLIRBuilder) -end - -function walk_cfg_emit_branchargs(builder, - cfg::Core.Compiler.CFG, current_block::Int, - target_block::Int, stmts, types, loc::JLIR.Location) - v = JLIR.Value[] - for stmt in cfg.blocks[target].stmts - handle_node!(builder, v, stmt, loc) - end - return v -end - -function create_func_op(builder::JLIRBuilder, - ir_code::Core.Compiler.IRCode, ret::Type, name::String) - - # Setup. - irstream = ir_code.stmts - location_indices = getfield(irstream, :line) - linetable = getfield(ir_code, :linetable) - locations = extract_linetable_meta(builder, linetable) - argtypes = getfield(ir_code, :argtypes) - args = [convert_type_to_mlir(builder, a) for a in argtypes] - state = JLIR.create_operation_state(name, locations[1]) - entry_blk, reg = JLIR.add_entry_block!(state, args) - cfg = ir_code.cfg - cfg_blocks = cfg.blocks - nblocks = length(cfg_blocks) - blocks = JLIR.Block[JLIR.push_new_block!(reg) for _ in 1 : nblocks] - pushfirst!(blocks, entry_blk) - builder.blocks = blocks - stmts = irstream.inst - types = irstream.type - v = walk_cfg_emit_branchargs(builder, cfg, 1, 2, stmts, types, locations[0]) - goto = create_goto_op(Location(builder.ctx), blocks[2], v) - push!(builder, goto) - set_insertion!(builder, 2) - - # Process. - for (ind, (stmt, type)) in enumerate(zip(stmts, types)) - loc = linetable[ind] == 0 ? JLIR.Location() : locations[ind] - is_terminator = false - process_node!(builder, stmt, loc) - end - - # Create op from state and verify. - op = JLIR.Operation(state) - @assert(JLIR.verify(op)) - return op -end - -function create_func_op(ir_code::Core.Compiler.IRCode, ret::Type, name::String) - b = JLIRBuilder() - return create_func_op(b, ir_code, ret, name) -end diff --git a/include/brutus/brutus.h b/include/brutus/brutus.h index 886b8c5..dd832ca 100644 --- a/include/brutus/brutus.h +++ b/include/brutus/brutus.h @@ -27,6 +27,7 @@ extern "C" void brutus_register_dialects(MlirContext Context); void brutus_register_extern_dialect(MlirContext Context, MlirDialect Dialect); MlirType brutus_get_juliatype(MlirContext context, jl_datatype_t *datatype); + MlirValue brutus_get_juliavalueattr(MlirContext context, jl_value_t *value); // Export C API for pipeline. typedef void (*ExecutionEngineFPtrResult)(void **); diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index e0b317f..5fb7b06 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -488,6 +488,14 @@ extern "C" mlir::Type type = JuliaType::get(ctx, datatype); return wrap(type); }; + + MlirValue brutus_get_juliavalueattr(MlirContext Context, + jl_value_t *value) + { + mlir::MLIRContext *ctx = unwrap(Context); + mlir::Value val = JuliaValueAttr::get(ctx, value); + return wrap(val); + }; void brutus_codegen_jlir(MlirContext Context, MlirModule Module, From 58703c769875ad6b3cea3d41164582c08b7a4b63 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Tue, 20 Apr 2021 23:31:27 -0400 Subject: [PATCH 03/13] Added JLIRBuilder and working on codegen in Julia. --- Brutus/src/compiler/opbuilder.jl | 6 +++--- include/brutus/brutus.h | 4 ++-- lib/Codegen/Codegen.cpp | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl index 09943ac..4681c26 100644 --- a/Brutus/src/compiler/opbuilder.jl +++ b/Brutus/src/compiler/opbuilder.jl @@ -34,7 +34,7 @@ end function convert_type_to_jlirtype(builder::JLIRBuilder, a) ctx = builder.ctx - return ccall((:brutus_get_juliatype, "libbrutus"), + return ccall((:brutus_get_jlirtype, "libbrutus"), JLIR.Type, (JLIR.Context, Any), ctx, a) @@ -42,8 +42,8 @@ end function convert_value_to_jlirattr(builder::JLIRBuilder, a) ctx = builder.ctx - return ccall((:brutus_get_juliavalueattr, "libbrutus"), - JLIR.Value, + return ccall((:brutus_get_jlirattr, "libbrutus"), + JLIR.Attribute, (JLIR.Context, Any), ctx, a) end diff --git a/include/brutus/brutus.h b/include/brutus/brutus.h index dd832ca..8ccbaa2 100644 --- a/include/brutus/brutus.h +++ b/include/brutus/brutus.h @@ -26,8 +26,8 @@ extern "C" #endif void brutus_register_dialects(MlirContext Context); void brutus_register_extern_dialect(MlirContext Context, MlirDialect Dialect); - MlirType brutus_get_juliatype(MlirContext context, jl_datatype_t *datatype); - MlirValue brutus_get_juliavalueattr(MlirContext context, jl_value_t *value); + MlirType brutus_get_jlirtype(MlirContext context, jl_datatype_t *datatype); + MlirAttribute brutus_get_jlirattr(MlirContext context, jl_value_t *value); // Export C API for pipeline. typedef void (*ExecutionEngineFPtrResult)(void **); diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index 5fb7b06..ab8fac5 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -481,7 +481,7 @@ extern "C" ctx->getOrLoadDialect(); }; - MlirType brutus_get_juliatype(MlirContext Context, + MlirType brutus_get_jlirtype(MlirContext Context, jl_datatype_t *datatype) { mlir::MLIRContext *ctx = unwrap(Context); @@ -489,11 +489,11 @@ extern "C" return wrap(type); }; - MlirValue brutus_get_juliavalueattr(MlirContext Context, + MlirAttribute brutus_get_jlirattr(MlirContext Context, jl_value_t *value) { mlir::MLIRContext *ctx = unwrap(Context); - mlir::Value val = JuliaValueAttr::get(ctx, value); + mlir::Attribute val = JuliaValueAttr::get(ctx, value); return wrap(val); }; From 7ffce892ca75b81d929268b3fd177d24d90d48ae Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Wed, 21 Apr 2021 19:39:30 -0400 Subject: [PATCH 04/13] Creating Goto ops segfaults for any number of rreasons. Why Julia, why. --- Brutus/Project.toml | 1 + Brutus/scratch/juliacodegen.jl | 2 +- Brutus/src/compiler/Compiler.jl | 1 + Brutus/src/compiler/codegen.jl | 111 ++++++++++++++++++------------- Brutus/src/compiler/jlirgen.jl | 78 +++++++++++----------- Brutus/src/compiler/opbuilder.jl | 53 ++++++++++++++- Brutus/src/reflection.jl | 1 - lib/Codegen/Codegen.cpp | 10 +-- 8 files changed, 161 insertions(+), 96 deletions(-) diff --git a/Brutus/Project.toml b/Brutus/Project.toml index 9520d08..2068af9 100644 --- a/Brutus/Project.toml +++ b/Brutus/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" [deps] GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" MLIR = "bfde9dd4-8f40-4a1e-be09-1475335e1c92" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] julia = "1.5" diff --git a/Brutus/scratch/juliacodegen.jl b/Brutus/scratch/juliacodegen.jl index 5d4279a..3d91620 100644 --- a/Brutus/scratch/juliacodegen.jl +++ b/Brutus/scratch/juliacodegen.jl @@ -13,6 +13,6 @@ end mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) ir_code, ret = Brutus.code_ircode(mi) -ft = Brutus.Compiler.create_func_op(ir_code, ret, "gauss") +mod = Brutus.Compiler.emit_jlir(ir_code, ret, "gauss") end # module diff --git a/Brutus/src/compiler/Compiler.jl b/Brutus/src/compiler/Compiler.jl index 1a27bc3..ec58fc7 100644 --- a/Brutus/src/compiler/Compiler.jl +++ b/Brutus/src/compiler/Compiler.jl @@ -2,6 +2,7 @@ module Compiler using MLIR import MLIR.IR as JLIR +import Base: push! include("opbuilder.jl") include("jlirgen.jl") diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 3daac3d..080f4e8 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -4,81 +4,98 @@ # This is the Julia interface between Julia's IRCode and JLIR. -function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value::GlobalRef, type) +function emit_value(b::JLIRBuilder, loc::JLIR.Location, value::GlobalRef, type) name = value.name v = getproperty(value.mod, value.name) - return create_constant_op(builder, loc, v, type) + return create_constant_op(b, loc, v, type) end -function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value::Core.SSAValue, type) +function emit_value(b::JLIRBuilder, loc::JLIR.Location, value::Core.SSAValue, type) @assert(value.id >= 1) - return getindex(builder.values, value.id) + return getindex(b.values, value.id) end -function emit_value(builder::JLIRBuilder, loc::JLIR.Location, value, type) - return create_unimplemented_op(builder, loc, type) +function emit_value(b::JLIRBuilder, loc::JLIR.Location, value, type) + return create_unimplemented_op(b, loc, type) end -function emit_ftype(builder::JLIRBuilder, ir_code::Core.Compiler.IRCode, ret_type) +function emit_ftype(b::JLIRBuilder, ret_type) argtypes = getfield(ir_code, :argtypes) nargs = length(argtypes) - args = [convert_type_to_mlir(builder, a) for a in argtypes] - ret = convert_type_to_mlir(builder, ret_type) - return get_functype(builder, args, ret) + args = [convert_type_to_jlirtype(b, a) for a in argtypes] + ret = convert_type_to_jlirtype(b, ret_type) + return get_functype(b, args, ret) end -function process_node!(b::JLIRBuilder) +function handle_node!(b::JLIRBuilder, current::Int, + v::Vector{JLIR.Value}, stmt::Core.PhiNode, + type::Type, loc::JLIR.Location) + edges = stmt.edges + values = stmt.values + found = false + for (v, e) in zip(edges, values) + if e == current + val = emit_value(b, loc, v) + push!(v, maybe_widen_type(b, loc, val, type)) + found = true + end + end + if !found + op = create!(b, UndefOp(), loc, convert_type_to_jlirtype(b, type)) + push!(v, JLIR.get_result(op, 0)) + end end -function walk_cfg_emit_branchargs(builder, - cfg::Core.Compiler.CFG, current_block::Int, - target_block::Int, stmts, types, loc::JLIR.Location) +function walk_cfg_emit_branchargs(b::JLIRBuilder, current::Int, + target::Int, loc::JLIR.Location) v = JLIR.Value[] - for stmt in cfg.blocks[target].stmts - handle_node!(builder, v, stmt, loc) + cfg = get_cfg(b) + for ind in cfg.blocks[target].stmts + stmt = get_stmt(b, ind) + stmt isa Core.PhiNode || break + type = get_type(b, ind) + handle_node!(b, v, current, stmt, type, loc) end return v end -function emit_jlir(builder::JLIRBuilder, - ir_code::Core.Compiler.IRCode, ret::Type, name::String) +function emit_op!(b::JLIRBuilder, code::Core.Compiler.IRCode, + stmt::Core.GotoIfNot, loc::JLIR.Location, ret::Type) + label = stmt.label + v = walk_cfg_emit_branchargs(b, b.insertion, label, loc) + create!(b, GotoOp(), loc, b.blocks[label], v) + return true +end - # Setup. - irstream = ir_code.stmts - location_indices = getfield(irstream, :line) - linetable = getfield(ir_code, :linetable) - locations = extract_linetable_meta(builder, linetable) - argtypes = getfield(ir_code, :argtypes) - args = [convert_type_to_mlir(builder, a) for a in argtypes] - state = JLIR.create_operation_state(name, locations[1]) - entry_blk, reg = JLIR.add_entry_block!(state, args) - cfg = ir_code.cfg - cfg_blocks = cfg.blocks - nblocks = length(cfg_blocks) - blocks = JLIR.Block[JLIR.push_new_block!(reg) for _ in 1 : nblocks] - pushfirst!(blocks, entry_blk) - builder.blocks = blocks - stmts = irstream.inst - types = irstream.type - v = walk_cfg_emit_branchargs(builder, cfg, 1, 2, stmts, types, locations[0]) - goto = create_goto_op(Location(builder.ctx), blocks[2], v) - push!(builder, goto) - set_insertion!(builder, 2) +function emit_op!(b::JLIRBuilder, stmt::Core.ReturnNode, + loc::JLIR.Location, ret::Type) + if isdefined(stmt, :val) + value = maybe_widen_type(b, loc, emit_value(b, loc, stmt.val), ret) + else + value = create!(b, UndefOp(), loc) + end + create!(b, ReturnOp(), loc, value) + return true +end +function emit_jlir(ir_code::Core.Compiler.IRCode, ret::Type, name::String) + + # Create builder. + b = JLIRBuilder(ir_code, name) + stmts = get_stmts(b) + types = get_types(b) + # Process. for (ind, (stmt, type)) in enumerate(zip(stmts, types)) - loc = linetable[ind] == 0 ? JLIR.Location() : locations[ind] + @assert(b.insertion <= nblocks) + lt_ind = location_indices[ind] + loc = lt_ind == 0 ? JLIR.Location() : locations[lt_ind] is_terminator = false - process_node!(builder, stmt, loc) + is_terminator = emit_op!(b, stmt, loc, ret) end # Create op from state and verify. - op = JLIR.Operation(state) + op = finish(b) @assert(JLIR.verify(op)) return op end - -function emit_jlir(ir_code::Core.Compiler.IRCode, ret::Type, name::String) - b = JLIRBuilder() - return create_func_op(b, ir_code, ret, name) -end diff --git a/Brutus/src/compiler/jlirgen.jl b/Brutus/src/compiler/jlirgen.jl index 42257b1..e0201e0 100644 --- a/Brutus/src/compiler/jlirgen.jl +++ b/Brutus/src/compiler/jlirgen.jl @@ -1,67 +1,67 @@ # TODO: In future, should be autogenerated from tablegen. -function create_unimplemented_op(loc::JLIR.Location, type) - state = JLIR.create_operation_state("jlir::unimplemented", loc) - JLIR.push_results!(state, 1, type) +function create_unimplemented_op(loc::JLIR.Location, type::JLIR.Type) + state = JLIR.create_operation_state("jlir.unimplemented", loc) + JLIR.push_results!(state, type) return JLIR.Operation(state) end -function create_constant_op(loc::JLIR.Location, value, type) - state = JLIR.create_operation_state("jlir::constant", loc) - JLIR.push_operands!(state, 1, value) - JLIR.push_results!(state, 1, type) +function create_constant_op(loc::JLIR.Location, value::JLIR.Value, type::JLIR.Type) + state = JLIR.create_operation_state("jlir.constant", loc) + JLIR.push_operands!(state, value) + JLIR.push_results!(state, type) return JLIR.Operation(state) end function create_goto_op(loc::JLIR.Location, blk::JLIR.Block, v::Vector{JLIR.Value}) - state = JLIR.create_operation_state("jlir::goto", loc) + state = JLIR.create_operation_state("jlir.goto", loc) JLIR.push_operands!(state, v) - JLIR.push_successors!(state, JLIR.Block[blk]) + JLIR.push_successors!(state, blk) return JLIR.Operation(state) end function create_gotoifnot_op(loc::JLIR.Location, cond::JLIR.Value, dest::JLIR.Block, v::Vector{JLIR.Value}, fall::JLIR.Block, fallv::Vector{JLIR.Value}) - state = JLIR.create_operation_state("jlir::gotoifnot", loc) - JLIR.push_operands!(state, JLIR.Value[cond]) + state = JLIR.create_operation_state("jlir.gotoifnot", loc) + JLIR.push_operands!(state, cond) JLIR.push_operands!(state, v) JLIR.push_operands!(state, fallv) - JLIR.push_successors!(state, JLIR.Block[blk]) - JLIR.push_successors!(state, JLIR.Block[fall]) + JLIR.push_successors!(state, blk) + JLIR.push_successors!(state, fall) return JLIR.Operation(state) end function create_pi_op(loc::JLIR.Location, input::JLIR.Type, type::JLIR.Type) - state = JLIR.create_operation_state("jlir::pi", loc) - JLIR.push_operands!(state, JLIR.Type[value]) - JLIR.push_results!(state, JLIR.Type[type]) + state = JLIR.create_operation_state("jlir.pi", loc) + JLIR.push_operands!(state, value) + JLIR.push_results!(state, type) return JLIR.Operation(state) end function create_return_op(loc::JLIR.Location, input::JLIR.Type) - state = JLIR.create_operation_state("jlir::return", loc) - JLIR.push_operands!(state, JLIR.Type[input]) + state = JLIR.create_operation_state("jlir.return", loc) + JLIR.push_operands!(state, input) return JLIR.Operation(state) end function create_call_op(loc::JLIR.Location, callee::JLIR.Type, arguments::Vector{JLIR.Type}, type::JLIR.Type) - state = JLIR.create_operation_state("jlir::call", loc) - operands = [callee, arguments...] + state = JLIR.create_operation_state("jlir.call", loc) + operands = JLIR.Value[callee, arguments...] JLIR.push_operands!(state, operands) - JLIR.push_results!(state, JLIR.Value[type]) + JLIR.push_results!(state, type) return JLIR.Operation(state) end function create_invoke_op(loc::JLIR.Location, mi::JLIR.Value, callee::JLIR.Value, arguments::Vector{JLIR.Value}, type::JLIR.Type) - state = JLIR.create_operation_state("jlir::invoke", loc) + state = JLIR.create_operation_state("jlir.invoke", loc) JLIR.push_operands!(state, JLIR.Value[mi, callee, arguments...]) - JLIR.push_results!(state, JLIR.Value[type]) + JLIR.push_results!(state, type) return JLIR.Operation(state) end @@ -82,9 +82,9 @@ function create!(b::JLIRBuilder, ::UnimplementedOp, loc::JLIR.Location, type::JLIR.Type) @assert(isdefined(b, :blocks)) op = create_unimplemented_op(loc, type) - JLIR.verify(op) + @assert(JLIR.verify(op)) blk = getindex(b.blocks, b.insertion) - push!(blk, op) + push_operation!(blk, op) return op end @@ -92,9 +92,9 @@ function create!(b::JLIRBuilder, ::ConstantOp, loc::JLIR.Location, value::JLIR.Value, type::JLIR.Type) @assert(isdefined(b, :blocks)) op = create_constant_op(loc, value, type) - JLIR.verify(op) + @assert(JLIR.verify(op)) blk = getindex(b.blocks, b.insertion) - push!(blk, op) + push_operation!(blk, op) return op end @@ -102,9 +102,9 @@ function create!(b::JLIRBuilder, ::GotoOp, loc::JLIR.Location, blk::JLIR.Block, v::Vector{JLIR.Value}) @assert(isdefined(b, :blocks)) op = create_goto_op(loc, blk, v) - JLIR.verify(op) + @assert(JLIR.verify(op)) blk = getindex(b.blocks, b.insertion) - push!(blk, op) + JLIR.push_operation!(blk, op) return op end @@ -113,9 +113,9 @@ function create!(b::JLIRBuilder, ::GotoIfNotOp, loc::JLIR.Location, fall::JLIR.Block, fallv::Vector{JLIR.Value}) @assert(isdefined(b, :blocks)) op = create_gotoifnot_op(loc, cond, dest, v, fall, fallv) - JLIR.verify(op) + @assert(JLIR.verify(op)) blk = getindex(b.blocks, b.insertion) - push!(blk, op) + push_operation!(blk, op) return op end @@ -123,9 +123,9 @@ function create!(b::JLIRBuilder, ::PiOp, loc::JLIR.Location, value::JLIR.Value, type::JLIR.Type) @assert(isdefined(b, :blocks)) op = create_pi_op(loc, value, type) - JLIR.verify(op) + @assert(JLIR.verify(op)) blk = getindex(b.blocks, b.insertion) - push!(blk, op) + push_operation!(blk, op) return op end @@ -133,9 +133,9 @@ function create!(b::JLIRBuilder, ::ReturnOp, loc::JLIR.Location, input::JLIR.Type) @assert(isdefined(b, :blocks)) op = create_return_op(loc, input) - JLIR.verify(op) + @assert(JLIR.verify(op)) blk = getindex(b.blocks, b.insertion) - push!(blk, op) + push_operation!(blk, op) return op end @@ -144,9 +144,9 @@ function create!(b::JLIRBuilder, ::CallOp, loc::JLIR.Location, type::JLIR.Type) @assert(isdefined(b, :blocks)) op = create_call_op(loc, callee, arguments, type) - JLIR.verify(op) + @assert(JLIR.verify(op)) blk = getindex(b.blocks, b.insertion) - push!(blk, op) + push_operation!(blk, op) return op end @@ -156,8 +156,8 @@ function create!(b::JLIRBuilder, ::InvokeOp, loc::JLIR.Location, @assert(isdefined(b, :blocks)) jlir_mi = convert_value_to_jlirattr(b, mi) op = create_invoke_op(loc, jlir_mi, callee, arguments, type) - JLIR.verify(op) + @assert(JLIR.verify(op)) blk = getindex(b.blocks, b.insertion) - push!(blk, op) + push_operation!(blk, op) return op end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl index 4681c26..e8187b9 100644 --- a/Brutus/src/compiler/opbuilder.jl +++ b/Brutus/src/compiler/opbuilder.jl @@ -6,28 +6,75 @@ mutable struct JLIRBuilder ctx::JLIR.Context - values::Vector{JLIR.Value} - arguments::Vector{JLIR.Value} insertion::Int + state::JLIR.OperationState + values::Vector{JLIR.Value} + arguments::Vector{JLIR.Type} + locations::Vector{JLIR.Location} blocks::Vector{JLIR.Block} + code::Core.Compiler.IRCode function JLIRBuilder() ctx = JLIR.create_context() ccall((:brutus_register_dialects, "libbrutus"), Cvoid, (JLIR.Context, ), ctx) - new(ctx, JLIR.Value[], JLIR.Value[], 1) + new(ctx, 1) + end +end + +function JLIRBuilder(code::Core.Compiler.IRCode, name::String) + b = JLIRBuilder() + irstream = code.stmts + stmts = irstream.inst + types = irstream.type + location_indices = getfield(irstream, :line) + linetable = getfield(code, :linetable) + locations = extract_linetable_meta(b, linetable) + argtypes = getfield(code, :argtypes) + args = [convert_type_to_jlirtype(b, a) for a in argtypes] + state = JLIR.create_operation_state(name, locations[1]) + entry_blk, reg = JLIR.add_entry_block!(state, args) + tr = JLIR.get_first_block(reg) + nblocks = length(code.cfg.blocks) + blocks = JLIR.Block[entry_blk] + for i in 1 : nblocks + blk = JLIR.Block() + JLIR.insertafter!(reg, entry_blk, blk) + push!(blocks, blk) end + b.state = state + b.arguments = args + b.locations = locations + b.blocks = blocks + b.state = state + b.code = code + b.arguments = args + v = walk_cfg_emit_branchargs(b, 1, 2, locations[1]) + goto = create!(b, GotoOp(), JLIR.Location(b.ctx), blocks[2], v) + set_insertion!(b, 2) + return b end set_insertion!(b::JLIRBuilder, blk::Int) = b.insertion = blk +get_stmts(b::JLIRBuilder) = b.code.stmts.inst +get_types(b::JLIRBuilder) = b.code.stmts.type +get_stmt(b::JLIRBuilder, ind::Int) = getindex(b.code.stmts.inst, ind) +get_type(b::JLIRBuilder, ind::Int) = getindex(b.code.stmts.type, ind) +function get_cfg(b::JLIRBuilder) + @assert(isdefined(b, :code)) + return b.code.cfg +end + function push!(b::JLIRBuilder, op::JLIR.Operation) @assert(isdefined(b, :blocks)) blk = b.blocks[b.insertion] push_operation!(blk, op) end +finish(b::JLIRBuilder) = JLIR.Operation(b.state) + ##### ##### Utilities ##### diff --git a/Brutus/src/reflection.jl b/Brutus/src/reflection.jl index 17870df..b311993 100644 --- a/Brutus/src/reflection.jl +++ b/Brutus/src/reflection.jl @@ -4,7 +4,6 @@ function get_methodinstance(@nospecialize(sig); ms = Base._methods_by_ftype(sig, 1, Base.get_world_counter()) @assert length(ms) == 1 m = ms[1] - display(m) mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any), m[3], m[1], m[2]) diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index ab8fac5..e7e6ce6 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -645,11 +645,11 @@ extern "C" enum DumpOption { - // DUMP_IRCODE = 0, - DUMP_TRANSLATED = 1, - DUMP_CANONICALIZED = 2, - DUMP_LOWERED_TO_STD = 4, - DUMP_LOWERED_TO_LLVM = 8, + // DUMP_IRCODE = 0, + DUMP_TRANSLATED = 1, + DUMP_CANONICALIZED = 2, + DUMP_LOWERED_TO_STD = 4, + DUMP_LOWERED_TO_LLVM = 8, DUMP_TRANSLATE_TO_LLVM = 16, }; From 0132e2a8ae253d96a14c33d5687ad69448b892b0 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Wed, 21 Apr 2021 23:17:28 -0400 Subject: [PATCH 05/13] Running into Goto verify issues. --- Brutus/src/compiler/codegen.jl | 22 ++++++++++++---------- Brutus/src/compiler/jlirgen.jl | 30 +++++++++++++++++++++++------- Brutus/src/compiler/opbuilder.jl | 9 +++------ 3 files changed, 38 insertions(+), 23 deletions(-) diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 080f4e8..6072438 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -4,18 +4,21 @@ # This is the Julia interface between Julia's IRCode and JLIR. -function emit_value(b::JLIRBuilder, loc::JLIR.Location, value::GlobalRef, type) +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value::GlobalRef, type) name = value.name v = getproperty(value.mod, value.name) return create_constant_op(b, loc, v, type) end -function emit_value(b::JLIRBuilder, loc::JLIR.Location, value::Core.SSAValue, type) +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value::Core.SSAValue, type) @assert(value.id >= 1) return getindex(b.values, value.id) end -function emit_value(b::JLIRBuilder, loc::JLIR.Location, value, type) +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value, type) return create_unimplemented_op(b, loc, type) end @@ -86,13 +89,12 @@ function emit_jlir(ir_code::Core.Compiler.IRCode, ret::Type, name::String) types = get_types(b) # Process. - for (ind, (stmt, type)) in enumerate(zip(stmts, types)) - @assert(b.insertion <= nblocks) - lt_ind = location_indices[ind] - loc = lt_ind == 0 ? JLIR.Location() : locations[lt_ind] - is_terminator = false - is_terminator = emit_op!(b, stmt, loc, ret) - end + #for (ind, (stmt, type)) in enumerate(zip(stmts, types)) + # lt_ind = location_indices[ind] + # loc = lt_ind == 0 ? JLIR.Location() : locations[lt_ind] + # is_terminator = false + # is_terminator = emit_op!(b, stmt, loc, ret) + #end # Create op from state and verify. op = finish(b) diff --git a/Brutus/src/compiler/jlirgen.jl b/Brutus/src/compiler/jlirgen.jl index e0201e0..1a4fc42 100644 --- a/Brutus/src/compiler/jlirgen.jl +++ b/Brutus/src/compiler/jlirgen.jl @@ -6,6 +6,11 @@ function create_unimplemented_op(loc::JLIR.Location, type::JLIR.Type) return JLIR.Operation(state) end +function create_undef_op(loc::JLIR.Location, type::JLIR.Type) + state = JLIR.create_operation_state("jlir.undef", loc) + return JLIR.Operation(state) +end + function create_constant_op(loc::JLIR.Location, value::JLIR.Value, type::JLIR.Type) state = JLIR.create_operation_state("jlir.constant", loc) JLIR.push_operands!(state, value) @@ -13,11 +18,11 @@ function create_constant_op(loc::JLIR.Location, value::JLIR.Value, type::JLIR.Ty return JLIR.Operation(state) end -function create_goto_op(loc::JLIR.Location, blk::JLIR.Block, - v::Vector{JLIR.Value}) +function create_goto_op(loc::JLIR.Location, from::JLIR.Block, + to::JLIR.Block, v::Vector{JLIR.Value}) state = JLIR.create_operation_state("jlir.goto", loc) JLIR.push_operands!(state, v) - JLIR.push_successors!(state, blk) + JLIR.push_successors!(state, to) return JLIR.Operation(state) end @@ -70,6 +75,7 @@ end ##### struct UnimplementedOp end +struct UndefOp end struct ConstantOp end struct GotoOp end struct GotoIfNotOp end @@ -88,6 +94,15 @@ function create!(b::JLIRBuilder, ::UnimplementedOp, loc::JLIR.Location, return op end +function create!(b::JLIRBuilder, ::UndefOp, loc::JLIR.Location) + @assert(isdefined(b, :blocks)) + op = create_undef_op(loc, type) + @assert(JLIR.verify(op)) + blk = getindex(b.blocks, b.insertion) + push_operation!(blk, op) + return op +end + function create!(b::JLIRBuilder, ::ConstantOp, loc::JLIR.Location, value::JLIR.Value, type::JLIR.Type) @assert(isdefined(b, :blocks)) @@ -99,12 +114,13 @@ function create!(b::JLIRBuilder, ::ConstantOp, loc::JLIR.Location, end function create!(b::JLIRBuilder, ::GotoOp, loc::JLIR.Location, - blk::JLIR.Block, v::Vector{JLIR.Value}) + to::JLIR.Block, v::Vector{JLIR.Value}) @assert(isdefined(b, :blocks)) - op = create_goto_op(loc, blk, v) + from = get_insertion_block(b) + op = create_goto_op(loc, from, to, v) + JLIR.dump(op) @assert(JLIR.verify(op)) - blk = getindex(b.blocks, b.insertion) - JLIR.push_operation!(blk, op) + JLIR.push_operation!(from, op) return op end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl index e8187b9..1e4d007 100644 --- a/Brutus/src/compiler/opbuilder.jl +++ b/Brutus/src/compiler/opbuilder.jl @@ -56,19 +56,16 @@ function JLIRBuilder(code::Core.Compiler.IRCode, name::String) return b end -set_insertion!(b::JLIRBuilder, blk::Int) = b.insertion = blk +set_insertion!(b::JLIRBuilder, blk::Int) = setfield!(b, :insertion, blk) +get_insertion_block(b::JLIRBuilder) = b.blocks[b.insertion] get_stmts(b::JLIRBuilder) = b.code.stmts.inst get_types(b::JLIRBuilder) = b.code.stmts.type get_stmt(b::JLIRBuilder, ind::Int) = getindex(b.code.stmts.inst, ind) get_type(b::JLIRBuilder, ind::Int) = getindex(b.code.stmts.type, ind) -function get_cfg(b::JLIRBuilder) - @assert(isdefined(b, :code)) - return b.code.cfg -end +get_cfg(b::JLIRBuilder) = b.code.cfg function push!(b::JLIRBuilder, op::JLIR.Operation) - @assert(isdefined(b, :blocks)) blk = b.blocks[b.insertion] push_operation!(blk, op) end From bf0197692ab47237ee89a69685bf738677109e2a Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Thu, 22 Apr 2021 17:18:30 -0400 Subject: [PATCH 06/13] Switched to immutable builder. --- Brutus/scratch/juliacodegen.jl | 8 ++--- Brutus/src/compiler/codegen.jl | 17 +++++++--- Brutus/src/compiler/jlirgen.jl | 21 +++++------- Brutus/src/compiler/opbuilder.jl | 56 ++++++++++++-------------------- 4 files changed, 45 insertions(+), 57 deletions(-) diff --git a/Brutus/scratch/juliacodegen.jl b/Brutus/scratch/juliacodegen.jl index 3d91620..0a8ab37 100644 --- a/Brutus/scratch/juliacodegen.jl +++ b/Brutus/scratch/juliacodegen.jl @@ -4,11 +4,11 @@ using Brutus using MLIR function gauss(N) - acc = 0 - for i in 1:N - acc += i + k = 0 + for i in 1 : N + k += i end - return acc + return k end mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 6072438..6daae6b 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -82,13 +82,20 @@ function emit_op!(b::JLIRBuilder, stmt::Core.ReturnNode, end function emit_jlir(ir_code::Core.Compiler.IRCode, ret::Type, name::String) - + GC.enable(false) + # Create builder. b = JLIRBuilder(ir_code, name) - stmts = get_stmts(b) - types = get_types(b) + + # Create branch from entry block. + v = walk_cfg_emit_branchargs(b, 1, 2, b.locations[1]) + goto = create_goto_op(JLIR.Location(b.ctx), b.blocks[1], b.blocks[2], v) + push!(b.blocks[1], goto) + JLIR.dump(goto) # Process. + #stmts = get_stmts(b) + #types = get_types(b) #for (ind, (stmt, type)) in enumerate(zip(stmts, types)) # lt_ind = location_indices[ind] # loc = lt_ind == 0 ? JLIR.Location() : locations[lt_ind] @@ -97,7 +104,7 @@ function emit_jlir(ir_code::Core.Compiler.IRCode, ret::Type, name::String) #end # Create op from state and verify. - op = finish(b) - @assert(JLIR.verify(op)) + op = JLIR.Operation(b.state) + GC.enable(true) return op end diff --git a/Brutus/src/compiler/jlirgen.jl b/Brutus/src/compiler/jlirgen.jl index 1a4fc42..b009b0b 100644 --- a/Brutus/src/compiler/jlirgen.jl +++ b/Brutus/src/compiler/jlirgen.jl @@ -20,7 +20,7 @@ end function create_goto_op(loc::JLIR.Location, from::JLIR.Block, to::JLIR.Block, v::Vector{JLIR.Value}) - state = JLIR.create_operation_state("jlir.goto", loc) + state = JLIR.create_operation_state("std.br", loc) JLIR.push_operands!(state, v) JLIR.push_successors!(state, to) return JLIR.Operation(state) @@ -89,7 +89,7 @@ function create!(b::JLIRBuilder, ::UnimplementedOp, loc::JLIR.Location, @assert(isdefined(b, :blocks)) op = create_unimplemented_op(loc, type) @assert(JLIR.verify(op)) - blk = getindex(b.blocks, b.insertion) + blk = get_insertion_block(b) push_operation!(blk, op) return op end @@ -98,7 +98,7 @@ function create!(b::JLIRBuilder, ::UndefOp, loc::JLIR.Location) @assert(isdefined(b, :blocks)) op = create_undef_op(loc, type) @assert(JLIR.verify(op)) - blk = getindex(b.blocks, b.insertion) + blk = get_insertion_block(b) push_operation!(blk, op) return op end @@ -108,7 +108,7 @@ function create!(b::JLIRBuilder, ::ConstantOp, loc::JLIR.Location, @assert(isdefined(b, :blocks)) op = create_constant_op(loc, value, type) @assert(JLIR.verify(op)) - blk = getindex(b.blocks, b.insertion) + blk = get_insertion_block(b) push_operation!(blk, op) return op end @@ -118,8 +118,6 @@ function create!(b::JLIRBuilder, ::GotoOp, loc::JLIR.Location, @assert(isdefined(b, :blocks)) from = get_insertion_block(b) op = create_goto_op(loc, from, to, v) - JLIR.dump(op) - @assert(JLIR.verify(op)) JLIR.push_operation!(from, op) return op end @@ -129,8 +127,7 @@ function create!(b::JLIRBuilder, ::GotoIfNotOp, loc::JLIR.Location, fall::JLIR.Block, fallv::Vector{JLIR.Value}) @assert(isdefined(b, :blocks)) op = create_gotoifnot_op(loc, cond, dest, v, fall, fallv) - @assert(JLIR.verify(op)) - blk = getindex(b.blocks, b.insertion) + blk = get_insertion_block(b) push_operation!(blk, op) return op end @@ -140,7 +137,7 @@ function create!(b::JLIRBuilder, ::PiOp, loc::JLIR.Location, @assert(isdefined(b, :blocks)) op = create_pi_op(loc, value, type) @assert(JLIR.verify(op)) - blk = getindex(b.blocks, b.insertion) + blk = get_insertion_block(b) push_operation!(blk, op) return op end @@ -150,7 +147,7 @@ function create!(b::JLIRBuilder, ::ReturnOp, loc::JLIR.Location, @assert(isdefined(b, :blocks)) op = create_return_op(loc, input) @assert(JLIR.verify(op)) - blk = getindex(b.blocks, b.insertion) + blk = get_insertion_block(b) push_operation!(blk, op) return op end @@ -161,7 +158,7 @@ function create!(b::JLIRBuilder, ::CallOp, loc::JLIR.Location, @assert(isdefined(b, :blocks)) op = create_call_op(loc, callee, arguments, type) @assert(JLIR.verify(op)) - blk = getindex(b.blocks, b.insertion) + blk = get_insertion_block(b) push_operation!(blk, op) return op end @@ -173,7 +170,7 @@ function create!(b::JLIRBuilder, ::InvokeOp, loc::JLIR.Location, jlir_mi = convert_value_to_jlirattr(b, mi) op = create_invoke_op(loc, jlir_mi, callee, arguments, type) @assert(JLIR.verify(op)) - blk = getindex(b.blocks, b.insertion) + blk = get_insertion_block(b) push_operation!(blk, op) return op end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl index 1e4d007..37f5ebf 100644 --- a/Brutus/src/compiler/opbuilder.jl +++ b/Brutus/src/compiler/opbuilder.jl @@ -4,35 +4,31 @@ # High-level version of MLIR's OpBuilder. -mutable struct JLIRBuilder +struct JLIRBuilder ctx::JLIR.Context - insertion::Int - state::JLIR.OperationState + insertion::Ref{Int} values::Vector{JLIR.Value} arguments::Vector{JLIR.Type} locations::Vector{JLIR.Location} blocks::Vector{JLIR.Block} code::Core.Compiler.IRCode - function JLIRBuilder() - ctx = JLIR.create_context() - ccall((:brutus_register_dialects, "libbrutus"), - Cvoid, - (JLIR.Context, ), - ctx) - new(ctx, 1) - end + state::JLIR.OperationState end function JLIRBuilder(code::Core.Compiler.IRCode, name::String) - b = JLIRBuilder() + ctx = JLIR.create_context() + ccall((:brutus_register_dialects, "libbrutus"), + Cvoid, + (JLIR.Context, ), + ctx) irstream = code.stmts stmts = irstream.inst types = irstream.type location_indices = getfield(irstream, :line) linetable = getfield(code, :linetable) - locations = extract_linetable_meta(b, linetable) + locations = extract_linetable_locations(ctx, linetable) argtypes = getfield(code, :argtypes) - args = [convert_type_to_jlirtype(b, a) for a in argtypes] + args = [convert_type_to_jlirtype(ctx, a) for a in argtypes] state = JLIR.create_operation_state(name, locations[1]) entry_blk, reg = JLIR.add_entry_block!(state, args) tr = JLIR.get_first_block(reg) @@ -40,24 +36,14 @@ function JLIRBuilder(code::Core.Compiler.IRCode, name::String) blocks = JLIR.Block[entry_blk] for i in 1 : nblocks blk = JLIR.Block() - JLIR.insertafter!(reg, entry_blk, blk) + JLIR.push!(reg, blk) push!(blocks, blk) end - b.state = state - b.arguments = args - b.locations = locations - b.blocks = blocks - b.state = state - b.code = code - b.arguments = args - v = walk_cfg_emit_branchargs(b, 1, 2, locations[1]) - goto = create!(b, GotoOp(), JLIR.Location(b.ctx), blocks[2], v) - set_insertion!(b, 2) - return b + return JLIRBuilder(ctx, Ref(2), JLIR.Value[], args, locations, blocks, code, state) end -set_insertion!(b::JLIRBuilder, blk::Int) = setfield!(b, :insertion, blk) -get_insertion_block(b::JLIRBuilder) = b.blocks[b.insertion] +set_insertion!(b::JLIRBuilder, blk::Int) = b.insertion[] = blk +get_insertion_block(b::JLIRBuilder) = b.blocks[b.insertion[]] get_stmts(b::JLIRBuilder) = b.code.stmts.inst get_types(b::JLIRBuilder) = b.code.stmts.type @@ -76,16 +62,14 @@ finish(b::JLIRBuilder) = JLIR.Operation(b.state) ##### Utilities ##### -function convert_type_to_jlirtype(builder::JLIRBuilder, a) - ctx = builder.ctx +function convert_type_to_jlirtype(ctx::JLIR.Context, a) return ccall((:brutus_get_jlirtype, "libbrutus"), JLIR.Type, (JLIR.Context, Any), ctx, a) end -function convert_value_to_jlirattr(builder::JLIRBuilder, a) - ctx = builder.ctx +function convert_value_to_jlirattr(ctx::JLIR.Context, a) return ccall((:brutus_get_jlirattr, "libbrutus"), JLIR.Attribute, (JLIR.Context, Any), @@ -98,8 +82,8 @@ end function get_functype(builder::JLIRBuilder, args, ret) return get_functype(builder, length(args), map(args) do a - convert_type_to_jlirtype(builder, a) - end, 1, [convert_type_to_jlirtype(builder, ret)]) + convert_type_to_jlirtype(builder.ctx, a) + end, 1, [convert_type_to_jlirtype(builder.ctx, ret)]) end function unwrap(mi::Core.MethodInstance) @@ -107,7 +91,7 @@ function unwrap(mi::Core.MethodInstance) end unwrap(s) = s -function extract_linetable_meta(builder::JLIRBuilder, v::Vector{Core.LineInfoNode}) +function extract_linetable_locations(ctx::JLIR.Context, v::Vector{Core.LineInfoNode}) locations = JLIR.Location[] for n in v method = unwrap(n.method) @@ -120,7 +104,7 @@ function extract_linetable_meta(builder::JLIRBuilder, v::Vector{Core.LineInfoNod if method isa Symbol fname = String(method) end - current = JLIR.Location(builder.ctx, fname, UInt32(line), UInt32(0)) # TODO: col. + current = JLIR.Location(ctx, fname, UInt32(line), UInt32(0)) # TODO: col. if inlined_at > 0 current = JLIR.Location(current, locations[inlined_at - 1]) end From 0569afc95313b621ba54cd4ff63f145a1267d6c8 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Fri, 23 Apr 2021 18:12:49 -0400 Subject: [PATCH 07/13] Updated pipeline -- it's actually almost working. There's a few things left to verify correctness with simple tests. --- Brutus/scratch/juliacodegen.jl | 26 +++-- Brutus/src/compiler/codegen.jl | 184 +++++++++++++++++++++++-------- Brutus/src/compiler/jlirgen.jl | 82 +++++++------- Brutus/src/compiler/opbuilder.jl | 36 ++++-- include/brutus/brutus.h | 8 +- lib/Codegen/Codegen.cpp | 15 ++- 6 files changed, 246 insertions(+), 105 deletions(-) diff --git a/Brutus/scratch/juliacodegen.jl b/Brutus/scratch/juliacodegen.jl index 0a8ab37..3bfc89d 100644 --- a/Brutus/scratch/juliacodegen.jl +++ b/Brutus/scratch/juliacodegen.jl @@ -3,16 +3,22 @@ module JuliaCodegen using Brutus using MLIR -function gauss(N) - k = 0 - for i in 1 : N - k += i - end - return k -end +brutus_id(N) = N -mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) -ir_code, ret = Brutus.code_ircode(mi) -mod = Brutus.Compiler.emit_jlir(ir_code, ret, "gauss") +mi = Brutus.get_methodinstance(Tuple{typeof(brutus_id), Int}) +ir_code, rt = Brutus.code_ircode(mi) +mod = Brutus.Compiler.emit_jlir(ir_code, rt, "brutus_id") + +#function gauss(N) +# k = 0 +# for i in 1 : N +# k += i +# end +# return k +#end +# +#mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) +#ir_code, rt = Brutus.code_ircode(mi) +#mod = Brutus.Compiler.emit_jlir(ir_code, rt, "gauss") end # module diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 6daae6b..0837ebd 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -4,30 +4,53 @@ # This is the Julia interface between Julia's IRCode and JLIR. +function maybe_widen_type(b::JLIRBuilder, loc::JLIR.Location, + value::JLIR.Value, expected_type::Type) + type = convert_jlirvalue_to_type(value) + if (type != expected_type && type <: expected_type) + op = create!(b, PiOp(), loc, value, expected_type) + return JLIR.get_result(op, 0) + else + return value + end +end + function emit_value(b::JLIRBuilder, loc::JLIR.Location, - value::GlobalRef, type) - name = value.name - v = getproperty(value.mod, value.name) - return create_constant_op(b, loc, v, type) + value, type::Type) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + op = create!(b, UnimplementedOp(), loc, jlir_type) + return JLIR.get_result(op, 0) +end + +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value::Core.Argument, type::Type) + idx = value.n + return b.arguments[idx]; end function emit_value(b::JLIRBuilder, loc::JLIR.Location, - value::Core.SSAValue, type) + value::Core.SSAValue, type::Type) @assert(value.id >= 1) return getindex(b.values, value.id) end function emit_value(b::JLIRBuilder, loc::JLIR.Location, - value, type) - return create_unimplemented_op(b, loc, type) + value::GlobalRef, type::Type) + name = value.name + v = getproperty(value.mod, value.name) + jlir_attr = convert_value_to_jlirattr(b.ctx, v) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + op = create!(b, ConstantOp(), loc, jlir_attr, jlir_type) + return JLIR.get_result(op, 0) end -function emit_ftype(b::JLIRBuilder, ret_type) - argtypes = getfield(ir_code, :argtypes) +function emit_ftype(ctx::JLIR.Context, code::Core.Compiler.IRCode, ret_type::Type) + argtypes = getfield(code, :argtypes) nargs = length(argtypes) - args = [convert_type_to_jlirtype(b, a) for a in argtypes] - ret = convert_type_to_jlirtype(b, ret_type) - return get_functype(b, args, ret) + args = [convert_type_to_jlirtype(ctx, a) for a in argtypes] + ret = convert_type_to_jlirtype(ctx, ret_type) + jlir_func_type = get_functype(ctx, args, ret) + return jlir_func_type end function handle_node!(b::JLIRBuilder, current::Int, @@ -38,13 +61,14 @@ function handle_node!(b::JLIRBuilder, current::Int, found = false for (v, e) in zip(edges, values) if e == current - val = emit_value(b, loc, v) + val = emit_value(b, loc, v, Any) push!(v, maybe_widen_type(b, loc, val, type)) found = true end end if !found - op = create!(b, UndefOp(), loc, convert_type_to_jlirtype(b, type)) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + op = create!(b, UndefOp(), loc, jlir_type) push!(v, JLIR.get_result(op, 0)) end end @@ -53,58 +77,132 @@ function walk_cfg_emit_branchargs(b::JLIRBuilder, current::Int, target::Int, loc::JLIR.Location) v = JLIR.Value[] cfg = get_cfg(b) - for ind in cfg.blocks[target].stmts - stmt = get_stmt(b, ind) - stmt isa Core.PhiNode || break + for ind in cfg.blocks[target - 1].stmts + node = get_stmt(b, ind) + node isa Core.PhiNode || break type = get_type(b, ind) - handle_node!(b, v, current, stmt, type, loc) + handle_node!(b, current, v, node, type, loc) end return v end -function emit_op!(b::JLIRBuilder, code::Core.Compiler.IRCode, - stmt::Core.GotoIfNot, loc::JLIR.Location, ret::Type) +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Nothing, loc::JLIR.Location, type::Type) + return false +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt, loc::JLIR.Location, type::Type) + push!(b.values, emit_value(b, loc, stmt, type)) + return false +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.GotoNode, loc::JLIR.Location, type::Type) label = stmt.label - v = walk_cfg_emit_branchargs(b, b.insertion, label, loc) + v = walk_cfg_emit_branchargs(b, b.insertion[], label, loc) create!(b, GotoOp(), loc, b.blocks[label], v) return true end -function emit_op!(b::JLIRBuilder, stmt::Core.ReturnNode, - loc::JLIR.Location, ret::Type) +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.GotoIfNot, loc::JLIR.Location, type::Type) + cond = emit_value(b, loc, stmt.cond, Any) + dest = stmt.dest + op = create!(b, GotoIfNotOp(), loc, + cond, b.blocks[dest], + walk_cfg_emit_branchargs(b, b.insertion[], dest, loc), + b.blocks[b.insertion[] + 1], + walk_cfg_emit_branchargs(b, b.insertion[], + b.insertion[] + 1, loc)) + return true +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.PhiNode, loc::JLIR.Location, type::Type) + t = convert_type_to_jlirtype(b.ctx, type) + blk = get_insertion_block(b) + arg = ccall((:brutusBlockAddArgument, "libbrutus"), + JLIR.Value, + (JLIR.Block, JLIR.Type), + blk, t) + push!(b.values, arg) + return false +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.PiNode, loc::JLIR.Location, type::Type) + val = stmt.val + @assert(type == stmt.type) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + op = create!(b, PiOp(), loc, emit_value(b, loc, val, Any), jlir_type) + ctx.values[ind] = JLIR.get_result(op, 0) + return false +end + +function process_stmt!(b::JLIRBuilder, ind::Int, + stmt::Core.ReturnNode, loc::JLIR.Location, type::Type) if isdefined(stmt, :val) - value = maybe_widen_type(b, loc, emit_value(b, loc, stmt.val), ret) + v = emit_value(b, loc, stmt.val, Any) + value = maybe_widen_type(b, loc, v, type) else - value = create!(b, UndefOp(), loc) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + value = create!(b, UndefOp(), loc, jlir_type) end create!(b, ReturnOp(), loc, value) return true end -function emit_jlir(ir_code::Core.Compiler.IRCode, ret::Type, name::String) - GC.enable(false) - +function process_stmt!(b::JLIRBuilder, ind::Int, + expr::Expr, loc::JLIR.Location, type::Type) + head = expr.head + args = expr.args + jlir_type = convert_type_to_jlirtype(b.ctx, type) + if head == :invoke + @assert(args[1] isa Core.MethodInstance) + mi = args[1] + callee = emit_value(b, loc, args[2], Any) + args = JLIR.Value[emit_value(b, loc, a, Any) for a in args[2 : end]] + op = create!(b, InvokeOp, loc, mi, callee, args, jlir_type) + elseif head == :call + callee = emit_value(b, loc, args[1], Any) + args = JLIR.Value[emit_value(b, loc, a, Any) for a in args] + op = create!(b, CallOp(), loc, callee, args, jlir_type) + else + display(head) + op = create!(b, UnimplementedOp(), loc, jlir_type) + end + res = JLIR.get_result(op, 0) + push!(b.values, res) + return false +end + +function emit_jlir(ir_code::Core.Compiler.IRCode, rt::Type, name::String) # Create builder. - b = JLIRBuilder(ir_code, name) + b = JLIRBuilder(ir_code, rt, name) # Create branch from entry block. v = walk_cfg_emit_branchargs(b, 1, 2, b.locations[1]) - goto = create_goto_op(JLIR.Location(b.ctx), b.blocks[1], b.blocks[2], v) + goto = create_goto_op(JLIR.Location(b.ctx), b.blocks[2], v) push!(b.blocks[1], goto) - JLIR.dump(goto) - + # Process. - #stmts = get_stmts(b) - #types = get_types(b) - #for (ind, (stmt, type)) in enumerate(zip(stmts, types)) - # lt_ind = location_indices[ind] - # loc = lt_ind == 0 ? JLIR.Location() : locations[lt_ind] - # is_terminator = false - # is_terminator = emit_op!(b, stmt, loc, ret) - #end - + location_indices = get_locindices(b) + stmts = get_stmts(b) + types = get_types(b) + for (ind, (stmt, type)) in enumerate(zip(stmts, types)) + lt_ind = location_indices[ind] + loc = lt_ind == 0 ? JLIR.Location() : b.locations[lt_ind] + is_terminator = false + is_terminator = process_stmt!(b, ind, stmt, loc, type) + if is_terminator + b.insertion[] += 1 + end + end + # Create op from state and verify. - op = JLIR.Operation(b.state) - GC.enable(true) + op = finish(b) + JLIR.verify(op) + JLIR.dump(op) return op end diff --git a/Brutus/src/compiler/jlirgen.jl b/Brutus/src/compiler/jlirgen.jl index b009b0b..e031eee 100644 --- a/Brutus/src/compiler/jlirgen.jl +++ b/Brutus/src/compiler/jlirgen.jl @@ -8,19 +8,21 @@ end function create_undef_op(loc::JLIR.Location, type::JLIR.Type) state = JLIR.create_operation_state("jlir.undef", loc) + JLIR.push_results!(state, type) return JLIR.Operation(state) end -function create_constant_op(loc::JLIR.Location, value::JLIR.Value, type::JLIR.Type) +function create_constant_op(loc::JLIR.Location, named_attr::JLIR.NamedAttribute, + type::JLIR.Type) state = JLIR.create_operation_state("jlir.constant", loc) - JLIR.push_operands!(state, value) + JLIR.push_attributes!(state, named_attr) JLIR.push_results!(state, type) return JLIR.Operation(state) end -function create_goto_op(loc::JLIR.Location, from::JLIR.Block, - to::JLIR.Block, v::Vector{JLIR.Value}) - state = JLIR.create_operation_state("std.br", loc) +function create_goto_op(loc::JLIR.Location, to::JLIR.Block, + v::Vector{JLIR.Value}) + state = JLIR.create_operation_state("jlir.goto", loc) JLIR.push_operands!(state, v) JLIR.push_successors!(state, to) return JLIR.Operation(state) @@ -33,7 +35,7 @@ function create_gotoifnot_op(loc::JLIR.Location, cond::JLIR.Value, JLIR.push_operands!(state, cond) JLIR.push_operands!(state, v) JLIR.push_operands!(state, fallv) - JLIR.push_successors!(state, blk) + JLIR.push_successors!(state, dest) JLIR.push_successors!(state, fall) return JLIR.Operation(state) end @@ -46,14 +48,14 @@ function create_pi_op(loc::JLIR.Location, input::JLIR.Type, return JLIR.Operation(state) end -function create_return_op(loc::JLIR.Location, input::JLIR.Type) +function create_return_op(loc::JLIR.Location, input::JLIR.Value) state = JLIR.create_operation_state("jlir.return", loc) JLIR.push_operands!(state, input) return JLIR.Operation(state) end -function create_call_op(loc::JLIR.Location, callee::JLIR.Type, - arguments::Vector{JLIR.Type}, type::JLIR.Type) +function create_call_op(loc::JLIR.Location, callee::JLIR.Value, + arguments::Vector{JLIR.Value}, type::JLIR.Type) state = JLIR.create_operation_state("jlir.call", loc) operands = JLIR.Value[callee, arguments...] JLIR.push_operands!(state, operands) @@ -62,8 +64,7 @@ function create_call_op(loc::JLIR.Location, callee::JLIR.Type, end function create_invoke_op(loc::JLIR.Location, mi::JLIR.Value, - callee::JLIR.Value, arguments::Vector{JLIR.Value}, - type::JLIR.Type) + callee::JLIR.Value, arguments::Vector{JLIR.Value}, type::JLIR.Type) state = JLIR.create_operation_state("jlir.invoke", loc) JLIR.push_operands!(state, JLIR.Value[mi, callee, arguments...]) JLIR.push_results!(state, type) @@ -84,93 +85,94 @@ struct ReturnOp end struct CallOp end struct InvokeOp end -function create!(b::JLIRBuilder, ::UnimplementedOp, loc::JLIR.Location, - type::JLIR.Type) +function create!(b::JLIRBuilder, ::UnimplementedOp, + loc::JLIR.Location, type::JLIR.Type) @assert(isdefined(b, :blocks)) op = create_unimplemented_op(loc, type) @assert(JLIR.verify(op)) blk = get_insertion_block(b) - push_operation!(blk, op) + JLIR.push_operation!(blk, op) return op end -function create!(b::JLIRBuilder, ::UndefOp, loc::JLIR.Location) +function create!(b::JLIRBuilder, ::UndefOp, + loc::JLIR.Location, type::JLIR.Type) @assert(isdefined(b, :blocks)) op = create_undef_op(loc, type) @assert(JLIR.verify(op)) blk = get_insertion_block(b) - push_operation!(blk, op) + JLIR.push_operation!(blk, op) return op end -function create!(b::JLIRBuilder, ::ConstantOp, loc::JLIR.Location, - value::JLIR.Value, type::JLIR.Type) +function create!(b::JLIRBuilder, ::ConstantOp, + loc::JLIR.Location, value::JLIR.Attribute, type::JLIR.Type) @assert(isdefined(b, :blocks)) - op = create_constant_op(loc, value, type) + named_attr = JLIR.NamedAttribute(b.ctx, "value", value) + op = create_constant_op(loc, named_attr, type) @assert(JLIR.verify(op)) blk = get_insertion_block(b) - push_operation!(blk, op) + JLIR.push_operation!(blk, op) return op end -function create!(b::JLIRBuilder, ::GotoOp, loc::JLIR.Location, - to::JLIR.Block, v::Vector{JLIR.Value}) +function create!(b::JLIRBuilder, ::GotoOp, + loc::JLIR.Location, to::JLIR.Block, v::Vector{JLIR.Value}) @assert(isdefined(b, :blocks)) from = get_insertion_block(b) - op = create_goto_op(loc, from, to, v) + op = create_goto_op(loc, to, v) JLIR.push_operation!(from, op) return op end -function create!(b::JLIRBuilder, ::GotoIfNotOp, loc::JLIR.Location, - cond::JLIR.Value, dest::JLIR.Block, v::Vector{JLIR.Value}, - fall::JLIR.Block, fallv::Vector{JLIR.Value}) +function create!(b::JLIRBuilder, ::GotoIfNotOp, + loc::JLIR.Location, cond::JLIR.Value, dest::JLIR.Block, + v::Vector{JLIR.Value}, fall::JLIR.Block, fallv::Vector{JLIR.Value}) @assert(isdefined(b, :blocks)) op = create_gotoifnot_op(loc, cond, dest, v, fall, fallv) blk = get_insertion_block(b) - push_operation!(blk, op) + JLIR.push_operation!(blk, op) return op end -function create!(b::JLIRBuilder, ::PiOp, loc::JLIR.Location, - value::JLIR.Value, type::JLIR.Type) +function create!(b::JLIRBuilder, ::PiOp, + loc::JLIR.Location, value::JLIR.Value, type::JLIR.Type) @assert(isdefined(b, :blocks)) op = create_pi_op(loc, value, type) @assert(JLIR.verify(op)) blk = get_insertion_block(b) - push_operation!(blk, op) + JLIR.push_operation!(blk, op) return op end -function create!(b::JLIRBuilder, ::ReturnOp, loc::JLIR.Location, - input::JLIR.Type) +function create!(b::JLIRBuilder, ::ReturnOp, + loc::JLIR.Location, input::JLIR.Value) @assert(isdefined(b, :blocks)) op = create_return_op(loc, input) - @assert(JLIR.verify(op)) blk = get_insertion_block(b) - push_operation!(blk, op) + JLIR.push_operation!(blk, op) return op end -function create!(b::JLIRBuilder, ::CallOp, loc::JLIR.Location, - callee::JLIR.Value, arguments::Vector{JLIR.Value}, +function create!(b::JLIRBuilder, ::CallOp, + loc::JLIR.Location, callee::JLIR.Value, arguments::Vector{JLIR.Value}, type::JLIR.Type) @assert(isdefined(b, :blocks)) op = create_call_op(loc, callee, arguments, type) @assert(JLIR.verify(op)) blk = get_insertion_block(b) - push_operation!(blk, op) + JLIR.push_operation!(blk, op) return op end -function create!(b::JLIRBuilder, ::InvokeOp, loc::JLIR.Location, - mi::Core.MethodInstance, callee::JLIR.Value, +function create!(b::JLIRBuilder, ::InvokeOp, + loc::JLIR.Location, mi::Core.MethodInstance, callee::JLIR.Value, arguments::Vector{JLIR.Value}, type::JLIR.Type) @assert(isdefined(b, :blocks)) jlir_mi = convert_value_to_jlirattr(b, mi) op = create_invoke_op(loc, jlir_mi, callee, arguments, type) @assert(JLIR.verify(op)) blk = get_insertion_block(b) - push_operation!(blk, op) + JLIR.push_operation!(blk, op) return op end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl index 37f5ebf..fd34483 100644 --- a/Brutus/src/compiler/opbuilder.jl +++ b/Brutus/src/compiler/opbuilder.jl @@ -11,11 +11,12 @@ struct JLIRBuilder arguments::Vector{JLIR.Type} locations::Vector{JLIR.Location} blocks::Vector{JLIR.Block} + reg::JLIR.Region code::Core.Compiler.IRCode state::JLIR.OperationState end -function JLIRBuilder(code::Core.Compiler.IRCode, name::String) +function JLIRBuilder(code::Core.Compiler.IRCode, rt::Type, name::String) ctx = JLIR.create_context() ccall((:brutus_register_dialects, "libbrutus"), Cvoid, @@ -29,7 +30,14 @@ function JLIRBuilder(code::Core.Compiler.IRCode, name::String) locations = extract_linetable_locations(ctx, linetable) argtypes = getfield(code, :argtypes) args = [convert_type_to_jlirtype(ctx, a) for a in argtypes] - state = JLIR.create_operation_state(name, locations[1]) + ftype = emit_ftype(ctx, code, rt) + state = JLIR.create_operation_state("func", locations[1]) + type_attr = JLIR.get_type_attribute(ftype) + named_type_attr = JLIR.NamedAttribute(ctx, "type", type_attr) + string_attr = JLIR.get_string_attribute(ctx, name) + symbol_name_attr = JLIR.NamedAttribute(ctx, "sym_name", string_attr) + JLIR.push_attributes!(state, named_type_attr) + JLIR.push_attributes!(state, symbol_name_attr) entry_blk, reg = JLIR.add_entry_block!(state, args) tr = JLIR.get_first_block(reg) nblocks = length(code.cfg.blocks) @@ -39,12 +47,13 @@ function JLIRBuilder(code::Core.Compiler.IRCode, name::String) JLIR.push!(reg, blk) push!(blocks, blk) end - return JLIRBuilder(ctx, Ref(2), JLIR.Value[], args, locations, blocks, code, state) + return JLIRBuilder(ctx, Ref(2), JLIR.Value[], args, locations, blocks, reg, code, state) end set_insertion!(b::JLIRBuilder, blk::Int) = b.insertion[] = blk get_insertion_block(b::JLIRBuilder) = b.blocks[b.insertion[]] +get_locindices(b::JLIRBuilder) = b.code.stmts.line get_stmts(b::JLIRBuilder) = b.code.stmts.inst get_types(b::JLIRBuilder) = b.code.stmts.type get_stmt(b::JLIRBuilder, ind::Int) = getindex(b.code.stmts.inst, ind) @@ -62,6 +71,8 @@ finish(b::JLIRBuilder) = JLIR.Operation(b.state) ##### Utilities ##### +# Explicitly exposed as part of extern C in codegen.cpp. + function convert_type_to_jlirtype(ctx::JLIR.Context, a) return ccall((:brutus_get_jlirtype, "libbrutus"), JLIR.Type, @@ -76,14 +87,21 @@ function convert_value_to_jlirattr(ctx::JLIR.Context, a) ctx, a) end -function get_functype(builder::JLIRBuilder, args::Vector{JLIR.Type}, ret::JLIR.Type) - return MLIR.API.mlirFunctionTypeGet(builder.ctx, length(args), args, 1, [ret]) +function convert_jlirvalue_to_type(v::JLIR.Value) + return ccall((:brutus_get_julia_type, "libbrutus"), + Any, + (JLIR.Value, ), + v) +end + +function get_functype(ctx::JLIR.Context, args::Vector{JLIR.Type}, ret::JLIR.Type) + return MLIR.API.mlirFunctionTypeGet(ctx, length(args), args, 1, [ret]) end -function get_functype(builder::JLIRBuilder, args, ret) - return get_functype(builder, length(args), map(args) do a - convert_type_to_jlirtype(builder.ctx, a) - end, 1, [convert_type_to_jlirtype(builder.ctx, ret)]) +function get_functype(ctx::JLIR.Context, args, ret) + return get_functype(ctx, length(args), map(args) do a + convert_type_to_jlirtype(ctx, a) + end, 1, [convert_type_to_jlirtype(ctx, ret)]) end function unwrap(mi::Core.MethodInstance) diff --git a/include/brutus/brutus.h b/include/brutus/brutus.h index 8ccbaa2..fcda8d5 100644 --- a/include/brutus/brutus.h +++ b/include/brutus/brutus.h @@ -24,10 +24,14 @@ extern "C" { #endif - void brutus_register_dialects(MlirContext Context); - void brutus_register_extern_dialect(MlirContext Context, MlirDialect Dialect); + void brutus_register_dialects(MlirContext context); MlirType brutus_get_jlirtype(MlirContext context, jl_datatype_t *datatype); + jl_value_t *brutus_get_julia_type(MlirValue v); MlirAttribute brutus_get_jlirattr(MlirContext context, jl_value_t *value); + + // TODO: deprecate -- should be available in MLIR C API. + void brutus_register_extern_dialect(MlirContext context, MlirDialect dialect); + MlirValue brutusBlockAddArgument(MlirBlock block, MlirType type); // Export C API for pipeline. typedef void (*ExecutionEngineFPtrResult)(void **); diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index e7e6ce6..b896ca4 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -468,6 +468,7 @@ mlir::FuncOp emit_function(jl_mlirctx_t &ctx, extern "C" { + // TODO: deprecate -- available in MLIR C API. void brutus_register_extern_dialect(MlirContext Context, MlirDialect Dialect) { return; @@ -488,7 +489,13 @@ extern "C" mlir::Type type = JuliaType::get(ctx, datatype); return wrap(type); }; - + + jl_value_t *brutus_get_julia_type(MlirValue v) { + mlir::Value value = unwrap(v); + jl_value_t *value_type = (jl_value_t *)value.getType().cast().getDatatype(); + return value_type; + } + MlirAttribute brutus_get_jlirattr(MlirContext Context, jl_value_t *value) { @@ -497,6 +504,12 @@ extern "C" return wrap(val); }; + // TODO: deprecate -- available in MLIR C API. + MlirValue brutusBlockAddArgument(MlirBlock block, MlirType type) + { + return wrap(unwrap(block)->addArgument(unwrap(type))); + } + void brutus_codegen_jlir(MlirContext Context, MlirModule Module, jl_value_t *methods, From 48a2e9c83a03c85645ce89ffe89de159e516c940 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Sat, 24 Apr 2021 12:03:02 -0400 Subject: [PATCH 08/13] Julia codegen working for simple functions (identity). --- Brutus/src/compiler/codegen.jl | 24 ++++++++++++++---------- Brutus/src/compiler/jlirgen.jl | 4 ++-- Brutus/src/compiler/opbuilder.jl | 9 +++++---- include/brutus/brutus.h | 2 +- lib/Codegen/Codegen.cpp | 7 +++---- 5 files changed, 25 insertions(+), 21 deletions(-) diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 0837ebd..2c66936 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -5,13 +5,15 @@ # This is the Julia interface between Julia's IRCode and JLIR. function maybe_widen_type(b::JLIRBuilder, loc::JLIR.Location, - value::JLIR.Value, expected_type::Type) - type = convert_jlirvalue_to_type(value) + jlir_value::JLIR.Value, expected_type::Type) + jlir_type = JLIR.get_type(jlir_value) + type = convert_jlirtype_to_type(jlir_type) if (type != expected_type && type <: expected_type) - op = create!(b, PiOp(), loc, value, expected_type) + jlir_expected_type = convert_type_to_jlirtype(b.ctx, expected_type) + op = create!(b, PiOp(), loc, jlir_value, jlir_expected_type) return JLIR.get_result(op, 0) else - return value + return jlir_value end end @@ -25,7 +27,8 @@ end function emit_value(b::JLIRBuilder, loc::JLIR.Location, value::Core.Argument, type::Type) idx = value.n - return b.arguments[idx]; + arg = JLIR.get_arg(b.blocks[1], idx - 1) + return arg end function emit_value(b::JLIRBuilder, loc::JLIR.Location, @@ -143,8 +146,8 @@ end function process_stmt!(b::JLIRBuilder, ind::Int, stmt::Core.ReturnNode, loc::JLIR.Location, type::Type) if isdefined(stmt, :val) - v = emit_value(b, loc, stmt.val, Any) - value = maybe_widen_type(b, loc, v, type) + jlir_v = emit_value(b, loc, stmt.val, Any) + value = maybe_widen_type(b, loc, jlir_v, b.rt) else jlir_type = convert_type_to_jlirtype(b.ctx, type) value = create!(b, UndefOp(), loc, jlir_type) @@ -169,7 +172,6 @@ function process_stmt!(b::JLIRBuilder, ind::Int, args = JLIR.Value[emit_value(b, loc, a, Any) for a in args] op = create!(b, CallOp(), loc, callee, args, jlir_type) else - display(head) op = create!(b, UnimplementedOp(), loc, jlir_type) end res = JLIR.get_result(op, 0) @@ -180,6 +182,7 @@ end function emit_jlir(ir_code::Core.Compiler.IRCode, rt::Type, name::String) # Create builder. b = JLIRBuilder(ir_code, rt, name) + m = JLIR.Module(JLIR.Location(b.ctx)) # Create branch from entry block. v = walk_cfg_emit_branchargs(b, 1, 2, b.locations[1]) @@ -200,8 +203,9 @@ function emit_jlir(ir_code::Core.Compiler.IRCode, rt::Type, name::String) end end - # Create op from state and verify. - op = finish(b) + # Create op from module and verify. + JLIR.push_operation!(m, finish(b)) + op = JLIR.get_operation(m) JLIR.verify(op) JLIR.dump(op) return op diff --git a/Brutus/src/compiler/jlirgen.jl b/Brutus/src/compiler/jlirgen.jl index e031eee..d9350f6 100644 --- a/Brutus/src/compiler/jlirgen.jl +++ b/Brutus/src/compiler/jlirgen.jl @@ -40,11 +40,11 @@ function create_gotoifnot_op(loc::JLIR.Location, cond::JLIR.Value, return JLIR.Operation(state) end -function create_pi_op(loc::JLIR.Location, input::JLIR.Type, +function create_pi_op(loc::JLIR.Location, value::JLIR.Value, type::JLIR.Type) state = JLIR.create_operation_state("jlir.pi", loc) - JLIR.push_operands!(state, value) JLIR.push_results!(state, type) + JLIR.push_operands!(state, value) return JLIR.Operation(state) end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl index fd34483..ff0ba29 100644 --- a/Brutus/src/compiler/opbuilder.jl +++ b/Brutus/src/compiler/opbuilder.jl @@ -13,6 +13,7 @@ struct JLIRBuilder blocks::Vector{JLIR.Block} reg::JLIR.Region code::Core.Compiler.IRCode + rt::Type state::JLIR.OperationState end @@ -47,7 +48,7 @@ function JLIRBuilder(code::Core.Compiler.IRCode, rt::Type, name::String) JLIR.push!(reg, blk) push!(blocks, blk) end - return JLIRBuilder(ctx, Ref(2), JLIR.Value[], args, locations, blocks, reg, code, state) + return JLIRBuilder(ctx, Ref(2), JLIR.Value[], args, locations, blocks, reg, code, rt, state) end set_insertion!(b::JLIRBuilder, blk::Int) = b.insertion[] = blk @@ -87,10 +88,10 @@ function convert_value_to_jlirattr(ctx::JLIR.Context, a) ctx, a) end -function convert_jlirvalue_to_type(v::JLIR.Value) +function convert_jlirtype_to_type(v::JLIR.Type) return ccall((:brutus_get_julia_type, "libbrutus"), - Any, - (JLIR.Value, ), + Type, + (JLIR.Type, ), v) end diff --git a/include/brutus/brutus.h b/include/brutus/brutus.h index fcda8d5..d6661ba 100644 --- a/include/brutus/brutus.h +++ b/include/brutus/brutus.h @@ -26,7 +26,7 @@ extern "C" #endif void brutus_register_dialects(MlirContext context); MlirType brutus_get_jlirtype(MlirContext context, jl_datatype_t *datatype); - jl_value_t *brutus_get_julia_type(MlirValue v); + jl_datatype_t *brutus_get_julia_type(MlirType v); MlirAttribute brutus_get_jlirattr(MlirContext context, jl_value_t *value); // TODO: deprecate -- should be available in MLIR C API. diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index b896ca4..722e032 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -490,10 +490,9 @@ extern "C" return wrap(type); }; - jl_value_t *brutus_get_julia_type(MlirValue v) { - mlir::Value value = unwrap(v); - jl_value_t *value_type = (jl_value_t *)value.getType().cast().getDatatype(); - return value_type; + jl_datatype_t *brutus_get_julia_type(MlirType v) { + mlir::Type type = unwrap(v); + return (jl_datatype_t *)type.cast().getDatatype(); } MlirAttribute brutus_get_jlirattr(MlirContext Context, From a228ff7b5ac37ad054044c6bde628c15bdbc684e Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Sat, 24 Apr 2021 23:03:05 -0400 Subject: [PATCH 09/13] Fixed value indexing on builder. --- Brutus/scratch/juliacodegen.jl | 18 +++++++++++++++++- Brutus/src/compiler/codegen.jl | 15 ++++++++------- Brutus/src/compiler/opbuilder.jl | 6 +++--- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/Brutus/scratch/juliacodegen.jl b/Brutus/scratch/juliacodegen.jl index 3bfc89d..70ab9cd 100644 --- a/Brutus/scratch/juliacodegen.jl +++ b/Brutus/scratch/juliacodegen.jl @@ -3,11 +3,25 @@ module JuliaCodegen using Brutus using MLIR -brutus_id(N) = N +function brutus_id(N) + return N +end mi = Brutus.get_methodinstance(Tuple{typeof(brutus_id), Int}) ir_code, rt = Brutus.code_ircode(mi) +display(ir_code) mod = Brutus.Compiler.emit_jlir(ir_code, rt, "brutus_id") +MLIR.IR.dump(mod) + +function switch(N) + N > 10 ? 5 : 10 +end + +mi = Brutus.get_methodinstance(Tuple{typeof(switch), Int}) +ir_code, rt = Brutus.code_ircode(mi) +display(ir_code) +mod = Brutus.Compiler.emit_jlir(ir_code, rt, "switch") +MLIR.IR.dump(mod) #function gauss(N) # k = 0 @@ -19,6 +33,8 @@ mod = Brutus.Compiler.emit_jlir(ir_code, rt, "brutus_id") # #mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) #ir_code, rt = Brutus.code_ircode(mi) +#display(ir_code) #mod = Brutus.Compiler.emit_jlir(ir_code, rt, "gauss") +#MLIR.IR.dump(mod) end # module diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 2c66936..1448e59 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -96,7 +96,7 @@ end function process_stmt!(b::JLIRBuilder, ind::Int, stmt, loc::JLIR.Location, type::Type) - push!(b.values, emit_value(b, loc, stmt, type)) + setindex!(b.values, emit_value(b, loc, stmt, type), ind) return false end @@ -114,7 +114,8 @@ function process_stmt!(b::JLIRBuilder, ind::Int, dest = stmt.dest op = create!(b, GotoIfNotOp(), loc, cond, b.blocks[dest], - walk_cfg_emit_branchargs(b, b.insertion[], dest, loc), + walk_cfg_emit_branchargs(b, b.insertion[], + dest, loc), b.blocks[b.insertion[] + 1], walk_cfg_emit_branchargs(b, b.insertion[], b.insertion[] + 1, loc)) @@ -129,7 +130,7 @@ function process_stmt!(b::JLIRBuilder, ind::Int, JLIR.Value, (JLIR.Block, JLIR.Type), blk, t) - push!(b.values, arg) + setindex!(b.values, arg, ind) return false end @@ -138,8 +139,9 @@ function process_stmt!(b::JLIRBuilder, ind::Int, val = stmt.val @assert(type == stmt.type) jlir_type = convert_type_to_jlirtype(b.ctx, type) - op = create!(b, PiOp(), loc, emit_value(b, loc, val, Any), jlir_type) - ctx.values[ind] = JLIR.get_result(op, 0) + op = create!(b, PiOp(), loc, + emit_value(b, loc, val, Any), jlir_type) + setindex!(b.values, JLIR.get_result(op, 0), ind) return false end @@ -175,7 +177,7 @@ function process_stmt!(b::JLIRBuilder, ind::Int, op = create!(b, UnimplementedOp(), loc, jlir_type) end res = JLIR.get_result(op, 0) - push!(b.values, res) + setindex!(b.values, res, ind) return false end @@ -207,6 +209,5 @@ function emit_jlir(ir_code::Core.Compiler.IRCode, rt::Type, name::String) JLIR.push_operation!(m, finish(b)) op = JLIR.get_operation(m) JLIR.verify(op) - JLIR.dump(op) return op end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl index ff0ba29..8080c34 100644 --- a/Brutus/src/compiler/opbuilder.jl +++ b/Brutus/src/compiler/opbuilder.jl @@ -7,7 +7,7 @@ struct JLIRBuilder ctx::JLIR.Context insertion::Ref{Int} - values::Vector{JLIR.Value} + values::Dict{Int, JLIR.Value} arguments::Vector{JLIR.Type} locations::Vector{JLIR.Location} blocks::Vector{JLIR.Block} @@ -48,7 +48,7 @@ function JLIRBuilder(code::Core.Compiler.IRCode, rt::Type, name::String) JLIR.push!(reg, blk) push!(blocks, blk) end - return JLIRBuilder(ctx, Ref(2), JLIR.Value[], args, locations, blocks, reg, code, rt, state) + return JLIRBuilder(ctx, Ref(2), Dict{Int, JLIR.Value}(), args, locations, blocks, reg, code, rt, state) end set_insertion!(b::JLIRBuilder, blk::Int) = b.insertion[] = blk @@ -124,7 +124,7 @@ function extract_linetable_locations(ctx::JLIR.Context, v::Vector{Core.LineInfoN fname = String(method) end current = JLIR.Location(ctx, fname, UInt32(line), UInt32(0)) # TODO: col. - if inlined_at > 0 + if inlined_at > 1 current = JLIR.Location(current, locations[inlined_at - 1]) end push!(locations, current) From 4a1eb2afa8d4ec71b0a715b24461c99c71121673 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Sat, 24 Apr 2021 23:50:04 -0400 Subject: [PATCH 10/13] Fixed value indexing on builder. --- Brutus/scratch/juliacodegen.jl | 43 ++++++++++++++++++++++++---------- Brutus/src/compiler/codegen.jl | 19 +++++++++------ 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/Brutus/scratch/juliacodegen.jl b/Brutus/scratch/juliacodegen.jl index 70ab9cd..8686f69 100644 --- a/Brutus/scratch/juliacodegen.jl +++ b/Brutus/scratch/juliacodegen.jl @@ -3,6 +3,7 @@ module JuliaCodegen using Brutus using MLIR +println("\n---- brutus_id ----\n") function brutus_id(N) return N end @@ -13,6 +14,20 @@ display(ir_code) mod = Brutus.Compiler.emit_jlir(ir_code, rt, "brutus_id") MLIR.IR.dump(mod) +println("\n---- brutus_add ----\n") + +function brutus_add(N1, N2) + return N1 + N2 +end + +mi = Brutus.get_methodinstance(Tuple{typeof(brutus_add), Int, Int}) +ir_code, rt = Brutus.code_ircode(mi) +display(ir_code) +mod = Brutus.Compiler.emit_jlir(ir_code, rt, "brutus_add") +MLIR.IR.dump(mod) + +println("\n---- switch ----\n") + function switch(N) N > 10 ? 5 : 10 end @@ -23,18 +38,20 @@ display(ir_code) mod = Brutus.Compiler.emit_jlir(ir_code, rt, "switch") MLIR.IR.dump(mod) -#function gauss(N) -# k = 0 -# for i in 1 : N -# k += i -# end -# return k -#end -# -#mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) -#ir_code, rt = Brutus.code_ircode(mi) -#display(ir_code) -#mod = Brutus.Compiler.emit_jlir(ir_code, rt, "gauss") -#MLIR.IR.dump(mod) +println("\n---- gauss ----\n") + +function gauss(N) + k = 0 + for i in 1 : N + k += i + end + return k +end + +mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) +ir_code, rt = Brutus.code_ircode(mi) +display(ir_code) +mod = Brutus.Compiler.emit_jlir(ir_code, rt, "gauss") +MLIR.IR.dump(mod) end # module diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 1448e59..46e96f3 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -18,9 +18,11 @@ function maybe_widen_type(b::JLIRBuilder, loc::JLIR.Location, end function emit_value(b::JLIRBuilder, loc::JLIR.Location, - value, type::Type) + value, ::Type) + type = typeof(value) jlir_type = convert_type_to_jlirtype(b.ctx, type) - op = create!(b, UnimplementedOp(), loc, jlir_type) + jlir_value = convert_value_to_jlirattr(b.ctx, value) + op = create!(b, ConstantOp(), loc, jlir_value, jlir_type) return JLIR.get_result(op, 0) end @@ -111,14 +113,15 @@ end function process_stmt!(b::JLIRBuilder, ind::Int, stmt::Core.GotoIfNot, loc::JLIR.Location, type::Type) cond = emit_value(b, loc, stmt.cond, Any) - dest = stmt.dest + dest = stmt.dest + 1 # Accounts for entry block. + fallthrough = b.insertion[] + 1 op = create!(b, GotoIfNotOp(), loc, cond, b.blocks[dest], walk_cfg_emit_branchargs(b, b.insertion[], dest, loc), - b.blocks[b.insertion[] + 1], + b.blocks[fallthrough], walk_cfg_emit_branchargs(b, b.insertion[], - b.insertion[] + 1, loc)) + fallthrough, loc)) return true end @@ -167,11 +170,13 @@ function process_stmt!(b::JLIRBuilder, ind::Int, @assert(args[1] isa Core.MethodInstance) mi = args[1] callee = emit_value(b, loc, args[2], Any) - args = JLIR.Value[emit_value(b, loc, a, Any) for a in args[2 : end]] + args = JLIR.Value[emit_value(b, loc, a, Any) + for a in args[2 : end]] op = create!(b, InvokeOp, loc, mi, callee, args, jlir_type) elseif head == :call callee = emit_value(b, loc, args[1], Any) - args = JLIR.Value[emit_value(b, loc, a, Any) for a in args] + args = JLIR.Value[emit_value(b, loc, a, Any) + for a in args[2 : end]] op = create!(b, CallOp(), loc, callee, args, jlir_type) else op = create!(b, UnimplementedOp(), loc, jlir_type) From 1fba8fa61c79bb93714a15df7a0d390a8d3c4a94 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Sun, 25 Apr 2021 11:52:03 -0400 Subject: [PATCH 11/13] Executing code generated from the C pipeline. --- Brutus/Project.toml | 1 - Brutus/scratch/juliacodegen.jl | 72 +++++++++++++++----------------- Brutus/src/compiler/codegen.jl | 57 +++++++++++++++++++++++-- Brutus/src/compiler/opbuilder.jl | 13 ++++++ Brutus/src/interface.jl | 60 +++++++++++++------------- include/brutus/brutus.h | 9 ++-- lib/Codegen/Codegen.cpp | 30 +++++++------ 7 files changed, 154 insertions(+), 88 deletions(-) diff --git a/Brutus/Project.toml b/Brutus/Project.toml index 2068af9..9520d08 100644 --- a/Brutus/Project.toml +++ b/Brutus/Project.toml @@ -6,7 +6,6 @@ version = "0.1.0" [deps] GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" MLIR = "bfde9dd4-8f40-4a1e-be09-1475335e1c92" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] julia = "1.5" diff --git a/Brutus/scratch/juliacodegen.jl b/Brutus/scratch/juliacodegen.jl index 8686f69..e05ffe0 100644 --- a/Brutus/scratch/juliacodegen.jl +++ b/Brutus/scratch/juliacodegen.jl @@ -8,11 +8,8 @@ function brutus_id(N) return N end -mi = Brutus.get_methodinstance(Tuple{typeof(brutus_id), Int}) -ir_code, rt = Brutus.code_ircode(mi) -display(ir_code) -mod = Brutus.Compiler.emit_jlir(ir_code, rt, "brutus_id") -MLIR.IR.dump(mod) +v = Brutus.call(brutus_id, 5) +display(v) println("\n---- brutus_add ----\n") @@ -20,38 +17,37 @@ function brutus_add(N1, N2) return N1 + N2 end -mi = Brutus.get_methodinstance(Tuple{typeof(brutus_add), Int, Int}) -ir_code, rt = Brutus.code_ircode(mi) -display(ir_code) -mod = Brutus.Compiler.emit_jlir(ir_code, rt, "brutus_add") -MLIR.IR.dump(mod) - -println("\n---- switch ----\n") - -function switch(N) - N > 10 ? 5 : 10 -end - -mi = Brutus.get_methodinstance(Tuple{typeof(switch), Int}) -ir_code, rt = Brutus.code_ircode(mi) -display(ir_code) -mod = Brutus.Compiler.emit_jlir(ir_code, rt, "switch") -MLIR.IR.dump(mod) - -println("\n---- gauss ----\n") - -function gauss(N) - k = 0 - for i in 1 : N - k += i - end - return k -end - -mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) -ir_code, rt = Brutus.code_ircode(mi) -display(ir_code) -mod = Brutus.Compiler.emit_jlir(ir_code, rt, "gauss") -MLIR.IR.dump(mod) +v = Brutus.call(brutus_add, 5.0, 10.0) +display(v) + +#println("\n---- switch ----\n") +# +#function switch(N) +# N > 10 ? 5 : 10 +#end +# +#mi = Brutus.get_methodinstance(Tuple{typeof(switch), Int}) +#ir_code, rt = Brutus.code_ircode(mi) +#display(ir_code) +#jlir = Brutus.Compiler.codegen_jlir(ir_code, rt, "switch") +#display(jlir) +#Brutus.Compiler.canonicalize!(jlir) +#display(jlir) +# +#println("\n---- gauss ----\n") +# +#function gauss(N) +# k = 0 +# for i in 1 : N +# k += i +# end +# return k +#end +# +#mi = Brutus.get_methodinstance(Tuple{typeof(gauss), Int}) +#ir_code, rt = Brutus.code_ircode(mi) +#display(ir_code) +#mod = Brutus.Compiler.codegen_jlir(ir_code, rt, "gauss") +#MLIR.IR.dump(mod) end # module diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 46e96f3..86a52d1 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -186,7 +186,19 @@ function process_stmt!(b::JLIRBuilder, ind::Int, return false end -function emit_jlir(ir_code::Core.Compiler.IRCode, rt::Type, name::String) +##### +##### JLIR generation +##### + +mutable struct CompiledJLIRModule + ctx::JLIR.Context + mod::JLIR.Module + name::String +end + +Base.display(jlir::CompiledJLIRModule) = JLIR.dump(JLIR.get_operation(jlir.mod)) + +function codegen_jlir(ir_code::Core.Compiler.IRCode, rt::Type, name::String) # Create builder. b = JLIRBuilder(ir_code, rt, name) m = JLIR.Module(JLIR.Location(b.ctx)) @@ -212,7 +224,44 @@ function emit_jlir(ir_code::Core.Compiler.IRCode, rt::Type, name::String) # Create op from module and verify. JLIR.push_operation!(m, finish(b)) - op = JLIR.get_operation(m) - JLIR.verify(op) - return op + @assert(JLIR.verify(JLIR.get_operation(m))) + return CompiledJLIRModule(b.ctx, m, name) +end + +function canonicalize!(jlir::CompiledJLIRModule) + ccall((:brutus_canonicalize, "libbrutus"), + Cvoid, + (JLIR.Context, JLIR.Module), + jlir.ctx, jlir.mod) + op = JLIR.get_operation(jlir.mod) + @assert(JLIR.verify(op)) + return +end + +function dialect_lower_to_std!(jlir::CompiledJLIRModule) + ccall((:brutus_lower_to_standard, "libbrutus"), + Cvoid, + (JLIR.Context, JLIR.Module), + jlir.ctx, jlir.mod) + op = JLIR.get_operation(jlir.mod) + @assert(JLIR.verify(op)) + return +end + +function dialect_lower_to_llvm!(jlir::CompiledJLIRModule) + ccall((:brutus_lower_to_llvm, "libbrutus"), + Cvoid, + (JLIR.Context, JLIR.Module), + jlir.ctx, jlir.mod) + op = JLIR.get_operation(jlir.mod) + @assert(JLIR.verify(op)) + return +end + +function thunk(jlir::CompiledJLIRModule) + fptr = ccall((:c_brutus_create_execution_engine, "libbrutus"), + Ptr{Nothing}, + (JLIR.Context, JLIR.Module, Cstring), + jlir.ctx, jlir.mod, jlir.name) + return fptr end diff --git a/Brutus/src/compiler/opbuilder.jl b/Brutus/src/compiler/opbuilder.jl index 8080c34..bb0af23 100644 --- a/Brutus/src/compiler/opbuilder.jl +++ b/Brutus/src/compiler/opbuilder.jl @@ -18,17 +18,23 @@ struct JLIRBuilder end function JLIRBuilder(code::Core.Compiler.IRCode, rt::Type, name::String) + + # Create a context and register dialects required by Brutus. ctx = JLIR.create_context() ccall((:brutus_register_dialects, "libbrutus"), Cvoid, (JLIR.Context, ), ctx) + + # IRCode metadata -> JLIR metadata (locations). irstream = code.stmts stmts = irstream.inst types = irstream.type location_indices = getfield(irstream, :line) linetable = getfield(code, :linetable) locations = extract_linetable_locations(ctx, linetable) + + # Create toplevel FuncOp. argtypes = getfield(code, :argtypes) args = [convert_type_to_jlirtype(ctx, a) for a in argtypes] ftype = emit_ftype(ctx, code, rt) @@ -37,8 +43,13 @@ function JLIRBuilder(code::Core.Compiler.IRCode, rt::Type, name::String) named_type_attr = JLIR.NamedAttribute(ctx, "type", type_attr) string_attr = JLIR.get_string_attribute(ctx, name) symbol_name_attr = JLIR.NamedAttribute(ctx, "sym_name", string_attr) + viz_attr = JLIR.get_string_attribute(ctx, "nested") + named_viz_attr = JLIR.NamedAttribute(ctx, "sym_visibility", viz_attr) + unit_attr = JLIR.get_unit_attribute(ctx) JLIR.push_attributes!(state, named_type_attr) JLIR.push_attributes!(state, symbol_name_attr) + JLIR.push_attributes!(state, named_viz_attr) + JLIR.push_attributes!(state, JLIR.NamedAttribute(ctx, "llvm.emit_c_interface", unit_attr)) entry_blk, reg = JLIR.add_entry_block!(state, args) tr = JLIR.get_first_block(reg) nblocks = length(code.cfg.blocks) @@ -48,6 +59,8 @@ function JLIRBuilder(code::Core.Compiler.IRCode, rt::Type, name::String) JLIR.push!(reg, blk) push!(blocks, blk) end + + # Pass FuncOp state in builder. return JLIRBuilder(ctx, Ref(2), Dict{Int, JLIR.Value}(), args, locations, blocks, reg, code, rt, state) end diff --git a/Brutus/src/interface.jl b/Brutus/src/interface.jl index f00e18e..216b433 100644 --- a/Brutus/src/interface.jl +++ b/Brutus/src/interface.jl @@ -51,7 +51,7 @@ end # Emit MLIR IR to stdout function emit(job::CompilerJob) - ft = job.source.f + ft = typeof(job.source.f) tt = job.source.tt emit_fptr = job.params.emit_fptr dump_options = job.params.dump_options @@ -67,36 +67,37 @@ function emit(job::CompilerJob) println(IR) end - worklist = [IR] - methods = Dict{Core.MethodInstance, Tuple{Core.Compiler.IRCode, Any}}( - entry_mi => (IR, rt) - ) + #worklist = [IR] + #methods = Dict{Core.MethodInstance, Tuple{Core.Compiler.IRCode, Any}}( + # entry_mi => (IR, rt) + #) - while !isempty(worklist) - code = pop!(worklist) - callees = find_invokes(code) - for callee in callees - if !haskey(methods, callee) - _code, _rt = code_ircode(callee) + #while !isempty(worklist) + # code = pop!(worklist) + # callees = find_invokes(code) + # for callee in callees + # if !haskey(methods, callee) + # _code, _rt = code_ircode(callee) - methods[callee] = (_code, _rt) - push!(worklist, _code) - end - end - end + # methods[callee] = (_code, _rt) + # push!(worklist, _code) + # end + # end + #end # generate LLVM bitcode and load it - dump_flags = reduce(|, map(UInt8, dump_options), init=0) - fptr = ccall((:brutus_codegen, "libbrutus"), - Ptr{Nothing}, - (Any, Any, Cuchar, Cuchar), - methods, entry_mi, emit_fptr, dump_flags) + jlir = Brutus.Compiler.codegen_jlir(IR, rt, String(name)) + Brutus.Compiler.canonicalize!(jlir) + Brutus.Compiler.canonicalize!(jlir) + Brutus.Compiler.dialect_lower_to_std!(jlir) + Brutus.Compiler.dialect_lower_to_llvm!(jlir) + fptr = Brutus.Compiler.thunk(jlir) return (fptr, rt) end function emit(@nospecialize(ft), @nospecialize(tt); - emit_fptr::Bool=true, - dump_options::Vector{DumpOption}=DumpOption[]) + emit_fptr::Bool=true, + dump_options::Vector{DumpOption}=DumpOption[]) fspec = GPUCompiler.FunctionSpec(ft, Tuple{tt...}, false, nothing) target = BrutusCompilerTarget() params = BrutusCompilerParams(emit_fptr, dump_options) @@ -135,25 +136,26 @@ struct Thunk{F, RT, TT} ptr::Ptr{Cvoid} end -const brutus_cache = Dict{UInt,Any}() - function link(job::CompilerJob, (fptr, rt)) @assert fptr != C_NULL - fptr, rt = result f = job.source.f tt = job.source.tt return Thunk{typeof(f), rt, tt}(f, fptr) end -function thunk(f::F, tt::TT=Tuple{}; emit_fptr::Bool = true, dump_options::Vector{DumpOption} = DumpOption[]) where {F<:Base.Callable, TT<:Type} - fspec = GPUCompiler.FunctionSpec(F, tt, false, nothing) +const brutus_cache = Dict{UInt,Any}() + +function thunk(f::F, tt::TT=Tuple{}; + emit_fptr::Bool = true, + dump_options::Vector{DumpOption} = DumpOption[]) where {F <: Base.Callable, TT <: Type} + fspec = GPUCompiler.FunctionSpec(f, tt, false, nothing) target = BrutusCompilerTarget() params = BrutusCompilerParams(emit_fptr, dump_options) job = CompilerJob(target, fspec, params) return GPUCompiler.cached_compilation(brutus_cache, job, emit, link) end -# Need to pass struct as pointer, to match cifacme ABI +# Need to pass struct as pointer, to match ciface ABI abi(::Type{<:Array{T, N}}) where {T, N} = Ref{MemrefDescriptor{T, N}} function abi(T::DataType) if isprimitivetype(T) diff --git a/include/brutus/brutus.h b/include/brutus/brutus.h index d6661ba..7c9b1f4 100644 --- a/include/brutus/brutus.h +++ b/include/brutus/brutus.h @@ -37,11 +37,12 @@ extern "C" typedef void (*ExecutionEngineFPtrResult)(void **); void brutus_init(jl_module_t *brutus); - void brutus_codegen_jlir(MlirContext context, MlirModule module, jl_value_t *methods, jl_method_instance_t *entry_mi, char dump_flags); - void brutus_canonicalize(MlirContext context, MlirModule module, char dump_flags); - void brutus_lower_to_standard(MlirContext context, MlirModule module, char dump_flags); - void brutus_lower_to_llvm(MlirContext context, MlirModule module, char dump_flags); + void brutus_codegen_jlir(MlirContext context, MlirModule module, jl_value_t *methods, jl_method_instance_t *entry_mi); + void brutus_canonicalize(MlirContext context, MlirModule module); + void brutus_lower_to_standard(MlirContext context, MlirModule module); + void brutus_lower_to_llvm(MlirContext context, MlirModule module); ExecutionEngineFPtrResult brutus_create_execution_engine(MlirContext context, MlirModule module, std::string name); + ExecutionEngineFPtrResult c_brutus_create_execution_engine(MlirContext context, MlirModule module, const char *name); ExecutionEngineFPtrResult brutus_codegen(jl_value_t *methods, jl_method_instance_t *entry_mi, char emit_fptr, char dump_flags); #ifdef __cplusplus diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index 722e032..ea7640a 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -1,3 +1,4 @@ +#include #include "brutus/brutus.h" #include "brutus/brutus_internal.h" #include "brutus/Dialect/Julia/JuliaOps.h" @@ -16,6 +17,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" #include "mlir/Target/LLVMIR.h" +#include "llvm-c/Core.h" #include "mlir-c/IR.h" #include "mlir/CAPI/Wrap.h" #include "mlir/CAPI/IR.h" @@ -512,8 +514,7 @@ extern "C" void brutus_codegen_jlir(MlirContext Context, MlirModule Module, jl_value_t *methods, - jl_method_instance_t *entry_mi, - char dump_flags) + jl_method_instance_t *entry_mi) { mlir::ModuleOp module = unwrap(Module); @@ -543,8 +544,7 @@ extern "C" // canonicalize void brutus_canonicalize(MlirContext Context, - MlirModule Module, - char dump_flags) + MlirModule Module) { mlir::MLIRContext *context = unwrap(Context); mlir::ModuleOp module = unwrap(Module); @@ -568,8 +568,7 @@ extern "C" // lower to Standard dialect void brutus_lower_to_standard(MlirContext Context, - MlirModule Module, - char dump_flags) + MlirModule Module) { mlir::MLIRContext *context = unwrap(Context); mlir::ModuleOp module = unwrap(Module); @@ -589,8 +588,7 @@ extern "C" // lower to LLVM dialect void brutus_lower_to_llvm(MlirContext Context, - MlirModule Module, - char dump_flags) + MlirModule Module) { mlir::MLIRContext *context = unwrap(Context); mlir::ModuleOp module = unwrap(Module); @@ -670,7 +668,7 @@ extern "C" MlirContext Context = mlirContextCreate(); MlirModule Module = mlirModuleCreateEmpty(mlirLocationUnknownGet(Context)); - brutus_codegen_jlir(Context, Module, methods, entry_mi, dump_flags); + brutus_codegen_jlir(Context, Module, methods, entry_mi); if (dump_flags && DUMP_TRANSLATED) { mlir::ModuleOp module = unwrap(Module); @@ -679,7 +677,7 @@ extern "C" llvm::dbgs() << "\n\n"; } - brutus_canonicalize(Context, Module, dump_flags); + brutus_canonicalize(Context, Module); if (dump_flags & DUMP_CANONICALIZED) { mlir::ModuleOp module = unwrap(Module); @@ -688,7 +686,7 @@ extern "C" llvm::dbgs() << "\n\n"; } - brutus_lower_to_standard(Context, Module, dump_flags); + brutus_lower_to_standard(Context, Module); if (dump_flags & DUMP_LOWERED_TO_STD) { mlir::ModuleOp module = unwrap(Module); @@ -697,7 +695,7 @@ extern "C" llvm::dbgs() << "\n\n"; } - brutus_lower_to_llvm(Context, Module, dump_flags); + brutus_lower_to_llvm(Context, Module); if (dump_flags & DUMP_LOWERED_TO_LLVM) { mlir::ModuleOp module = unwrap(Module); @@ -729,4 +727,12 @@ extern "C" return engine_ptr; } + + ExecutionEngineFPtrResult c_brutus_create_execution_engine(MlirContext Context, + MlirModule Module, + const char *name) + { + std::string str(name); + return brutus_create_execution_engine(Context, Module, str); + } } From 39363ebd225351b015b389549978ffea0b9bf175 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Sun, 25 Apr 2021 11:55:57 -0400 Subject: [PATCH 12/13] Remove iostream debug p. --- lib/Codegen/Codegen.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Codegen/Codegen.cpp b/lib/Codegen/Codegen.cpp index ea7640a..a8f073c 100644 --- a/lib/Codegen/Codegen.cpp +++ b/lib/Codegen/Codegen.cpp @@ -1,4 +1,3 @@ -#include #include "brutus/brutus.h" #include "brutus/brutus_internal.h" #include "brutus/Dialect/Julia/JuliaOps.h" From e29937a74cf8b82ba0db7f552dc5ab6dd8ade218 Mon Sep 17 00:00:00 2001 From: "McCoy R. Becker" Date: Sun, 25 Apr 2021 13:50:26 -0400 Subject: [PATCH 13/13] Modified interface to include dump options. Included cleanup! call to remove modules and contexts in memory. --- Brutus/scratch/juliacodegen.jl | 28 +++++++++++++++++------- Brutus/src/compiler/codegen.jl | 15 +++++++++++++ Brutus/src/interface.jl | 40 ++++++++++++++++++++++++++++++---- 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/Brutus/scratch/juliacodegen.jl b/Brutus/scratch/juliacodegen.jl index e05ffe0..e864376 100644 --- a/Brutus/scratch/juliacodegen.jl +++ b/Brutus/scratch/juliacodegen.jl @@ -9,6 +9,7 @@ function brutus_id(N) end v = Brutus.call(brutus_id, 5) +@time v = Brutus.call(brutus_id, 5) display(v) println("\n---- brutus_add ----\n") @@ -18,6 +19,22 @@ function brutus_add(N1, N2) end v = Brutus.call(brutus_add, 5.0, 10.0) +@time v = Brutus.call(brutus_add, 5.0, 10.0) +display(v) + +println("\n---- structs ----\n") + +struct Foo + x +end + +function bar() + f = Foo(5.0) + b = Foo(f.x + 10.0) + return f +end + +v = Brutus.call(bar; dump_options = Brutus.DumpAll) display(v) #println("\n---- switch ----\n") @@ -26,14 +43,9 @@ display(v) # N > 10 ? 5 : 10 #end # -#mi = Brutus.get_methodinstance(Tuple{typeof(switch), Int}) -#ir_code, rt = Brutus.code_ircode(mi) -#display(ir_code) -#jlir = Brutus.Compiler.codegen_jlir(ir_code, rt, "switch") -#display(jlir) -#Brutus.Compiler.canonicalize!(jlir) -#display(jlir) -# +#v = Brutus.call(switch, 15) +#display(v) + #println("\n---- gauss ----\n") # #function gauss(N) diff --git a/Brutus/src/compiler/codegen.jl b/Brutus/src/compiler/codegen.jl index 86a52d1..29b88f4 100644 --- a/Brutus/src/compiler/codegen.jl +++ b/Brutus/src/compiler/codegen.jl @@ -26,6 +26,16 @@ function emit_value(b::JLIRBuilder, loc::JLIR.Location, return JLIR.get_result(op, 0) end +function emit_value(b::JLIRBuilder, loc::JLIR.Location, + value::QuoteNode, ::Type) + value = getfield(value, :value) + type = typeof(value) + jlir_type = convert_type_to_jlirtype(b.ctx, type) + jlir_value = convert_value_to_jlirattr(b.ctx, value) + op = create!(b, ConstantOp(), loc, jlir_value, jlir_type) + return JLIR.get_result(op, 0) +end + function emit_value(b::JLIRBuilder, loc::JLIR.Location, value::Core.Argument, type::Type) idx = value.n @@ -196,6 +206,11 @@ mutable struct CompiledJLIRModule name::String end +function cleanup!(jlir::CompiledJLIRModule) + JLIR.destroy!(jlir.ctx) + JLIR.destroy!(jlir.mod) +end + Base.display(jlir::CompiledJLIRModule) = JLIR.dump(JLIR.get_operation(jlir.mod)) function codegen_jlir(ir_code::Core.Compiler.IRCode, rt::Type, name::String) diff --git a/Brutus/src/interface.jl b/Brutus/src/interface.jl index 216b433..4ba9194 100644 --- a/Brutus/src/interface.jl +++ b/Brutus/src/interface.jl @@ -25,6 +25,12 @@ end DumpTranslateToLLVM = 16 end +const DumpAll = DumpOption[DumpIRCode, + DumpTranslated, + DumpCanonicalized, + DumpLoweredToStd, + DumpLoweredToLLVM] + struct BrutusCompilerParams <: AbstractCompilerParams emit_fptr::Bool dump_options::Vector{DumpOption} @@ -65,6 +71,7 @@ function emit(job::CompilerJob) println("return type: ", rt) println("IRCode:\n") println(IR) + println() end #worklist = [IR] @@ -84,14 +91,38 @@ function emit(job::CompilerJob) # end # end #end - + # generate LLVM bitcode and load it jlir = Brutus.Compiler.codegen_jlir(IR, rt, String(name)) + if DumpTranslated in dump_options + println("JLIR:") + display(jlir) + println() + end + Brutus.Compiler.canonicalize!(jlir) - Brutus.Compiler.canonicalize!(jlir) + if DumpCanonicalized in dump_options + println("After canonicalization:") + display(jlir) + println() + end + Brutus.Compiler.dialect_lower_to_std!(jlir) + if DumpLoweredToStd in dump_options + println("Standard:") + display(jlir) + println() + end + Brutus.Compiler.dialect_lower_to_llvm!(jlir) + if DumpLoweredToLLVM in dump_options + println("LLVM dialect:") + display(jlir) + println() + end + fptr = Brutus.Compiler.thunk(jlir) + Brutus.Compiler.cleanup!(jlir) return (fptr, rt) end @@ -180,7 +211,8 @@ end return expr end -function call(f::F, args...) where F +function call(f::F, args...; + dump_options::Vector{DumpOption} = DumpOption[]) where F TT = Tuple{map(Core.Typeof, args)...} - return thunk(f, TT)(args...) + return thunk(f, TT; dump_options = dump_options)(args...) end