diff --git a/src/Dialects/Utils.jl b/src/Dialects/Utils.jl index 45885494..bcba0810 100644 --- a/src/Dialects/Utils.jl +++ b/src/Dialects/Utils.jl @@ -1,13 +1,4 @@ -import ..IR: IR, Attribute, NamedAttribute, context -import ..API +import ..IR: NamedAttribute -namedattribute(name, val) = namedattribute(name, Attribute(val)) -namedattribute(name, val::Attribute) = NamedAttribute(name, val) -function namedattribute(name, val::NamedAttribute) - @assert true # TODO(jm): check whether name of attribute is correct, getting the name might need to be added to IR.jl? - return val -end - -function operandsegmentsizes(segments) - return namedattribute("operand_segment_sizes", Attribute(Int32.(segments))) -end +operandsegmentsizes(segments) = NamedAttribute("operand_segment_sizes", Int32.(segments)) +resultsegmentsizes(segments) = NamedAttribute("result_segment_sizes", Int32.(segments)) diff --git a/src/IR/AffineExpr.jl b/src/IR/AffineExpr.jl index 7a0f4931..f4a92b36 100644 --- a/src/IR/AffineExpr.jl +++ b/src/IR/AffineExpr.jl @@ -1,13 +1,9 @@ -struct AffineExpr - expr::API.MlirAffineExpr - - function AffineExpr(expr) - @assert !mlirIsNull(expr) "cannot create AffineExpr with null MlirAffineExpr" - return new(expr) - end +@checked struct AffineExpr + ref::API.MlirAffineExpr end -Base.convert(::Core.Type{API.MlirAffineExpr}, expr::AffineExpr) = expr.expr +Base.cconvert(::Core.Type{API.MlirAffineExpr}, expr::AffineExpr) = expr +Base.unsafe_convert(::Core.Type{API.MlirAffineExpr}, expr::AffineExpr) = expr.ref """ ==(a, b) @@ -56,8 +52,9 @@ ismultipleof(expr::AffineExpr, factor) = API.mlirAffineExprIsMultipleOf(expr, fa Checks whether the given affine expression involves AffineDimExpr 'position'. """ -isfunctionofdimexpr(expr::AffineExpr, position) = - API.mlirAffineExprIsFunctionOfDim(expr, position) +function isfunctionofdimexpr(expr::AffineExpr, position) + return API.mlirAffineExprIsFunctionOfDim(expr, position) +end """ isdimexpr(affineExpr) @@ -71,8 +68,9 @@ isdimexpr(expr::AffineExpr) = API.mlirAffineExprIsADim(expr) Creates an affine dimension expression with 'position' in the context. """ -AffineDimensionExpr(position; context::Context=context()) = - AffineExpr(API.mlirAffineDimExprGet(context, position)) +function AffineDimensionExpr(position; context::Context=current_context()) + return AffineExpr(API.mlirAffineDimExprGet(context, position)) +end """ issymbolexpr(affineExpr) @@ -82,12 +80,13 @@ Checks whether the given affine expression is a symbol expression. issymbolexpr(expr::AffineExpr) = API.mlirAffineExprIsASymbol(expr) """ - SymbolExpr(position; context=context()) + SymbolExpr(position; context=current_context()) Creates an affine symbol expression with 'position' in the context. """ -SymbolExpr(position; context::Context=context()) = - AffineExpr(API.mlirAffineSymbolExprGet(context, position)) +function SymbolExpr(position; context::Context=current_context()) + return AffineExpr(API.mlirAffineSymbolExprGet(context, position)) +end """ position(affineExpr) @@ -116,12 +115,13 @@ Checks whether the given affine expression is a constant expression. isconstantexpr(expr::AffineExpr) = API.mlirAffineExprIsAConstant(expr) """ - ConstantExpr(constant::Int; context=context()) + ConstantExpr(constant::Int; context=current_context()) Creates an affine constant expression with 'constant' in the context. """ -ConstantExpr(constant; context::Context=context()) = - AffineExpr(API.mlirAffineConstantExprGet(context, constant)) +function ConstantExpr(constant; context::Context=current_context()) + return AffineExpr(API.mlirAffineConstantExprGet(context, constant)) +end """ value(affineExpr) @@ -189,8 +189,10 @@ isfloordiv(expr::AffineExpr) = API.mlirAffineExprIsAFloorDiv(expr) Creates an affine floordiv expression with 'lhs' and 'rhs'. """ -Base.div(lhs::AffineExpr, rhs::AffineExpr) = - AffineExpr(API.mlirAffineFloorDivExprGet(lhs, rhs)) +function Base.div(lhs::AffineExpr, rhs::AffineExpr) + return AffineExpr(API.mlirAffineFloorDivExprGet(lhs, rhs)) +end + Base.fld(lhs::AffineExpr, rhs::AffineExpr) = div(lhs, rhs) """ @@ -205,8 +207,9 @@ isceildiv(expr::AffineExpr) = API.mlirAffineExprIsACeilDiv(expr) Creates an affine ceildiv expression with 'lhs' and 'rhs'. """ -Base.cld(lhs::AffineExpr, rhs::AffineExpr) = - AffineExpr(API.mlirAffineCeilDivExprGet(lhs, rhs)) +function Base.cld(lhs::AffineExpr, rhs::AffineExpr) + return AffineExpr(API.mlirAffineCeilDivExprGet(lhs, rhs)) +end """ isbinary(affineExpr) diff --git a/src/IR/AffineMap.jl b/src/IR/AffineMap.jl index f337e696..4e7e70b1 100644 --- a/src/IR/AffineMap.jl +++ b/src/IR/AffineMap.jl @@ -1,21 +1,19 @@ -struct AffineMap - map::API.MlirAffineMap - - function AffineMap(map::API.MlirAffineMap) - @assert !mlirIsNull(map) "cannot create AffineMap with null MlirAffineMap" - return new(map) - end +@checked struct AffineMap + ref::API.MlirAffineMap end """ - AffineMap(; context=context()) + AffineMap(; context=current_context()) Creates a zero result affine map with no dimensions or symbols in the context. The affine map is owned by the context. """ -AffineMap(; context::Context=context()) = AffineMap(API.mlirAffineMapEmptyGet(context)) +function AffineMap(; context::Context=current_context()) + return AffineMap(API.mlirAffineMapEmptyGet(context)) +end -Base.convert(::Core.Type{API.MlirAffineMap}, map::AffineMap) = map.map +Base.cconvert(::Core.Type{API.MlirAffineMap}, map::AffineMap) = map +Base.unsafe_convert(::Core.Type{API.MlirAffineMap}, map::AffineMap) = map.ref """ ==(a, b) @@ -25,78 +23,90 @@ Checks if two affine maps are equal. Base.:(==)(a::AffineMap, b::AffineMap) = API.mlirAffineMapEqual(a, b) """ - compose(affineExpr, affineMap) + context(affineMap) -Composes the given map with the given expression. +Gets the context that the given affine map was created with. """ -compose(expr::AffineExpr, map::AffineMap) = AffineExpr(API.mlirAffineExprCompose(expr, map)) +context(map::AffineMap) = Context(API.mlirAffineMapGetContext(map)) + +function Base.show(io::IO, map::AffineMap) + print(io, "AffineMap(#= ") + c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) + ref = Ref(io) + API.mlirAffineMapPrint(map, c_print_callback, ref) + return print(io, " =#)") +end """ - context(affineMap) + compose(affineExpr, affineMap) -Gets the context that the given affine map was created with. +Composes the given map with the given expression. """ -context(map::AffineMap) = API.mlirAffineMapGetContext(map) +compose(expr::AffineExpr, map::AffineMap) = AffineExpr(API.mlirAffineExprCompose(expr, map)) """ - AffineMap(ndims, nsymbols; context=context()) + AffineMap(ndims, nsymbols; context=current_context()) Creates a zero result affine map of the given dimensions and symbols in the context. The affine map is owned by the context. """ -AffineMap(ndims, nsymbols; context::Context=context()) = - AffineMap(API.mlirAffineMapZeroResultGet(context, ndims, nsymbols)) +function AffineMap(ndims, nsymbols; context::Context=current_context()) + return AffineMap(API.mlirAffineMapZeroResultGet(context, ndims, nsymbols)) +end """ - AffineMap(ndims, nsymbols, affineExprs; context=context()) + AffineMap(ndims, nsymbols, affineExprs; context=current_context()) Creates an affine map with results defined by the given list of affine expressions. The map resulting map also has the requested number of input dimensions and symbols, regardless of them being used in the results. """ -AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context()) = - AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), pointer(exprs))) +function AffineMap( + ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=current_context() +) + return AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), exprs)) +end """ - ConstantAffineMap(val; context=context()) + ConstantAffineMap(val; context=current_context()) Creates a single constant result affine map in the context. The affine map is owned by the context. """ -ConstantAffineMap(val; context::Context=context()) = - AffineMap(API.mlirAffineMapConstantGet(context, val)) +function ConstantAffineMap(val; context::Context=current_context()) + return AffineMap(API.mlirAffineMapConstantGet(context, val)) +end """ - IdentityAffineMap(ndims; context=context()) + IdentityAffineMap(ndims; context=current_context()) Creates an affine map with 'ndims' identity in the context. The affine map is owned by the context. """ -IdentityAffineMap(ndims; context::Context=context()) = - AffineMap(API.mlirAffineMapMultiDimIdentityGet(context, ndims)) +function IdentityAffineMap(ndims; context::Context=current_context()) + return AffineMap(API.mlirAffineMapMultiDimIdentityGet(context, ndims)) +end """ - MinorIdentityAffineMap(ndims, nresults; context=context()) + MinorIdentityAffineMap(ndims, nresults; context=current_context()) Creates an identity affine map on the most minor dimensions in the context. The affine map is owned by the context. The function asserts that the number of dimensions is greater or equal to the number of results. """ -function MinorIdentityAffineMap(ndims, nresults; context::Context=context()) +function MinorIdentityAffineMap(ndims, nresults; context::Context=current_context()) @assert ndims >= nresults "number of dimensions must be greater or equal to the number of results" return AffineMap(API.mlirAffineMapMinorIdentityGet(context, ndims, nresults)) end """ - PermutationAffineMap(permutation; context=context()) + PermutationAffineMap(permutation; context=current_context()) Creates an affine map with a permutation expression and its size in the context. The permutation expression is a non-empty vector of integers. The elements of the permutation vector must be continuous from 0 and cannot be repeated (i.e. `[1,2,0]` is a valid permutation. `[2,0]` or `[1,1,2]` is an invalid invalid permutation). The affine map is owned by the context. """ -function PermutationAffineMap(permutation; context::Context=context()) +function PermutationAffineMap(permutation; context::Context=current_context()) @assert Base.isperm(permutation) "$permutation must be a valid permutation" zero_perm = permutation .- 1 - return AffineMap( - API.mlirAffineMapPermutationGet(context, length(zero_perm), pointer(zero_perm)) - ) + return AffineMap(API.mlirAffineMapPermutationGet(context, length(zero_perm), zero_perm)) end """ @@ -191,8 +201,9 @@ Base.isperm(map::AffineMap) = API.mlirAffineMapIsPermutation(map) Returns the affine map consisting of the `positions` subset. """ -submap(map::AffineMap, pos::Vector{Int}) = - AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pointer(pos))) +function submap(map::AffineMap, pos::Vector{Int}) + return AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pos)) +end """ majorsubmap(affineMap, nresults) @@ -201,8 +212,9 @@ Returns the affine map consisting of the most major `nresults` results. Returns the null AffineMap if the `nresults` is equal to zero. Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map. """ -majorsubmap(map::AffineMap, nresults) = - AffineMap(API.mlirAffineMapGetMajorSubMap(map, nresults)) +function majorsubmap(map::AffineMap, nresults) + return AffineMap(API.mlirAffineMapGetMajorSubMap(map, nresults)) +end """ minorsubmap(affineMap, nresults) @@ -210,19 +222,24 @@ majorsubmap(map::AffineMap, nresults) = Returns the affine map consisting of the most minor `nresults` results. Returns the null AffineMap if the `nresults` is equal to zero. Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map. """ -minorsubmap(map::AffineMap, nresults) = - AffineMap(API.mlirAffineMapGetMinorSubMap(map, nresults)) +function minorsubmap(map::AffineMap, nresults) + return AffineMap(API.mlirAffineMapGetMinorSubMap(map, nresults)) +end """ mlirAffineMapReplace(affineMap, expression => replacement, numResultDims, numResultSyms) Apply `AffineExpr::replace(map)` to each of the results and return a new new AffineMap with the new results and the specified number of dims and symbols. """ -Base.replace( +function Base.replace( map::AffineMap, old_new::Pair{AffineExpr,AffineExpr}, nresultdims, nresultsyms -) = AffineMap( - API.mlirAffineMapReplace(map, old_new.first, old_new.second, nresultdims, nresultsyms), ) + return AffineMap( + API.mlirAffineMapReplace( + map, old_new.first, old_new.second, nresultdims, nresultsyms + ), + ) +end """ simplify(affineMaps, size, result, populateResult) @@ -232,15 +249,7 @@ Asserts that all maps in `affineMaps` are normalized to the same number of dims Takes a callback `populateResult` to fill the `res` container with value `m` at entry `idx`. This allows returning without worrying about ownership considerations. """ -# TODO simplify(map::AffineMap, ...) = AffineMap(API.mlirAffineMapCompressUnusedSymbols(map, ...)) - -function Base.show(io::IO, map::AffineMap) - print(io, "AffineMap(#= ") - c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) - ref = Ref(io) - API.mlirAffineMapPrint(map, c_print_callback, ref) - return print(io, " =#)") -end +# TODO(#2246) simplify(map::AffineMap, ...) = AffineMap(API.mlirAffineMapCompressUnusedSymbols(map, ...)) walk(f, other) = f(other) function walk(f, expr::Expr) @@ -258,12 +267,8 @@ On the right hand side are allowed the following function calls: The rhs can only contains dimensions and symbols present on the left hand side or integer literals. -```juliadoctest -julia> using MLIR: IR, AffineUtils - -julia> IR.context!(IR.Context()) do - IR.@affinemap (d1, d2)[s0] -> (d1 + s0, d2 % 10) - end +```julia +julia> @affinemap (d1, d2)[s0] -> (d1 + s0, d2 % 10) MLIR.IR.AffineMap(#= (d0, d1)[s0] -> (d0 + s0, d1 mod 10) =#) ``` """ @@ -287,11 +292,11 @@ macro affinemap(expr) @assert all(x -> x isa Symbol, syms) "invalid symbols $syms" dimexprs = map(enumerate(dims)) do (i, dim) - :($dim = AffineDimensionExpr($(i - 1))) + return :($dim = AffineDimensionExpr($(i - 1))) end symexprs = map(enumerate(syms)) do (i, sym) - :($sym = SymbolExpr($(i - 1))) + return :($sym = SymbolExpr($(i - 1))) end known_binops = [:+, :-, :*, :÷, :%, :fld, :cld] diff --git a/src/IR/Attribute.jl b/src/IR/Attribute.jl index e8195e5d..22e89719 100644 --- a/src/IR/Attribute.jl +++ b/src/IR/Attribute.jl @@ -1,5 +1,6 @@ +# ref is allowed to be null struct Attribute - attribute::API.MlirAttribute + ref::API.MlirAttribute end """ @@ -9,15 +10,19 @@ Returns an empty attribute. """ Attribute() = Attribute(API.mlirAttributeGetNull()) -Base.convert(::Core.Type{API.MlirAttribute}, attribute::Attribute) = attribute.attribute +Attribute(attr::Attribute) = attr + +Base.cconvert(::Core.Type{API.MlirAttribute}, attr::Attribute) = attr +Base.unsafe_convert(::Core.Type{API.MlirAttribute}, attr::Attribute) = attr.ref """ - parse(::Core.Type{Attribute}, str; context=context()) + parse(::Core.Type{Attribute}, str; context=current_context()) Parses an attribute. The attribute is owned by the context. """ -Base.parse(::Core.Type{Attribute}, str; context::Context=context()) = - Attribute(API.mlirAttributeParseGet(context, str)) +function Base.parse(::Core.Type{Attribute}, str; context::Context=current_context()) + return Attribute(API.mlirAttributeParseGet(context, str)) +end """ ==(a1, a2) @@ -76,12 +81,13 @@ Checks whether the given attribute is an array attribute. isarray(attr::Attribute) = API.mlirAttributeIsAArray(attr) """ - Attribute(elements; context=context()) + Attribute(elements; context=current_context()) Creates an array element containing the given list of elements in the given context. """ -Attribute(attrs::Vector{Attribute}; context::Context=context()) = - Attribute(API.mlirArrayAttrGet(context, length(attrs), pointer(attrs))) +function Attribute(attrs::Vector{Attribute}; context::Context=current_context()) + return Attribute(API.mlirArrayAttrGet(context, length(attrs), attrs)) +end """ isdict(attr) @@ -91,13 +97,13 @@ Checks whether the given attribute is a dictionary attribute. isdict(attr::Attribute) = API.mlirAttributeIsADictionary(attr) """ - Attribute(elements; context=context()) + Attribute(elements; context=current_context()) Creates a dictionary attribute containing the given list of elements in the provided context. """ -function Attribute(attrs::Dict; context::Context=context()) - attrs = map(splat(NamedAttribute), attrs) - return Attribute(API.mlirDictionaryAttrGet(context, length(attrs), pointer(attrs))) +function Attribute(attrs::Dict; context::Context=current_context()) + attrs = [NamedAttribute(k, Attribute(v); context) for (k, v) in attrs] + return Attribute(API.mlirDictionaryAttrGet(context, length(attrs), attrs)) end """ @@ -108,13 +114,16 @@ Checks whether the given attribute is a floating point attribute. isfloat(attr::Attribute) = API.mlirAttributeIsAFloat(attr) """ - Attribute(float; context=context(), location=Location(), check=false) + Attribute(float; context=current_context(), location=Location(), check=false) Creates a floating point attribute in the given context with the given double value and double-precision FP semantics. If `check=true`, emits appropriate diagnostics on illegal arguments. """ function Attribute( - f::T; context::Context=context(), location::Location=Location(), check::Bool=false + f::T; + context::Context=current_context(), + location::Location=Location(), + check::Bool=false, ) where {T<:AbstractFloat} if check Attribute(API.mlirFloatAttrDoubleGetChecked(location, Type(T), Float64(f))) @@ -133,6 +142,32 @@ function Base.Float64(attr::Attribute) return API.mlirFloatAttrGetValueDouble(attr) end +""" + Attribute(complex; context=current_context(), location=Location(), check=false) + +Creates a complex attribute in the given context with the given complex value and double-precision FP semantics. +""" +function Attribute( + c::T; + context::Context=current_context(), + location::Location=Location(), + check::Bool=false, +) where {T<:Complex} + if check + Attribute( + API.mlirComplexAttrDoubleGetChecked( + location, Type(T), Float64(real(c)), Float64(imag(c)) + ), + ) + else + Attribute( + API.mlirComplexAttrDoubleGet( + context, Type(T), Float64(real(c)), Float64(imag(c)) + ), + ) + end +end + """ isinteger(attr) @@ -145,8 +180,9 @@ isinteger(attr::Attribute) = API.mlirAttributeIsAInteger(attr) Creates an integer attribute of the given type with the given integer value. """ -Attribute(i::T, type=Type(T)) where {T<:Integer} = - Attribute(API.mlirIntegerAttrGet(type, Int64(i))) +function Attribute(i::T, type=Type(T)) where {T<:Integer} + return Attribute(API.mlirIntegerAttrGet(type, Int64(i))) +end """ Int64(attr) @@ -158,7 +194,7 @@ function Base.Int64(attr::Attribute) return API.mlirIntegerAttrGetValueInt(attr) end -# TODO mlirIntegerAttrGetValueSInt +# TODO(#2244) mlirIntegerAttrGetValueSInt """ UInt64(attr) @@ -178,11 +214,13 @@ Checks whether the given attribute is a bool attribute. isbool(attr::Attribute) = API.mlirAttributeIsABool(attr) """ - Attribute(value; context=context()) + Attribute(value; context=current_context()) Creates a bool attribute in the given context with the given value. """ -Attribute(b::Bool; context::Context=context()) = Attribute(API.mlirBoolAttrGet(context, b)) +function Attribute(b::Bool; context::Context=current_context()) + return Attribute(API.mlirBoolAttrGet(context, b)) +end """ Bool(attr) @@ -209,13 +247,14 @@ Checks whether the given attribute is an opaque attribute. isopaque(attr::Attribute) = API.mlirAttributeIsAOpaque(attr) """ - OpaqueAttribute(dialectNamespace, dataLength, data, type; context=context()) + OpaqueAttribute(dialectNamespace, dataLength, data, type; context=current_context()) Creates an opaque attribute in the given context associated with the dialect identified by its namespace. The attribute contains opaque byte data of the specified length (data need not be null-terminated). """ -OpaqueAttribute(namespace, data, type; context::Context=context) = - Attribute(API.mlirOpaqueAttrGet(context, namespace, length(data), data, type)) +function OpaqueAttribute(namespace, data, type; context::Context=context) + return Attribute(API.mlirOpaqueAttrGet(context, namespace, length(data), data, type)) +end """ mlirOpaqueAttrGetDialectNamespace(attr) @@ -234,7 +273,7 @@ Returns the raw data as a string reference. The data remains live as long as the """ function data(attr::Attribute) @assert isopaque(attr) "attribute $(attr) is not an opaque attribute" - return String(API.mlirOpaqueAttrGetData(attr)) # TODO return as Base.CodeUnits{Int8,String}? or as a Vector{Int8}? or Pointer? + return String(API.mlirOpaqueAttrGetData(attr)) # TODO(#2244) return as Base.CodeUnits{Int8,String}? or as a Vector{Int8}? or Pointer? end """ @@ -245,12 +284,13 @@ Checks whether the given attribute is a string attribute. isstring(attr::Attribute) = API.mlirAttributeIsAString(attr) """ - Attribute(str; context=context()) + Attribute(str; context=current_context()) Creates a string attribute in the given context containing the given string. """ -Attribute(str::AbstractString; context::Context=context()) = - Attribute(API.mlirStringAttrGet(context, str)) +function Attribute(str::AbstractString; context::Context=current_context()) + return Attribute(API.mlirStringAttrGet(context, str)) +end """ Attribute(type, str) @@ -279,16 +319,18 @@ Checks whether the given attribute is a symbol reference attribute. issymbolref(attr::Attribute) = API.mlirAttributeIsASymbolRef(attr) """ - SymbolRefAttribute(symbol, references; context=context()) + SymbolRefAttribute(symbol, references; context=current_context()) Creates a symbol reference attribute in the given context referencing a symbol identified by the given string inside a list of nested references. Each of the references in the list must not be nested. """ -SymbolRefAttribute( - symbol::String, references::Vector{Attribute}; context::Context=context() -) = Attribute( - API.mlirSymbolRefAttrGet(context, symbol, length(references), pointer(references)) +function SymbolRefAttribute( + symbol::String, references::Vector{Attribute}; context::Context=current_context() ) + return Attribute( + API.mlirSymbolRefAttrGet(context, symbol, length(references), references) + ) +end """ rootref(attr) @@ -332,8 +374,9 @@ isflatsymbolref(attr::Attribute) = API.mlirAttributeIsAFlatSymbolRef(attr) Creates a flat symbol reference attribute in the given context referencing a symbol identified by the given string. """ -FlatSymbolRefAttribute(symbol::String; context::Context=context()) = - Attribute(API.mlirFlatSymbolRefAttrGet(context, symbol)) +function FlatSymbolRefAttribute(symbol::String; context::Context=current_context()) + return Attribute(API.mlirFlatSymbolRefAttrGet(context, symbol)) +end """ flatsymbol(attr) @@ -374,11 +417,13 @@ Checks whether the given attribute is a unit attribute. isunit(attr::Attribute) = API.mlirAttributeIsAUnit(attr) """ - UnitAttribute(; context=context()) + UnitAttribute(; context=current_context()) Creates a unit attribute in the given context. """ -UnitAttribute(; context::Context=context()) = Attribute(API.mlirUnitAttrGet(context)) +function UnitAttribute(; context::Context=current_context()) + return Attribute(API.mlirUnitAttrGet(context)) +end """ iselements(attr) @@ -387,8 +432,8 @@ Checks whether the given attribute is an elements attribute. """ iselements(attr::Attribute) = API.mlirAttributeIsAElements(attr) -# TODO mlirElementsAttrGetValue -# TODO mlirElementsAttrIsValidIndex +# TODO(#2244) mlirElementsAttrGetValue +# TODO(#2244) mlirElementsAttrIsValidIndex """ isdenseelements(attr) @@ -406,12 +451,10 @@ Creates a dense elements attribute with the given Shaped type and elements in th """ function DenseElementsAttribute(shaped_type::Type, elements::AbstractArray) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return Attribute( - API.mlirDenseElementsAttrGet(shaped_type, length(elements), pointer(elements)) - ) + return Attribute(API.mlirDenseElementsAttrGet(shaped_type, length(elements), elements)) end -# TODO mlirDenseElementsAttrRawBufferGet +# TODO(#2244) mlirDenseElementsAttrRawBufferGet """ fill(attr, shapedType) @@ -420,147 +463,181 @@ Creates a dense elements attribute with the given Shaped type containing a singl """ function Base.fill(attr::Attribute, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return Attribute(API.mlirDenseElementsAttrSplatGet(attr, shaped_type)) + return Attribute(API.mlirDenseElementsAttrSplatGet(shaped_type, attr)) end function Base.fill(value::Bool, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrBoolSplatGet(value, shaped_type) + return API.mlirDenseElementsAttrBoolSplatGet(shaped_type, value) end function Base.fill(value::UInt8, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrUInt8SplatGet(value, shaped_type) + return API.mlirDenseElementsAttrUInt8SplatGet(shaped_type, value) end function Base.fill(value::Int8, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrInt8SplatGet(value, shaped_type) + return API.mlirDenseElementsAttrInt8SplatGet(shaped_type, value) end function Base.fill(value::UInt32, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrUInt32SplatGet(value, shaped_type) + return API.mlirDenseElementsAttrUInt32SplatGet(shaped_type, value) end function Base.fill(value::Int32, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrInt32SplatGet(value, shaped_type) + return API.mlirDenseElementsAttrInt32SplatGet(shaped_type, value) end function Base.fill(value::UInt64, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrUInt64SplatGet(value, shaped_type) + return API.mlirDenseElementsAttrUInt64SplatGet(shaped_type, value) end function Base.fill(value::Int64, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrInt64SplatGet(value, shaped_type) + return API.mlirDenseElementsAttrInt64SplatGet(shaped_type, value) end function Base.fill(value::Float32, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrFloatSplatGet(value, shaped_type) + return API.mlirDenseElementsAttrFloatSplatGet(shaped_type, value) end function Base.fill(value::Float64, shaped_type::Type) @assert isshaped(shaped_type) "type $(shaped_type) is not a shaped type" - return API.mlirDenseElementsAttrDoubleSplatGet(value, shaped_type) + return API.mlirDenseElementsAttrDoubleSplatGet(shaped_type, value) end -function Base.fill(::Core.Type{Attribute}, value, shape) +function Base.fill(::Core.Type{Attribute}, value, shape::Vector{Int}) shaped_type = TensorType(shape, Type(typeof(value))) return Base.fill(value, shaped_type) end +to_row_major(x) = permutedims(x, ndims(x):-1:1) +to_row_major(x::AbstractVector) = x +to_row_major(x::AbstractArray{T,0}) where {T} = x + """ DenseElementsAttribute(array::AbstractArray) Creates a dense elements attribute with the given shaped type from elements of a specific type. Expects the element type of the shaped type to match the data element type. """ -function DenseElementsAttribute(values::AbstractVector{Bool}) - shaped_type = TensorType(size(values), Type(Bool)) +function DenseElementsAttribute(values::AbstractArray{Bool}) + shaped_type = TensorType(collect(Int, size(values)), Type(Bool)) return Attribute( - API.mlirDenseElementsAttrBoolGet(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrBoolGet( + shaped_type, length(values), AbstractArray{Cint}(to_row_major(values)) + ), ) end function DenseElementsAttribute(values::AbstractArray{UInt8}) - shaped_type = TensorType(size(values), Type(UInt8)) + shaped_type = TensorType(collect(Int, size(values)), Type(UInt8)) return Attribute( - API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrUInt8Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{Int8}) - shaped_type = TensorType(size(values), Type(Int8)) + shaped_type = TensorType(collect(Int, size(values)), Type(Int8)) return Attribute( - API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrInt8Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{UInt16}) - shaped_type = TensorType(size(values), Type(UInt16)) + shaped_type = TensorType(collect(Int, size(values)), Type(UInt16)) return Attribute( - API.mlirDenseElementsAttrUInt16Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrUInt16Get( + shaped_type, length(values), to_row_major(values) + ), ) end function DenseElementsAttribute(values::AbstractArray{Int16}) - shaped_type = TensorType(size(values), Type(Int16)) + shaped_type = TensorType(collect(Int, size(values)), Type(Int16)) return Attribute( - API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrInt16Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{UInt32}) - shaped_type = TensorType(size(values), Type(UInt32)) + shaped_type = TensorType(collect(Int, size(values)), Type(UInt32)) return Attribute( - API.mlirDenseElementsAttrUInt32Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrUInt32Get( + shaped_type, length(values), to_row_major(values) + ), ) end function DenseElementsAttribute(values::AbstractArray{Int32}) - shaped_type = TensorType(size(values), Type(Int32)) + shaped_type = TensorType(collect(Int, size(values)), Type(Int32)) return Attribute( - API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrInt32Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{UInt64}) - shaped_type = TensorType(size(values), Type(UInt64)) + shaped_type = TensorType(collect(Int, size(values)), Type(UInt64)) return Attribute( - API.mlirDenseElementsAttrUInt64Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrUInt64Get( + shaped_type, length(values), to_row_major(values) + ), ) end function DenseElementsAttribute(values::AbstractArray{Int64}) - shaped_type = TensorType(size(values), Type(Int64)) + shaped_type = TensorType(collect(Int, size(values)), Type(Int64)) return Attribute( - API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrInt64Get(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{Float32}) - shaped_type = TensorType(size(values), Type(Float32)) + shaped_type = TensorType(collect(Int, size(values)), Type(Float32)) return Attribute( - API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrFloatGet(shaped_type, length(values), to_row_major(values)) ) end function DenseElementsAttribute(values::AbstractArray{Float64}) - shaped_type = TensorType(size(values), Type(Float64)) + shaped_type = TensorType(collect(Int, size(values)), Type(Float64)) return Attribute( - API.mlirDenseElementsAttrDoubleGet(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrDoubleGet( + shaped_type, length(values), to_row_major(values) + ), ) end -# TODO mlirDenseElementsAttrBFloat16Get +if isdefined(Core, :BFloat16) + function DenseElementsAttribute(values::AbstractArray{Core.BFloat16}) + shaped_type = TensorType(collect(Int, size(values)), Type(Core.BFloat16)) + return Attribute( + API.mlirDenseElementsAttrBFloat16Get( + shaped_type, length(values), to_row_major(values) + ), + ) + end +end function DenseElementsAttribute(values::AbstractArray{Float16}) - shaped_type = TensorType(size(values), Type(Float16)) + shaped_type = TensorType(collect(Int, size(values)), Type(Float16)) return Attribute( - API.mlirDenseElementsAttrFloat16Get(shaped_type, length(values), pointer(values)) + API.mlirDenseElementsAttrFloat16Get( + shaped_type, length(values), to_row_major(values) + ), + ) +end + +function DenseElementsAttribute(values::AbstractArray) + shaped_type = TensorType(collect(Int, size(values)), Type(eltype(values))) + return Attribute( + API.mlirDenseElementsAttrRawBufferGet( + shaped_type, length(values) * Base.elsize(values), to_row_major(values) + ), ) end @@ -570,22 +647,15 @@ end Creates a dense elements attribute with the given shaped type from string elements. """ function DenseElementsAttribute(values::AbstractArray{String}) - # TODO may fail because `Type(String)` is not defined - shaped_type = TensorType(size(values), Type(String)) - return Attribute( - API.mlirDenseElementsAttrStringGet(shaped_type, length(values), pointer(values)) - ) -end - -function Attribute(values::AbstractArray) - MLIR_VERSION[] >= v"15" || throw( - MLIRException("`Attribute(::AbstractArray)` requires MLIR version 15 or later") - ) - if MLIR_VERSION[] < v"16" - DenseElementsAttribute(values) - else - DenseArrayAttribute(values) - end + # Builtin dialect doesn't have string type. If we want to support this, + # we need to add this in our dialect. + throw(MethodError(DenseElementsAttribute, (values,))) + # shaped_type = TensorType(collect(Int, size(values)), Type(String)) + # return Attribute( + # API.mlirDenseElementsAttrStringGet( + # shaped_type, length(values), to_row_major(values) + # ), + # ) end """ @@ -593,12 +663,12 @@ end Creates a dense elements attribute that has the same data as the given dense elements attribute and a different shaped type. The new type must have the same total number of elements. """ -function Base.reshape(attr::Attribute, shape) +function Base.reshape(attr::Attribute, shape::Vector{Int}) @assert isdenseelements(attr) "attribute $(attr) is not a dense elements attribute" @assert length(attr) == prod(shape) "new shape $(shape) has a different number of elements than the original attribute" element_type = eltype(type(attr)) shaped_type = TensorType(shape, element_type) - return Attribute(API.mlirDenseElementsAttrReshape(attr, shaped_type)) + return Attribute(API.mlirDenseElementsAttrReshapeGet(attr, shaped_type)) end """ @@ -608,22 +678,10 @@ Checks whether the given dense elements attribute contains a single replicated v """ function issplat(attr::Attribute) @assert isdenseelements(attr) "attribute $(attr) is not a dense elements attribute" - return API.mlirDenseElementsAttrIsSplat(attr) # TODO Base.allequal? + return API.mlirDenseElementsAttrIsSplat(attr) # TODO(#2244) Base.allequal? end -# TODO mlirDenseElementsAttrGetRawData - -""" - isopaqueelements(attr) - -Checks whether the given attribute is an opaque elements attribute. -""" -function isopaqueelements(attr::Attribute) - MLIR_VERSION[] >= v"15" || throw( - MLIRException("`isopaqueelements(::Attribute)` requires MLIR version 15 or later"), - ) - return API.mlirAttributeIsAOpaqueElements(attr) -end +# TODO(#2244) mlirDenseElementsAttrGetRawData """ issparseelements(attr) @@ -632,150 +690,100 @@ Checks whether the given attribute is a sparse elements attribute. """ issparseelements(attr::Attribute) = API.mlirAttributeIsASparseElements(attr) -# TODO mlirSparseElementsAttribute -# TODO mlirSparseElementsAttrGetIndices -# TODO mlirSparseElementsAttrGetValues +# TODO(#2244) mlirSparseElementsAttribute +# TODO(#2244) mlirSparseElementsAttrGetIndices +# TODO(#2244) mlirSparseElementsAttrGetValues """ - isdensearray(attr, ::Core.Type{T}) + isdensearray(attr, ::Core.Type{T}) -Checks whether the given attribute is a dense array attribute. -""" + Checks whether the given attribute is a dense array attribute. + """ function isdensearray end -function isdensearray(attr::Attribute, ::Core.Type{Bool}) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`isdensearray(::Attribute, ::Core.Type{Bool})` requires MLIR version 16 or later", - ), - ) - return API.mlirAttributeIsADenseBoolArray(attr) -end - -function isdensearray(attr::Attribute, ::Core.Type{Int8}) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`isdensearray(::Attribute, ::Core.Type{Int8})` requires MLIR version 16 or later", - ), - ) - return API.mlirAttributeIsADenseI8Array(attr) -end - -function isdensearray(attr::Attribute, ::Core.Type{Int16}) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`isdensearray(::Attribute, ::Core.Type{Int16})` requires MLIR version 16 or later", - ), - ) - return API.mlirAttributeIsADenseI16Array(attr) -end - -function isdensearray(attr::Attribute, ::Core.Type{Int32}) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`isdensearray(::Attribute, ::Core.Type{Int32})` requires MLIR version 16 or later", - ), - ) - return API.mlirAttributeIsADenseI32Array(attr) -end - -function isdensearray(attr::Attribute, ::Core.Type{Int64}) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`isdensearray(::Attribute, ::Core.Type{Int64})` requires MLIR version 16 or later", - ), - ) - return API.mlirAttributeIsADenseI64Array(attr) -end +isdensearray(attr::Attribute, ::Core.Type{Bool}) = API.mlirAttributeIsADenseBoolArray(attr) +isdensearray(attr::Attribute, ::Core.Type{Int8}) = API.mlirAttributeIsADenseI8Array(attr) +isdensearray(attr::Attribute, ::Core.Type{Int16}) = API.mlirAttributeIsADenseI16Array(attr) +isdensearray(attr::Attribute, ::Core.Type{Int32}) = API.mlirAttributeIsADenseI32Array(attr) +isdensearray(attr::Attribute, ::Core.Type{Int64}) = API.mlirAttributeIsADenseI64Array(attr) function isdensearray(attr::Attribute, ::Core.Type{Float32}) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`isdensearray(::Attribute, ::Core.Type{Float32})` requires MLIR version 16 or later", - ), - ) return API.mlirAttributeIsADenseF32Array(attr) end function isdensearray(attr::Attribute, ::Core.Type{Float64}) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`isdensearray(::Attribute, ::Core.Type{Float64})` requires MLIR version 16 or later", - ), - ) return API.mlirAttributeIsADenseF64Array(attr) end """ - DenseArrayAttribute(array; context=context()) + DenseArrayAttribute(array; context=current_context()) -Create a dense array attribute with the given elements. -""" + Create a dense array attribute with the given elements. + """ function DenseArrayAttribute end -function DenseArrayAttribute(values::AbstractArray{Bool}; context::Context=context()) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`DenseArrayAttribute(::AbstractArray{Bool})` requires MLIR version 16 or later", +function DenseArrayAttribute( + values::AbstractArray{Bool}; context::Context=current_context() +) + return Attribute( + API.mlirDenseBoolArrayGet( + context, length(values), AbstractArray{Cint}(to_row_major(values)) ), ) - return Attribute(API.mlirDenseBoolArrayGet(context, length(values), pointer(values))) end -function DenseArrayAttribute(values::AbstractArray{Int8}; context::Context=context()) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`DenseArrayAttribute(::AbstractArray{Int8})` requires MLIR version 16 or later", - ), - ) - return Attribute(API.mlirDenseI8ArrayGet(context, length(values), pointer(values))) +function DenseArrayAttribute( + values::AbstractArray{Int8}; context::Context=current_context() +) + return Attribute(API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values))) end -function DenseArrayAttribute(values::AbstractArray{Int16}; context::Context=context()) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`DenseArrayAttribute(::AbstractArray{Int16})` requires MLIR version 16 or later", - ), +# function DenseArrayAttribute(values::AbstractArray{UInt8}; context::Context=current_context()) +# return Attribute(API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values))) +# end + +function DenseArrayAttribute( + values::AbstractArray{Int16}; context::Context=current_context() +) + return Attribute( + API.mlirDenseI16ArrayGet(context, length(values), to_row_major(values)) ) - return Attribute(API.mlirDenseI16ArrayGet(context, length(values), pointer(values))) end -function DenseArrayAttribute(values::AbstractArray{Int32}; context::Context=context()) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`DenseArrayAttribute(::AbstractArray{Int32})` requires MLIR version 16 or later", - ), +function DenseArrayAttribute( + values::AbstractArray{Int32}; context::Context=current_context() +) + return Attribute( + API.mlirDenseI32ArrayGet(context, length(values), to_row_major(values)) ) - return Attribute(API.mlirDenseI32ArrayGet(context, length(values), pointer(values))) end -function DenseArrayAttribute(values::AbstractArray{Int64}; context::Context=context()) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`DenseArrayAttribute(::AbstractArray{Int64})` requires MLIR version 16 or later", - ), +function DenseArrayAttribute( + values::AbstractArray{Int64}; context::Context=current_context() +) + return Attribute( + API.mlirDenseI64ArrayGet(context, length(values), to_row_major(values)) ) - return Attribute(API.mlirDenseI64ArrayGet(context, length(values), pointer(values))) end -function DenseArrayAttribute(values::AbstractArray{Float32}; context::Context=context()) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`DenseArrayAttribute(::AbstractArray{Float32})` requires MLIR version 16 or later", - ), +function DenseArrayAttribute( + values::AbstractArray{Float32}; context::Context=current_context() +) + return Attribute( + API.mlirDenseF32ArrayGet(context, length(values), to_row_major(values)) ) - return Attribute(API.mlirDenseF32ArrayGet(context, length(values), pointer(values))) end -function DenseArrayAttribute(values::AbstractArray{Float64}; context::Context=context()) - MLIR_VERSION[] >= v"16" || throw( - MLIRException( - "`DenseArrayAttribute(::AbstractArray{Float64})` requires MLIR version 16 or later", - ), +function DenseArrayAttribute( + values::AbstractArray{Float64}; context::Context=current_context() +) + return Attribute( + API.mlirDenseF64ArrayGet(context, length(values), to_row_major(values)) ) - return Attribute(API.mlirDenseF64ArrayGet(context, length(values), pointer(values))) end +Attribute(values::AbstractArray) = DenseArrayAttribute(values) + function Base.length(attr::Attribute) if isarray(attr) API.mlirArrayAttrGetNumElements(attr) @@ -783,12 +791,12 @@ function Base.length(attr::Attribute) API.mlirDictionaryAttrGetNumElements(attr) elseif iselements(attr) API.mlirElementsAttrGetNumElements(attr) - elseif MLIR_VERSION[] >= v"16" + else _isdensearray = any( T -> isdensearray(attr, T), [Bool, Int8, Int16, Int32, Int64, Float32, Float64] ) if _isdensearray - API.mlirDenseBoolArrayGetNumElements(attr) + API.mlirDenseArrayGetNumElements(attr) end end end @@ -826,12 +834,12 @@ function Base.getindex(attr::Attribute, i) API.mlirDenseElementsAttrGetFloatValue(attr, i) elseif elem_type isa Float64 API.mlirDenseElementsAttrGetDoubleValue(attr, i) - elseif elem_type isa String # TODO does this case work? + elseif elem_type isa String # TODO(#2244) does this case work? String(API.mlirDenseElementsAttrGetStringValue(attr, i)) else throw("unsupported element type $(elem_type)") end - elseif MLIR_VERSION[] >= v"16" + else if isdensearray(attr, Bool) API.mlirDenseBoolArrayGetElement(attr, i) elseif isdensearray(attr, Int8) @@ -850,6 +858,14 @@ function Base.getindex(attr::Attribute, i) end end +function Base.iterate(attr::Attribute, state=1) + if state > length(attr) + nothing + else + (attr[state], state + 1) + end +end + function Base.getindex(attr::Attribute) @assert isdenseelements(attr) "attribute $(attr) is not a dense elements attribute" @assert issplat(attr) "attribute $(attr) is not splatted (more than one different elements)" @@ -876,7 +892,7 @@ function Base.getindex(attr::Attribute) API.mlirDenseElementsAttrGetFloatSplatValue(attr) elseif elem_type isa Float64 API.mlirDenseElementsAttrGetDoubleSplatValue(attr) - elseif elem_type isa String # TODO does this case work? + elseif elem_type isa String # TODO(#2244) does this case work? String(API.mlirDenseElementsAttrGetStringSplatValue(attr)) else throw("unsupported element type $(elem_type)") @@ -891,21 +907,29 @@ function Base.show(io::IO, attribute::Attribute) return print(io, " =#)") end -struct NamedAttribute - named_attribute::API.MlirNamedAttribute -end - """ NamedAttribute(name, attr) Associates an attribute with the name. Takes ownership of neither. """ -function NamedAttribute(name, attribute; context=context(attribute)) - @assert !mlirIsNull(attribute.attribute) - name = API.mlirIdentifierGet(context, name) - return NamedAttribute(API.mlirNamedAttributeGet(name, attribute)) +struct NamedAttribute + ref::API.MlirNamedAttribute +end + +function NamedAttribute(name, attribute; kwargs...) + attr = Attribute(attribute; kwargs...) + return NamedAttribute(name, attr; kwargs...) end -function Base.convert(::Core.Type{API.MlirAttribute}, named_attribute::NamedAttribute) - return named_attribute.named_attribute +function NamedAttribute(name, attr::Attribute; context=context(attr)) + nameid = Identifier(name; context) + return NamedAttribute(nameid, attr; context) end + +function NamedAttribute(name::Identifier, attr::Attribute; context=context(attr)) + refcheck(attr.ref) + return NamedAttribute(API.mlirNamedAttributeGet(name, attr)) +end + +Base.cconvert(::Core.Type{API.MlirAttribute}, attr::NamedAttribute) = attr +Base.unsafe_convert(::Core.Type{API.MlirAttribute}, attr::NamedAttribute) = attr.ref diff --git a/src/IR/Block.jl b/src/IR/Block.jl index 2fcdadef..7ca82b47 100644 --- a/src/IR/Block.jl +++ b/src/IR/Block.jl @@ -1,18 +1,13 @@ -mutable struct Block - block::API.MlirBlock - @atomic owned::Bool - - function Block(block::API.MlirBlock, owned::Bool=true) - @assert !mlirIsNull(block) "cannot create Block with null MlirBlock" - finalizer(new(block, owned)) do block - if block.owned - API.mlirBlockDestroy(block.block) - end - end - end +""" + Block + +A `Block` is a sequence of [`Operation`](@ref)s with a list of arguments. +""" +@checked struct Block + ref::API.MlirBlock end -Block() = Block(Type[], Location[]) +Block() = Block(mark_alloc(API.mlirBlockCreate(0, C_NULL, C_NULL))) """ Block(args, locs) @@ -21,17 +16,61 @@ Creates a new empty block with the given argument types and transfers ownership """ function Block(args::Vector{Type}, locs::Vector{Location}) @assert length(args) == length(locs) "there should be one args for each locs (got $(length(args)) & $(length(locs)))" - return Block(API.mlirBlockCreate(length(args), args, locs)) + return Block(mark_alloc(API.mlirBlockCreate(length(args), args, locs))) end +""" + dispose(blk::Block) + +Disposes the given block and releases its resources. +After calling this function, the block must not be used anymore. +""" +dispose(blk::Block) = mark_dispose(API.mlirBlockDestroy, blk) + """ ==(block, other) Checks whether two blocks handles point to the same block. This does not perform deep comparison. """ Base.:(==)(a::Block, b::Block) = API.mlirBlockEqual(a, b) + Base.cconvert(::Core.Type{API.MlirBlock}, block::Block) = block -Base.unsafe_convert(::Core.Type{API.MlirBlock}, block::Block) = block.block +Base.unsafe_convert(::Core.Type{API.MlirBlock}, block::Block) = mark_use(block).ref + +Base.IteratorSize(::Core.Type{Block}) = Base.SizeUnknown() +Base.IteratorEltype(::Core.Type{Block}) = Base.HasEltype() +Base.eltype(::Block) = Operation + +""" + Base.iterate(block::Block) + +Iterates over all operations for the given block. +""" +function Base.iterate(it::Block) + raw_op = API.mlirBlockGetFirstOperation(it) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op) + (op, op) + end +end + +function Base.iterate(::Block, op) + raw_op = API.mlirOperationGetNextInBlock(op) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op) + (op, op) + end +end + +function Base.show(io::IO, block::Block) + c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) + ref = Ref(io) + return API.mlirBlockPrint(block, c_print_callback, ref) +end """ parent_op(block) @@ -82,8 +121,22 @@ end Appends an argument of the specified type to the block. Returns the newly added argument. """ -push_argument!(block::Block, type; location::Location=Location()) = - Value(API.mlirBlockAddArgument(block, type, location)) +function push_argument!(block::Block, type; location::Location=Location()) + return Value(API.mlirBlockAddArgument(block, type, location)) +end + +""" + erase_argument!(block, i) + +Erase argument `i` of the block. Returns the block. +""" +function erase_argument!(block, i) + if i ∉ 1:nargs(block) + throw(BoundsError(block, i)) + end + API.mlirBlockEraseArgument(block, i - 1) + return block +end """ first_op(block) @@ -93,7 +146,7 @@ Returns the first operation in the block or `nothing` if empty. function first_op(block::Block) op = API.mlirBlockGetFirstOperation(block) mlirIsNull(op) && return nothing - return Operation(op, false) + return Operation(op) end Base.first(block::Block) = first_op(block) @@ -105,7 +158,7 @@ Returns the terminator operation in the block or `nothing` if no terminator. function terminator(block::Block) op = API.mlirBlockGetTerminator(block) mlirIsNull(op) && return nothing - return Operation(op, false) + return Operation(op) end """ @@ -114,7 +167,7 @@ end Takes an operation owned by the caller and appends it to the block. """ function Base.push!(block::Block, op::Operation) - API.mlirBlockAppendOwnedOperation(block, lose_ownership!(op)) + API.mlirBlockAppendOwnedOperation(block, mark_donate(op)) return op end @@ -125,7 +178,7 @@ Takes an operation owned by the caller and inserts it as `index` to the block. This is an expensive operation that scans the block linearly, prefer insertBefore/After instead. """ function Base.insert!(block::Block, index, op::Operation) - API.mlirBlockInsertOwnedOperation(block, index - 1, lose_ownership!(op)) + API.mlirBlockInsertOwnedOperation(block, index - 1, mark_donate(op)) return op end @@ -140,7 +193,7 @@ end Takes an operation owned by the caller and inserts it after the (non-owned) reference operation in the given block. If the reference is null, prepends the operation. Otherwise, the reference must belong to the block. """ function insert_after!(block::Block, reference::Operation, op::Operation) - API.mlirBlockInsertOwnedOperationAfter(block, reference, lose_ownership!(op)) + API.mlirBlockInsertOwnedOperationAfter(block, reference, mark_donate(op)) return op end @@ -150,19 +203,34 @@ end Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in the given block. If the reference is null, appends the operation. Otherwise, the reference must belong to the block. """ function insert_before!(block::Block, reference::Operation, op::Operation) - API.mlirBlockInsertOwnedOperationBefore(block, reference, lose_ownership!(op)) + API.mlirBlockInsertOwnedOperationBefore(block, reference, mark_donate(op)) return op end -function lose_ownership!(block::Block) - @assert block.owned - # API.mlirBlockDetach(block) - @atomic block.owned = false - return block +# to simplify the API, we maintain a stack of contexts in task local storage +# and pass them implicitly to MLIR API's that require them. +function activate(blk::Block) + stack = get!(task_local_storage(), :mlir_block) do + return Block[] + end::Vector{Block} + Base.push!(stack, blk) + return nothing end -function Base.show(io::IO, block::Block) - c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) - ref = Ref(io) - return API.mlirBlockPrint(block, c_print_callback, ref) +function deactivate(blk::Block) + current_block() == blk || error("Deactivating wrong block") + return Base.pop!(task_local_storage(:mlir_block)) +end + +function has_block() + return haskey(task_local_storage(), :mlir_block) && + !Base.isempty(task_local_storage(:mlir_block)::Vector{Block}) +end + +function current_block(; throw_error::Core.Bool=true) + if !has_block() + throw_error && error("No MLIR block is active") + return nothing + end + return last(task_local_storage(:mlir_block)::Vector{Block}) end diff --git a/src/IR/Context.jl b/src/IR/Context.jl index eadecdb3..b701a379 100644 --- a/src/IR/Context.jl +++ b/src/IR/Context.jl @@ -1,82 +1,63 @@ -struct Context - context::API.MlirContext - - function Context(context) - @assert !mlirIsNull(context) "cannot create Context with null MlirContext" - return new(context) - end +@checked struct Context + ref::API.MlirContext end """ - Context() + Context(registry=DialectRegistry(); threading = false) -Creates an MLIR context and transfers its ownership to the caller. +Creates an MLIR context. """ -function Context() - context = API.mlirContextCreate() - context = Context(context) - activate!(context) - return context +function Context(registry=DialectRegistry(); threading::Bool=false) + return Context(mark_alloc(API.mlirContextCreateWithRegistry(registry, threading))) end -function Context(f::Core.Function) - ctx = Context() - try - f(ctx) - finally - dispose!(ctx) - end +""" + dispose(ctx::Context) + +Disposes the given context and releases its resources. +After calling this function, the context must not be used anymore. +""" +function dispose(ctx::Context) + # deactivate(ctx) + return mark_dispose(API.mlirContextDestroy, ctx) end -Base.convert(::Core.Type{API.MlirContext}, c::Context) = c.context +Base.cconvert(::Core.Type{API.MlirContext}, c::Context) = c +Base.unsafe_convert(::Core.Type{API.MlirContext}, c::Context) = mark_use(c).ref + +Base.:(==)(a::Context, b::Context) = API.mlirContextEqual(a, b) + +function enable_multithreading!(enable::Bool=true; context::Context=current_context()) + API.mlirContextEnableMultithreading(context, enable) + return context +end # Global state # to simplify the API, we maintain a stack of contexts in task local storage # and pass them implicitly to MLIR API's that require them. -function activate!(ctx::Context) +function activate(ctx::Context) stack = get!(task_local_storage(), :mlir_context_stack) do - Context[] - end + return Context[] + end::Vector{Context} Base.push!(stack, ctx) return nothing end -function deactivate!(ctx::Context) - context() == ctx || error("Deactivating wrong context") - return Base.pop!(task_local_storage(:mlir_context_stack)) -end - -function dispose!(ctx::Context) - deactivate!(ctx) - return API.mlirContextDestroy(ctx.context) +function deactivate(ctx::Context) + current_context() == ctx || error("Deactivating wrong context") + return Base.pop!(task_local_storage(:mlir_context_stack)::Vector{Context}) end -function _has_context() +function has_context() return haskey(task_local_storage(), :mlir_context_stack) && - !Base.isempty(task_local_storage(:mlir_context_stack)) + !Base.isempty(task_local_storage(:mlir_context_stack)::Vector{Context}) end -function context(; throw_error::Core.Bool=true) - if !_has_context() +function current_context(; throw_error::Core.Bool=true) + if !has_context() throw_error && error("No MLIR context is active") return nothing end - return last(task_local_storage(:mlir_context_stack)) -end - -function context!(f, ctx::Context) - activate!(ctx) - try - f() - finally - deactivate!(ctx) - end + return last(task_local_storage(:mlir_context_stack)::Vector{Context}) end - -function enable_multithreading!(enable::Bool=true; context::Context=context()) - API.mlirContextEnableMultithreading(context, enable) - return context -end - -Base.:(==)(a::Context, b::Context) = API.mlirContextEqual(a, b) diff --git a/src/IR/Dialect.jl b/src/IR/Dialect.jl index 7015661e..c961eaf6 100644 --- a/src/IR/Dialect.jl +++ b/src/IR/Dialect.jl @@ -1,13 +1,10 @@ -struct Dialect - dialect::API.MlirDialect - - function Dialect(dialect) - @assert !mlirIsNull(dialect) "cannot create Dialect from null MlirDialect" - return new(dialect) - end +@checked struct Dialect + ref::API.MlirDialect end -Base.convert(::Core.Type{API.MlirDialect}, dialect::Dialect) = dialect.dialect +Base.cconvert(::Core.Type{API.MlirDialect}, dialect::Dialect) = dialect +Base.unsafe_convert(::Core.Type{API.MlirDialect}, dialect::Dialect) = dialect.ref + Base.:(==)(a::Dialect, b::Dialect) = API.mlirDialectEqual(a, b) context(dialect::Dialect) = Context(API.mlirDialectGetContext(dialect)) @@ -17,32 +14,32 @@ function Base.show(io::IO, dialect::Dialect) return print(io, "Dialect(\"", namespace(dialect), "\")") end -function allow_unregistered_dialects(; context::Context=context()) +function allow_unregistered_dialects(; context::Context=current_context()) return API.mlirContextGetAllowUnregisteredDialects(context) end -function allow_unregistered_dialects!(allow::Bool=true; context::Context=context()) +function allow_unregistered_dialects!(allow::Bool=true; context::Context=current_context()) return API.mlirContextSetAllowUnregisteredDialects(context, allow) end -function num_registered_dialects(; context::Context=context()) +function num_registered_dialects(; context::Context=current_context()) return API.mlirContextGetNumRegisteredDialects(context) end -function num_loaded_dialects(; context::Context=context()) +function num_loaded_dialects(; context::Context=current_context()) return API.mlirContextGetNumLoadedDialects(context) end -function load_all_available_dialects(; context::Context=context()) +function load_all_available_dialects(; context::Context=current_context()) return API.mlirContextLoadAllAvailableDialects(context) end -function get_or_load_dialect!(name::String; context::Context=context()) +function get_or_load_dialect!(name::String; context::Context=current_context()) dialect = API.mlirContextGetOrLoadDialect(context, name) mlirIsNull(dialect) && error("could not load dialect $name") return Dialect(dialect) end -struct DialectHandle - handle::API.MlirDialectHandle +@checked struct DialectHandle + ref::API.MlirDialectHandle end function DialectHandle(s::Symbol) @@ -50,44 +47,44 @@ function DialectHandle(s::Symbol) return DialectHandle(getproperty(API, s)()) end -Base.convert(::Core.Type{API.MlirDialectHandle}, handle::DialectHandle) = handle.handle +Base.cconvert(::Core.Type{API.MlirDialectHandle}, handle::DialectHandle) = handle +Base.unsafe_convert(::Core.Type{API.MlirDialectHandle}, handle::DialectHandle) = handle.ref namespace(handle::DialectHandle) = String(API.mlirDialectHandleGetNamespace(handle)) -function get_or_load_dialect!(handle::DialectHandle; context::Context=context()) +function get_or_load_dialect!(handle::DialectHandle; context::Context=current_context()) dialect = API.mlirDialectHandleLoadDialect(handle, context) mlirIsNull(dialect) && error("could not load dialect from handle $handle") return Dialect(dialect) end -function register_dialect!(handle::DialectHandle; context::Context=context()) +function register_dialect!(handle::DialectHandle; context::Context=current_context()) return API.mlirDialectHandleRegisterDialect(handle, context) end -function load_dialect!(handle::DialectHandle; context::Context=context()) +function load_dialect!(handle::DialectHandle; context::Context=current_context()) return Dialect(API.mlirDialectHandleLoadDialect(handle, context)) end -mutable struct DialectRegistry - registry::API.MlirDialectRegistry - - function DialectRegistry(registry) - @assert !mlirIsNull(registry) "cannot create DialectRegistry with null MlirDialectRegistry" - finalizer(DialectRegistry(registry)) do registry - API.mlirDialectRegistryDestroy(registry.registry) - end - end +@checked struct DialectRegistry + ref::API.MlirDialectRegistry end -DialectRegistry() = DialectRegistry(API.mlirDialectRegistryCreate()) +DialectRegistry() = DialectRegistry(mark_alloc(API.mlirDialectRegistryCreate())) -function Base.convert(::Core.Type{API.MlirDialectRegistry}, registry::DialectRegistry) - return registry.registry +dispose(registry::DialectRegistry) = API.mlirDialectRegistryDestroy(registry.ref) + +Base.cconvert(::Core.Type{API.MlirDialectRegistry}, registry::DialectRegistry) = registry +function Base.unsafe_convert( + ::Core.Type{API.MlirDialectRegistry}, registry::DialectRegistry +) + return mark_use(registry).ref end + function Base.push!(registry::DialectRegistry, handle::DialectHandle) return API.mlirDialectHandleInsertDialect(handle, registry) end -# TODO is `append!` the right name? +# TODO(#2245) is `append!` the right name? function Base.append!(registry::DialectRegistry; context::Context) return API.mlirContextAppendDialectRegistry(context, registry) end diff --git a/src/IR/ExecutionEngine.jl b/src/IR/ExecutionEngine.jl index 341e764f..cfe79606 100644 --- a/src/IR/ExecutionEngine.jl +++ b/src/IR/ExecutionEngine.jl @@ -1,10 +1,5 @@ -mutable struct ExecutionEngine - engine::API.MlirExecutionEngine - - function ExecutionEngine(engine) - @assert !mlirIsNull(engine) "cannot create ExecutionEngine with null MlirExecutionEngine" - return finalizer(API.mlirExecutionEngineDestroy, new(engine)) - end +@checked struct ExecutionEngine + ref::API.MlirExecutionEngine end """ @@ -16,7 +11,7 @@ The module ownership stays with the client and can be destroyed as soon as the c `optLevel` is the optimization level to be used for transformation and code generation. LLVM passes at `optLevel` are run before code generation. The number and array of paths corresponding to shared libraries that will be loaded are specified via `numPaths` and `sharedLibPaths` respectively. -TODO: figure out other options. +TODO(#2246): figure out other options. """ function ExecutionEngine( mod::Module, @@ -24,23 +19,23 @@ function ExecutionEngine( sharedlibs::Vector{String}=String[], enableObjectDump::Bool=false, ) - if MLIR_VERSION[] < v"16" - enableObjectDump && @warn "enableObjectDump is only available in LLVM 16 and later" - ExecutionEngine( - API.mlirExecutionEngineCreate(mod, optLevel, length(sharedlibs), sharedlibs) - ) - else - ExecutionEngine( + return ExecutionEngine( + mark_alloc( API.mlirExecutionEngineCreate( mod, optLevel, length(sharedlibs), sharedlibs, enableObjectDump ), - ) - end + ), + ) end -Base.convert(::Core.Type{API.MlirExecutionEngine}, engine::ExecutionEngine) = engine.engine +dispose(engine::ExecutionEngine) = API.mlirExecutionEngineDestroy(engine) -# TODO mlirExecutionEngineInvokePacked +Base.cconvert(::Core.Type{API.MlirExecutionEngine}, engine::ExecutionEngine) = engine +function Base.unsafe_convert(::Core.Type{API.MlirExecutionEngine}, engine::ExecutionEngine) + return mark_use(engine).ref +end + +# TODO(#2246) mlirExecutionEngineInvokePacked """ lookup(jit, name) @@ -56,12 +51,13 @@ function lookup(jit::ExecutionEngine, name::String; packed::Bool=false) return fn == C_NULL ? nothing : fn end -# TODO mlirExecutionEngineRegisterSymbol +# TODO(#2246) mlirExecutionEngineRegisterSymbol """ write(fileName, jit) Dump as an object in `fileName`. """ -Base.write(filename::String, jit::ExecutionEngine) = - API.mlirExecutionEngineDumpToObjectFile(jit, filename) +function Base.write(filename::String, jit::ExecutionEngine) + return API.mlirExecutionEngineDumpToObjectFile(jit, filename) +end diff --git a/src/IR/IR.jl b/src/IR/IR.jl index 2d9c09ad..cc125015 100644 --- a/src/IR/IR.jl +++ b/src/IR/IR.jl @@ -1,38 +1,24 @@ module IR -using ..MLIR: MLIR_VERSION, MLIRException using ..API -# do not export `Type`, as it is already defined in Core -# also, use `Core.Type` inside this module to avoid clash with MLIR `Type` +# WARN do not export `Type` nor `Module` as they are already defined in Core +# also, use `Core.Type` and `Core.Module` inside this module to avoid clash with +# MLIR `Type` and `Module` export Attribute, Block, Context, Dialect, Location, Operation, Region, Value -export activate!, deactivate!, dispose!, enable_multithreading!, context! -export context, type, type!, location, typeid, block, dialect -export nattrs, - attr, - attr!, - rmattr!, - nregions, - region, - nresults, - result, - noperands, - operand, - operand!, - nsuccessors, - successor -export BlockIterator, RegionIterator, OperationIterator +export activate, deactivate, dispose, enable_multithreading! +export context, current_context, has_context +export block, current_block, has_block +export current_module, has_module +export type, settype!, location, typeid, dialect +export nattrs, getattr, setattr!, rmattr! +export nregions, region +export nresults, result, noperands, operand, setoperand! +export nsuccessors, successor export @affinemap -function mlirIsNull(val) - return val.ptr == C_NULL -end - -function print_callback(str::API.MlirStringRef, userdata) - data = unsafe_wrap(Array, Base.convert(Ptr{Cchar}, str.data), str.length; own=false) - write(userdata isa Base.RefValue ? userdata[] : userdata, data) - return Cvoid() -end +include("debug.jl") +include("Utils.jl") include("LogicalResult.jl") include("Context.jl") @@ -45,43 +31,15 @@ include("Module.jl") include("Block.jl") include("Region.jl") include("Value.jl") -include("OpOperand.jl") # introduced in LLVM 16 +include("OpOperand.jl") include("Identifier.jl") include("SymbolTable.jl") include("AffineExpr.jl") include("AffineMap.jl") include("Attribute.jl") include("IntegerSet.jl") -include("Iterators.jl") include("ExecutionEngine.jl") include("Pass.jl") -### Utils - -function visit(f, op) - for region in RegionIterator(op) - for block in BlockIterator(region) - for op in OperationIterator(block) - f(op) - end - end - end -end - -""" - verifyall(operation; debug=false) - -Prints the operations which could not be verified. -""" -function verifyall(operation::Operation; debug=false) - io = IOContext(stdout, :debug => debug) - visit(operation) do op - if !verify(op) - show(io, op) - end - end -end -verifyall(module_::IR.Module) = verifyall(Operation(module_)) - end # module IR diff --git a/src/IR/Identifier.jl b/src/IR/Identifier.jl index 2f58836b..2eb52540 100644 --- a/src/IR/Identifier.jl +++ b/src/IR/Identifier.jl @@ -1,5 +1,5 @@ -struct Identifier - identifier::API.MlirIdentifier +@checked struct Identifier + ref::API.MlirIdentifier end """ @@ -7,10 +7,12 @@ end Gets an identifier with the given string value. """ -Identifier(str::String; context::Context=context()) = - Identifier(API.mlirIdentifierGet(context, str)) +function Identifier(str::String; context::Context=current_context()) + return Identifier(API.mlirIdentifierGet(context, str)) +end -Base.convert(::Core.Type{API.MlirIdentifier}, id::Identifier) = id.identifier +Base.cconvert(::Core.Type{API.MlirIdentifier}, id::Identifier) = id +Base.unsafe_convert(::Core.Type{API.MlirIdentifier}, id::Identifier) = id.ref """ ==(ident, other) diff --git a/src/IR/IntegerSet.jl b/src/IR/IntegerSet.jl index e7f89cbc..03251516 100644 --- a/src/IR/IntegerSet.jl +++ b/src/IR/IntegerSet.jl @@ -1,56 +1,62 @@ -struct IntegerSet - set::API.MlirIntegerSet - - function IntegerSet(set) - @assert !mlirIsNull(set) "cannot create IntegerSet with null MlirIntegerSet" - return new(set) - end +@checked struct IntegerSet + ref::API.MlirIntegerSet end """ - Integerset(ndims, nsymbols; context=context()) + Integerset(ndims, nsymbols; context=current_context()) Gets or creates a new canonically empty integer set with the give number of dimensions and symbols in the given context. """ -IntegerSet(ndims, nsymbols; context::Context=context()) = - IntegerSet(API.mlirIntegerSetEmptyGet(context, ndims, nsymbols)) +function IntegerSet(ndims, nsymbols; context::Context=current_context()) + return IntegerSet(API.mlirIntegerSetEmptyGet(context, ndims, nsymbols)) +end """ - IntegerSet(ndims, nsymbols, constraints, eqflags; context=context()) + IntegerSet(ndims, nsymbols, constraints, eqflags; context=current_context()) Gets or creates a new integer set in the given context. The set is defined by a list of affine constraints, with the given number of input dimensions and symbols, which are treated as either equalities (eqflags is 1) or inequalities (eqflags is 0). Both `constraints` and `eqflags` need to be arrays of the same length. """ -IntegerSet(ndims, nsymbols, constraints, eqflags; context::Context=context()) = IntegerSet( - API.mlirIntegerSetGet( - context, - ndims, - nsymbols, - length(constraints), - pointer(constraints), - pointer(eqflags), - ), +function IntegerSet( + ndims, nsymbols, constraints, eqflags; context::Context=current_context() ) + return IntegerSet( + API.mlirIntegerSetGet( + context, ndims, nsymbols, length(constraints), constraints, eqflags + ), + ) +end + +Base.cconvert(::Core.Type{API.MlirIntegerSet}, set::IntegerSet) = set +Base.unsafe_convert(::Core.Type{API.MlirIntegerSet}, set::IntegerSet) = set.ref + +function Base.show(io::IO, set::IntegerSet) + print(io, "IntegerSet(#= ") + c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) + ref = Ref(io) + API.mlirIntegerSetPrint(set, c_print_callback, ref) + return print(io, " =#)") +end """ - mlirIntegerSetReplaceGet(set, dimReplacements, symbolReplacements, numResultDims, numResultSymbols) + Base.replace(set::IntegerSet, dimReplacements, symbolReplacements, numResultDims, numResultSymbols) Gets or creates a new integer set in which the values and dimensions of the given set are replaced with the given affine expressions. `dimReplacements` and `symbolReplacements` are expected to point to at least as many consecutive expressions as the given set has dimensions and symbols, respectively. The new set will have `numResultDims` and `numResultSymbols` dimensions and symbols, respectively. """ -Base.replace(set::IntegerSet, dim_replacements, symbol_replacements) = IntegerSet( - API.mlirIntegerSetReplaceGet( - set, - dim_replacements, - symbol_replacements, - length(dim_replacements), - length(symbol_replacements), - ), -) - -Base.convert(::Core.Type{API.MlirIntegerSet}, set::IntegerSet) = set.set +function Base.replace(set::IntegerSet, dim_replacements, symbol_replacements) + return IntegerSet( + API.mlirIntegerSetReplaceGet( + set, + dim_replacements, + symbol_replacements, + length(dim_replacements), + length(symbol_replacements), + ), + ) +end """ ==(s1, s2) @@ -67,14 +73,14 @@ Base.:(==)(a::IntegerSet, b::IntegerSet) = API.mlirIntegerSetEqual(a, b) Gets the context in which the given integer set lives. """ -context(set::IntegerSet) = Context(API.mlirIntegerSetGetContext(set.set)) +context(set::IntegerSet) = Context(API.mlirIntegerSetGetContext(set)) """ isempty(set) -Checks whether the given set is a canonical empty set, e.g., the set returned by [`mlirIntegerSetEmptyGet`](@ref). +Checks whether the given set is a canonical empty set. """ -isempty(set::IntegerSet) = API.mlirIntegerSetIsCanonicalEmpty(set) +Base.isempty(set::IntegerSet) = API.mlirIntegerSetIsCanonicalEmpty(set) """ ndims(set) @@ -131,11 +137,3 @@ constraint(set::IntegerSet, i) = API.mlirIntegerSetGetConstraint(set, i) Returns `true` of the `i`-th constraint of the set is an equality constraint, `false` otherwise. """ isconstrainteq(set::IntegerSet, i) = API.mlirIntegerSetIsConstraintEq(set, i) - -function Base.show(io::IO, set::IntegerSet) - print(io, "IntegerSet(#= ") - c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) - ref = Ref(io) - API.mlirIntegerSetPrint(set, c_print_callback, ref) - return print(io, " =#)") -end diff --git a/src/IR/Location.jl b/src/IR/Location.jl index 88f40300..1d5d6ded 100644 --- a/src/IR/Location.jl +++ b/src/IR/Location.jl @@ -1,34 +1,26 @@ -struct Location - location::API.MlirLocation - - function Location(location) - @assert !mlirIsNull(location) "cannot create Location with null MlirLocation" - return new(location) - end +@checked struct Location + ref::API.MlirLocation end -Location(; context::Context=context()) = Location(API.mlirLocationUnknownGet(context)) +function Location(; context::Context=current_context()) + return Location(API.mlirLocationUnknownGet(context)) +end -function Location(filename, line, column; context::Context=context()) +function Location(filename, line, column; context::Context=current_context()) return Location(API.mlirLocationFileLineColGet(context, filename, line, column)) end -function Location(callee::Location, caller::Location; context::Context=context()) - return Location(API.mlirLocationCallSiteGet(context, callee, caller)) +function Location(callee::Location, caller::Location; context::Context=current_context()) + return Location(API.mlirLocationCallSiteGet(callee, caller)) end -function Location(name::String, location::Location; context::Context=context()) +function Location(name::String, location::Location; context::Context=current_context()) return Location(API.mlirLocationNameGet(context, name, location)) end -# TODO rename to merge? -function fuse(locations::Vector{Location}, metadata; context::Context=context()) - return Location( - API.mlirLocationFusedGet(context, length(locations), pointer(locations), metadata) - ) -end +Base.cconvert(::Core.Type{API.MlirLocation}, location::Location) = location +Base.unsafe_convert(::Core.Type{API.MlirLocation}, location::Location) = location.ref -Base.convert(::Core.Type{API.MlirLocation}, location::Location) = location.location Base.:(==)(a::Location, b::Location) = API.mlirLocationEqual(a, b) context(location::Location) = Context(API.mlirLocationGetContext(location)) @@ -39,3 +31,10 @@ function Base.show(io::IO, location::Location) API.mlirLocationPrint(location, c_print_callback, ref) return print(io, " =#)") end + +# TODO(#2245): rename to merge? +function fuse(locations::Vector{Location}, metadata; context::Context=current_context()) + return Location( + API.mlirLocationFusedGet(context, length(locations), locations, metadata) + ) +end diff --git a/src/IR/LogicalResult.jl b/src/IR/LogicalResult.jl index b0e2fc83..87c21716 100644 --- a/src/IR/LogicalResult.jl +++ b/src/IR/LogicalResult.jl @@ -6,10 +6,11 @@ LLVM convention for using boolean values to designate success or failure of an o Instances of [`LogicalResult`](@ref) must only be inspected using the associated functions. """ struct LogicalResult - result::API.MlirLogicalResult + ref::API.MlirLogicalResult end -Base.convert(::Core.Type{API.MlirLogicalResult}, result::LogicalResult) = result.result +Base.cconvert(::Core.Type{API.MlirLogicalResult}, result::LogicalResult) = result +Base.unsafe_convert(::Core.Type{API.MlirLogicalResult}, result::LogicalResult) = result.ref """ success() @@ -30,11 +31,11 @@ failure() = LogicalResult(API.MlirLogicalResult(0)) Checks if the given logical result represents a success. """ -issuccess(result::LogicalResult) = result.result.value != 0 +issuccess(result::LogicalResult) = result.ref.value != 0 """ isfailure(res) Checks if the given logical result represents a failure. """ -isfailure(result::LogicalResult) = result.result.value == 0 +isfailure(result::LogicalResult) = result.ref.value == 0 diff --git a/src/IR/Module.jl b/src/IR/Module.jl index d8017172..e200cf5e 100644 --- a/src/IR/Module.jl +++ b/src/IR/Module.jl @@ -1,10 +1,5 @@ -mutable struct Module - module_::API.MlirModule - - function Module(module_) - @assert !mlirIsNull(module_) "cannot create Module with null MlirModule" - return finalizer(API.mlirModuleDestroy, new(module_)) - end +@checked struct Module + ref::API.MlirModule end """ @@ -12,19 +7,29 @@ end Creates a new, empty module and transfers ownership to the caller. """ -Module(loc::Location=Location()) = Module(API.mlirModuleCreateEmpty(loc)) +Module(loc::Location=Location()) = Module(mark_alloc(API.mlirModuleCreateEmpty(loc))) -Module(op::Operation) = Module(API.mlirModuleFromOperation(lose_ownership!(op))) +Module(op::Operation) = Module(API.mlirModuleFromOperation(mark_donate(op))) -Base.convert(::Core.Type{API.MlirModule}, module_::Module) = module_.module_ +""" + dispose(module) +Disposes the given module and releases its resources. +After calling this function, the module must not be used anymore. """ - parse(::Type{Module}, module; context=context()) +dispose(mod_::Module) = mark_dispose(API.mlirModuleDestroy, mod_) + +Base.cconvert(::Core.Type{API.MlirModule}, module_::Module) = module_ +Base.unsafe_convert(::Core.Type{API.MlirModule}, module_::Module) = mark_use(module_).ref + +""" + parse(::Type{Module}, module; context=current_context()) Parses a module from the string and transfers ownership to the caller. """ -Base.parse(::Core.Type{Module}, module_; context::Context=context()) = - Module(API.mlirModuleCreateParse(context, module_)) +function Base.parse(::Core.Type{Module}, module_; context::Context=current_context()) + return Module(API.mlirModuleCreateParse(context, module_)) +end macro mlir_str(code) quote @@ -45,16 +50,45 @@ context(module_::Module) = Context(API.mlirModuleGetContext(module_)) Gets the body of the module, i.e. the only block it contains. """ -body(module_) = Block(API.mlirModuleGetBody(module_), false) +body(module_::Module) = Block(API.mlirModuleGetBody(module_)) """ Operation(module) Views the module as a generic operation. """ -Operation(module_::Module) = Operation(API.mlirModuleGetOperation(module_), false) +Operation(module_::Module) = Operation(API.mlirModuleGetOperation(module_)) + +Base.copy(mod::Module) = Module(copy(Operation(mod))) + +Base.show(io::IO, module_::Module) = show(io, Operation(module_)) -function Base.show(io::IO, module_::Module) - println(io, "Module:") - return show(io, Operation(module_)) +verifyall(mod_::Module; debug=false) = verifyall(Operation(mod_); debug) + +# to simplify the API, we maintain a stack of contexts in task local storage +# and pass them implicitly to MLIR API's that require them. +function activate(blk::Module) + stack = get!(task_local_storage(), :mlir_module) do + return Module[] + end::Vector{Module} + Base.push!(stack, blk) + return nothing +end + +function deactivate(blk::Module) + current_module() == blk || error("Deactivating wrong block") + return Base.pop!(task_local_storage(:mlir_module)::Vector{Module}) +end + +function has_module() + return haskey(task_local_storage(), :mlir_module) && + !Base.isempty(task_local_storage(:mlir_module)::Vector{Module}) +end + +function current_module(; throw_error::Core.Bool=true) + if !has_module() + throw_error && error("No MLIR module is active") + return nothing + end + return last(task_local_storage(:mlir_module)::Vector{Module}) end diff --git a/src/IR/OpOperand.jl b/src/IR/OpOperand.jl index 68faabd4..0d0f0156 100644 --- a/src/IR/OpOperand.jl +++ b/src/IR/OpOperand.jl @@ -1,13 +1,9 @@ -struct OpOperand - op::API.MlirOpOperand - - function OpOperand(op::API.MlirOpOperand) - @assert mlirIsNull(op) "cannot create OpOperand with null MlirOpOperand" - return new(op) - end +@checked struct OpOperand + ref::API.MlirOpOperand end -Base.convert(::Core.Type{API.MlirOpOperand}, op::OpOperand) = op.op +Base.cconvert(::Core.Type{API.MlirOpOperand}, op::OpOperand) = op +Base.unsafe_convert(::Core.Type{API.MlirOpOperand}, op::OpOperand) = op.ref """ first_use(value) @@ -15,7 +11,7 @@ Base.convert(::Core.Type{API.MlirOpOperand}, op::OpOperand) = op.op Returns an `OpOperand` representing the first use of the value, or a `nothing` if there are no uses. """ function first_use(value::Value) - operand = API.mlirOperationGetFirstResult(value) + operand = API.mlirValueGetFirstUse(value) mlirIsNull(operand) && return nothing return OpOperand(operand) end @@ -25,7 +21,7 @@ end Returns the owner operation of an op operand. """ -owner(op::OpOperand) = Operation(API.mlirOpOperandGetOwner(op), false) +owner(op::OpOperand) = Operation(API.mlirOpOperandGetOwner(op)) """ operandindex(opOperand) diff --git a/src/IR/Operation.jl b/src/IR/Operation.jl index aad10a09..4bc21ded 100644 --- a/src/IR/Operation.jl +++ b/src/IR/Operation.jl @@ -1,43 +1,101 @@ -mutable struct Operation - operation::API.MlirOperation - @atomic owned::Bool - - function Operation(operation, owned=true) - @assert !mlirIsNull(operation) "cannot create Operation with null MlirOperation" - finalizer(new(operation, owned)) do op - if op.owned - API.mlirOperationDestroy(op.operation) - end - end +@checked struct Operation + ref::API.MlirOperation +end + +dispose(op::Operation) = mark_dispose(API.mlirOperationDestroy, op) + +Base.cconvert(::Core.Type{API.MlirOperation}, op::Operation) = op +Base.unsafe_convert(::Core.Type{API.MlirOperation}, op::Operation) = mark_use(op).ref + +Base.:(==)(op::Operation, other::Operation) = API.mlirOperationEqual(op, other) + +""" + parse(::Type{Operation}, code; context=current_context()) + +Parses an operation from the string and transfers ownership to the caller. +""" +function Base.parse( + ::Core.Type{Operation}, + code; + verify::Bool=false, + context::Context=current_context(), + block=Block(), + location::Location=Location(), +) + return Operation( + mark_alloc(API.mlirOperationParse(context, block, code, location, verify)) + ) +end + +function Base.show(io::IO, operation::Operation) + if mlirIsNull(operation.ref) + return write(io, "Operation(NULL)") + end + + c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) + + buffer = IOBuffer() + ref = Ref(buffer) + + flags = API.mlirOpPrintingFlagsCreate() + + API.mlirOpPrintingFlagsEnableDebugInfo(flags, get(io, :debug, false), false) + API.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) + API.mlirOpPrintingFlagsDestroy(flags) + + return write(io, rstrip(String(take!(buffer)))) +end + +Base.IteratorSize(::Core.Type{Operation}) = Base.HasLength() +Base.IteratorEltype(::Core.Type{Operation}) = Base.HasEltype() +Base.eltype(::Operation) = Region +Base.length(it::Operation) = nregions(it) + +""" + Base.iterate(op::Operation) + +Iterates over all sub-regions for the given operation. +""" +function Base.iterate(it::Operation) + raw_region = API.mlirOperationGetFirstRegion(it) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region) + (region, region) end end -Base.cconvert(::Core.Type{API.MlirOperation}, operation::Operation) = operation -function Base.unsafe_convert(::Core.Type{API.MlirOperation}, operation::Operation) - return operation.operation +function Base.iterate(::Operation, region) + raw_region = API.mlirRegionGetNextInOperation(region) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region) + (region, region) + end end -Base.:(==)(op::Operation, other::Operation) = API.mlirOperationEqual(op, other) """ copy(op) Creates a deep copy of an operation. The operation is not inserted and ownership is transferred to the caller. """ -Base.copy(operation::Operation) = Operation(API.mlirOperationClone(operation)) +Base.copy(op::Operation) = Operation(mark_alloc(API.mlirOperationClone(op))) """ context(op) Gets the context this operation is associated with. """ -context(operation::Operation) = Context(API.mlirOperationGetContext(operation)) +context(op::Operation) = Context(API.mlirOperationGetContext(op)) """ location(op) Gets the location of the operation. """ -location(operation::Operation) = Location(API.mlirOperationGetLocation(operation)) +location(op::Operation) = Location(API.mlirOperationGetLocation(op)) """ typeid(op) @@ -51,52 +109,53 @@ typeid(op::Operation) = TypeID(API.mlirOperationGetTypeID(op)) Gets the name of the operation as an identifier. """ -name(operation::Operation) = String(API.mlirOperationGetName(operation)) +name(op::Operation) = String(API.mlirOperationGetName(op)) """ block(op) Gets the block that owns this operation, returning null if the operation is not owned. """ -block(operation::Operation) = Block(API.mlirOperationGetBlock(operation), false) +block(op::Operation) = Block(API.mlirOperationGetBlock(op)) """ parent_op(op) Gets the operation that owns this operation, returning null if the operation is not owned. """ -parent_op(operation::Operation) = - Operation(API.mlirOperationGetParentOperation(operation), false) +function parent_op(op::Operation) + return Operation(API.mlirOperationGetParentOperation(op)) +end """ - rmfromparent(op) + rmfromparent!(op) Removes the given operation from its parent block. The operation is not destroyed. The ownership of the operation is transferred to the caller. """ -function rmfromparent!(operation::Operation) - API.mlirOperationRemoveFromParent(operation) - @atomic operation.owned = true - return operation +function rmfromparent!(op::Operation) + API.mlirOperationRemoveFromParent(op) + # TODO mark ownership moved to the caller + return op end -dialect(operation::Operation) = Symbol(first(split(name(operation), '.'))) +dialect(op::Operation) = Symbol(first(split(name(op), '.'))) """ nregions(op) Returns the number of regions attached to the given operation. """ -nregions(operation::Operation) = API.mlirOperationGetNumRegions(operation) +nregions(op::Operation) = API.mlirOperationGetNumRegions(op) """ region(op, i) Returns `i`-th region attached to the operation. """ -function region(operation::Operation, i) - i ∉ 1:nregions(operation) && throw(BoundsError(operation, i)) - return Region(API.mlirOperationGetRegion(operation, i - 1), false) +function region(op::Operation, i) + i ∉ 1:nregions(op) && throw(BoundsError(op, i)) + return Region(API.mlirOperationGetRegion(op, i - 1)) end """ @@ -104,44 +163,51 @@ end Returns the number of results of the operation. """ -nresults(operation::Operation) = API.mlirOperationGetNumResults(operation) +nresults(op::Operation) = API.mlirOperationGetNumResults(op) """ result(op, i) Returns `i`-th result of the operation. """ -function result(operation::Operation, i=1) - i ∉ 1:nresults(operation) && throw(BoundsError(operation, i)) - return Value(API.mlirOperationGetResult(operation, i - 1)) +function result(op::Operation, i=1) + i ∉ 1:nresults(op) && throw(BoundsError(op, i)) + return Value(API.mlirOperationGetResult(op, i - 1)) end -results(operation) = [result(operation, i) for i in 1:nresults(operation)] +results(op) = [result(op, i) for i in 1:nresults(op)] """ noperands(op) Returns the number of operands of the operation. """ -noperands(operation::Operation) = API.mlirOperationGetNumOperands(operation) +noperands(op::Operation) = API.mlirOperationGetNumOperands(op) """ operand(op, i) Returns `i`-th operand of the operation. """ -function operand(operation::Operation, i=1) - i ∉ 1:noperands(operation) && throw(BoundsError(operation, i)) - return Value(API.mlirOperationGetOperand(operation, i - 1)) +function operand(op::Operation, i=1) + i ∉ 1:noperands(op) && throw(BoundsError(op, i)) + return Value(API.mlirOperationGetOperand(op, i - 1)) end """ - operand!(op, i, value) + operands(op) + +Return an array of all operands of the operation. +""" +operands(op) = Value[operand(op, i) for i in 1:noperands(op)] + +""" + setoperand!(op, i, value) Sets the `i`-th operand of the operation. """ -function operand!(operation::Operation, i, value) - i ∉ 1:noperands(operation) && throw(BoundsError(operation, i)) - API.mlirOperationSetOperand(operation, i - 1, value) +function setoperand!(op::Operation, i, value) + i ∉ 1:noperands(op) && throw(BoundsError(op, i)) + API.mlirOperationSetOperand(op, i - 1, value) return value end @@ -150,16 +216,16 @@ end Returns the number of successor blocks of the operation. """ -nsuccessors(operation::Operation) = API.mlirOperationGetNumSuccessors(operation) +nsuccessors(op::Operation) = API.mlirOperationGetNumSuccessors(op) """ successor(op, i) Returns `i`-th successor of the operation. """ -function successor(operation::Operation, i) - i ∉ 1:nsuccessors(operation) && throw(BoundsError(operation, i)) - return Block(API.mlirOperationGetSuccessor(operation, i - 1), false) +function successor(op::Operation, i) + i ∉ 1:nsuccessors(op) && throw(BoundsError(op, i)) + return Block(API.mlirOperationGetSuccessor(op, i - 1)) end """ @@ -167,25 +233,25 @@ end Returns the number of attributes attached to the operation. """ -nattrs(operation::Operation) = API.mlirOperationGetNumAttributes(operation) +nattrs(op::Operation) = API.mlirOperationGetNumAttributes(op) """ - attr(op, i) + getattr(op, i) Return `i`-th attribute of the operation. """ -function attr(operation::Operation, i) - i ∉ 1:nattrs(operation) && throw(BoundsError(operation, i)) - return NamedAttribute(API.mlirOperationGetAttribute(operation, i - 1)) +function getattr(op::Operation, i) + i ∉ 1:nattrs(op) && throw(BoundsError(op, i)) + return NamedAttribute(API.mlirOperationGetAttribute(op, i - 1)) end """ - attr(op, name) + getattr(op, name) Returns an attribute attached to the operation given its name. """ -function attr(operation::Operation, name::AbstractString) - raw_attr = API.mlirOperationGetAttributeByName(operation, name) +function getattr(op::Operation, name::AbstractString) + raw_attr = API.mlirOperationGetAttributeByName(op, name) if mlirIsNull(raw_attr) return nothing end @@ -193,13 +259,13 @@ function attr(operation::Operation, name::AbstractString) end """ - attr!(op, name, attr) + setattr!(op, name, attr) Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise. """ -function attr!(operation::Operation, name, attribute) - API.mlirOperationSetAttributeByName(operation, name, attribute) - return operation +function setattr!(op::Operation, name, attribute) + API.mlirOperationSetAttributeByName(op, name, attribute) + return op end """ @@ -207,33 +273,7 @@ end Removes an attribute by name. Returns false if the attribute was not found and true if removed. """ -rmattr!(operation::Operation, name) = - API.mlirOperationRemoveAttributeByName(operation, name) - -function lose_ownership!(operation::Operation) - @assert operation.owned - @atomic operation.owned = false - return operation -end - -function Base.show(io::IO, operation::Operation) - c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) - - buffer = IOBuffer() - ref = Ref(buffer) - - flags = API.mlirOpPrintingFlagsCreate() - - if MLIR_VERSION[] >= v"16" - API.mlirOpPrintingFlagsEnableDebugInfo(flags, get(io, :debug, false), true) - else - get(io, :debug, false) && API.mlirOpPrintingFlagsEnableDebugInfo(flags, true) - end - API.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) - API.mlirOpPrintingFlagsDestroy(flags) - - return write(io, rstrip(String(take!(buffer)))) -end +rmattr!(op::Operation, name) = API.mlirOperationRemoveAttributeByName(op, name) """ verify(op) @@ -248,7 +288,7 @@ verify(operation::Operation) = API.mlirOperationVerify(operation) Moves the given operation immediately after the other operation in its parent block. The given operation may be owned by the caller or by its current block. The other operation must belong to a block. In any case, the ownership is transferred to the block of the other operation. """ function move_after!(operation::Operation, other::Operation) - lose_ownership!(operation) + mark_donate(operation) return API.mlirOperationMoveAfter(operation, other) end @@ -260,23 +300,22 @@ The given operation may be owner by the caller or by its current block. The other operation must belong to a block. In any case, the ownership is transferred to the block of the other operation. """ -function move_before!(operation::Operation, other::Operation) - lose_ownership!(operation) - return API.mlirOperationMoveBefore(operation, other) +function move_before!(op::Operation, other::Operation) + mark_donate(op) + return API.mlirOperationMoveBefore(op, other) end """ - is_registered(name; context=context()) + is_registered(name; context=current_context()) Returns whether the given fully-qualified operation (i.e. 'dialect.operation') is registered with the context. This will return true if the dialect is loaded and the operation is registered within the dialect. """ -is_registered(opname; context::Context=context()) = - API.mlirContextIsRegisteredOperation(context, opname) - -# TODO mlirOperationWriteBytecode (LLVM 16) +function is_registered(opname; context::Context=current_context()) + return API.mlirContextIsRegisteredOperation(context, opname) +end -function create_operation( +function create_operation_common( name, loc; results=nothing, @@ -298,7 +337,7 @@ function create_operation( API.mlirOperationStateAddOperands(state, length(operands), operands) end if !isnothing(owned_regions) - lose_ownership!.(owned_regions) + mark_donate.(owned_regions) GC.@preserve owned_regions begin mlir_regions = Base.unsafe_convert.(API.MlirRegion, owned_regions) API.mlirOperationStateAddOwnedRegions( @@ -322,6 +361,50 @@ function create_operation( if mlirIsNull(op) error("Create Operation '$name' failed") end - Operation(op, true) + return Operation(op) + end +end + +function create_operation(args...; kwargs...) + res = create_operation_common(args...; kwargs...) + if has_block() + push!(current_block(), res) + end + return res +end + +function create_operation_at_front(args...; kwargs...) + res = create_operation_common(args...; kwargs...) + Base.pushfirst!(current_block(), res) + return res +end + +function FunctionType(op::Operation) + is_function_op = API.mlirIsFunctionOpInterface(op) + if is_function_op + return Type(API.mlirGetFunctionTypeFromOperation(op)) + else + throw("operation is not a function operation") + end +end + +""" + verifyall(operation; debug=false) + +Prints the operations which could not be verified. +""" +function verifyall(operation::Operation; debug=false) + io = IOBuffer() + visit(operation) do op + ok = verifyall(op; debug) + if !ok || !verify(op) + if ok + show(IOContext(io, :debug => debug), op) + error(String(take!(io))) + end + false + else + true + end end end diff --git a/src/IR/Pass.jl b/src/IR/Pass.jl index 1ca2fdb3..0458b8b3 100644 --- a/src/IR/Pass.jl +++ b/src/IR/Pass.jl @@ -1,50 +1,63 @@ abstract type AbstractPass end +using StableRNGs: StableRNG + mutable struct ExternalPassHandle ctx::Union{Nothing,Context} pass::AbstractPass end -mutable struct PassManager - pass::API.MlirPassManager +@checked struct PassManager + ref::API.MlirPassManager allocator::TypeIDAllocator passes::Dict{TypeID,ExternalPassHandle} +end - function PassManager(pm::API.MlirPassManager) - @assert !mlirIsNull(pm) "cannot create PassManager with null MlirPassManager" - finalizer(new(pm, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm - API.mlirPassManagerDestroy(pm.pass) - end - end +function PassManager(pm::API.MlirPassManager) + return PassManager(pm, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}()) end """ - PassManager(; context=context()) + PassManager(; context=current_context()) Create a new top-level PassManager. """ -PassManager(; context::Context=context()) = PassManager(API.mlirPassManagerCreate(context)) +function PassManager(; context::Context=current_context()) + return PassManager(mark_alloc(API.mlirPassManagerCreate(context))) +end """ - PassManager(anchorOp; context=context()) + PassManager(anchorOp; context=current_context()) Create a new top-level PassManager anchored on `anchorOp`. """ -function PassManager(anchor_op::Operation; context::Context=context()) - MLIR_VERSION[] >= v"16" || - throw(MLIRException("`PassManager(::Operation)` requires MLIR version 16 or later")) - return PassManager(API.mlirPassManagerCreateOnOperation(context, anchor_op)) +function PassManager(anchor_op::Operation; context::Context=current_context()) + return PassManager(mark_alloc(API.mlirPassManagerCreateOnOperation(context, anchor_op))) end -Base.convert(::Core.Type{API.MlirPassManager}, pass::PassManager) = pass.pass +dispose(pass::PassManager) = mark_dispose(API.mlirPassManagerDestroy, pass) + +Base.cconvert(::Core.Type{API.MlirPassManager}, pass::PassManager) = pass +function Base.unsafe_convert(::Core.Type{API.MlirPassManager}, pass::PassManager) + return mark_use(pass).ref +end """ enable_ir_printing!(passManager) Enable mlir-print-ir-after-all. """ -function enable_ir_printing!(pm) - API.mlirPassManagerEnableIRPrinting(pm) +function enable_ir_printing!( + pm; + before_all=false, + after_all=false, + module_scope=false, + after_only_on_change=false, + after_only_on_failure=false, +) + API.mlirPassManagerEnableIRPrinting( + pm, before_all, after_all, module_scope, after_only_on_change, after_only_on_failure + ) return pm end @@ -59,30 +72,133 @@ function enable_verifier!(pm, enable=true) end """ - run!(passManager, module) + add_owned_pass!(passManager, pass) -Run the provided `passManager` on the given `module`. +Add a pass and transfer ownership to the provided top-level `PassManager`. If the pass is not a generic operation pass or a `ModulePass`, a new `OpPassManager` is implicitly nested under the provided PassManager. """ -function run!(pm::PassManager, mod::Module) - status = if MLIR_VERSION[] >= v"17" - LogicalResult(API.mlirPassManagerRunOnOp(pm, Operation(mod))) - else - LogicalResult(API.mlirPassManagerRun(pm, mod)) +function add_owned_pass!(pm::PassManager, pass) + API.mlirPassManagerAddOwnedPass(pm, pass) + return pm +end + +# Where to dump the MLIR modules +const DUMP_MLIR_DIR = Ref{Union{Nothing,String}}(nothing) +# Whether to always dump MLIR, regardless of failure +const DUMP_MLIR_ALWAYS = Ref{Bool}(false) +# Counter for dumping MLIR modules +const MLIR_DUMP_COUNTER = Threads.Atomic{Int}(0) + +const DUMP_RNG = StableRNG(0) + +function dump_mlir( + mod::Module, pm::Union{Nothing,PassManager}=nothing, mode::String=""; failed::Bool=false +) + return dump_mlir(Operation(mod), pm, mode; failed) +end + +# Utilities for dumping to a file the module of a failed compilation, useful for +# debugging purposes. +function dump_mlir( + op::Operation, + pm::Union{Nothing,PassManager}=nothing, + mode::String=""; + failed::Bool=false, +) + try + # If `DUMP_MLIR_DIR` is `nothing`, create a persistent new temp + # directory, otherwise use the provided path. + dir = if isnothing(DUMP_MLIR_DIR[]) + mkpath(tempdir()) + # Use the same directory for this session + DUMP_MLIR_DIR[] = mktempdir(; prefix="mlir_", cleanup=false) + else + DUMP_MLIR_DIR[] + end + + # Make sure the directory exists + mkpath(dir) + + # Attempt to get the name of the module if that exists + mod_name = getattr(op, String(API.mlirSymbolTableGetSymbolAttributeName())) + fname = mod_name === nothing ? randstring(DUMP_RNG, 4) : String(mod_name) + fname = "module_" * lpad(MLIR_DUMP_COUNTER[], 3, "0") * "_$(fname)" + if isempty(mode) + fname *= ".mlir" + else + if length(mode) > 100 + mode = mode[1:100] + end + fname *= "_$(mode).mlir" + end + MLIR_DUMP_COUNTER[] += 1 + path = joinpath(dir, fname) + + open(path, "w") do io + if !isnothing(pm) + println(io, "// Pass pipeline:") + print(io, "// ") + print_pass_pipeline(io, OpPassManager(pm)) + println(io) + end + show(IOContext(io, :debug => true), op) + end + if failed + @error "Compilation failed, MLIR module written to $(path)" + else + @debug "MLIR module written to $(path)" + end + catch err + @error "Couldn't save MLIR module" exception = err end - if isfailure(status) - throw("failed to run pass manager on module") + flush(stdout) + flush(stderr) + return nothing +end + +function try_compile_dump_mlir(f, mod::Module, pm=nothing) + failed = false + # Dump MLIR before calling `f`. We set `pm` to nothing because the pass + # manager isn't called yet here. + DUMP_MLIR_ALWAYS[] && dump_mlir(mod, nothing, "pre") + try + f() + catch + failed = true + rethrow() + finally + if failed || DUMP_MLIR_ALWAYS[] + dump_mlir(Operation(mod), pm, "post"; failed) + end end - return mod end -struct OpPassManager - op_pass::API.MlirOpPassManager - pass::PassManager +function run!(pm::PassManager, mod::Module, key::String="") + return run!(pm, Operation(mod), key) +end - function OpPassManager(op_pass, pass) - @assert !mlirIsNull(op_pass) "cannot create OpPassManager with null MlirOpPassManager" - return new(op_pass, pass) +""" + run!(passManager, operation, key="") + +Run the provided `passManager` on the given `operation`. +""" +function run!(pm::PassManager, operation, key::String="") + # Dump MLIR before running the pass manager, but also print the list of passes that will be called later. + DUMP_MLIR_ALWAYS[] && + dump_mlir(operation, pm, isempty(key) ? "pre_pm" : "pre_$(key)_pm") + status = LogicalResult(API.mlirPassManagerRunOnOp(pm, operation)) + failed = isfailure(status) + if failed || DUMP_MLIR_ALWAYS[] + dump_mlir(operation, pm, isempty(key) ? "post_pm" : "post_$(key)_pm"; failed) end + if failed + throw("failed to run pass manager on module") + end + return operation +end + +@checked struct OpPassManager + ref::API.MlirOpPassManager + pass::PassManager end """ @@ -90,8 +206,9 @@ end Cast a top-level `PassManager` to a generic `OpPassManager`. """ -OpPassManager(pm::PassManager) = - OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), pm) +function OpPassManager(pm::PassManager) + return OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), pm) +end """ OpPassManager(passManager, operationName) @@ -99,26 +216,45 @@ OpPassManager(pm::PassManager) = Nest an `OpPassManager` under the top-level PassManager, the nested passmanager will only run on operations matching the provided name. The returned `OpPassManager` will be destroyed when the parent is destroyed. To further nest more `OpPassManager` under the newly returned one, see `mlirOpPassManagerNest` below. """ -OpPassManager(pm::PassManager, opname) = - OpPassManager(API.mlirPassManagerGetNestedUnder(pm, opname), pm) +function OpPassManager(pm::PassManager, opname) + return OpPassManager(API.mlirPassManagerGetNestedUnder(pm, opname), pm) +end """ OpPassManager(opPassManager, operationName) Nest an `OpPassManager` under the provided `OpPassManager`, the nested passmanager will only run on operations matching the provided name. The returned `OpPassManager` will be destroyed when the parent is destroyed. """ -OpPassManager(opm::OpPassManager, opname) = - OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) +function OpPassManager(opm::OpPassManager, opname) + return OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, opname), opm.ref) +end -Base.convert(::Core.Type{API.MlirOpPassManager}, op_pass::OpPassManager) = op_pass.op_pass +Base.cconvert(::Core.Type{API.MlirOpPassManager}, op_pass::OpPassManager) = op_pass +Base.unsafe_convert(::Core.Type{API.MlirOpPassManager}, opm::OpPassManager) = opm.ref -function Base.show(io::IO, op_pass::OpPassManager) +""" + pass_pipeline(opPassManager) -> String + +Returns the pass pipeline. +""" +pass_pipeline(op_pass::OpPassManager) = sprint(print_pass_pipeline, op_pass) + +""" + print_pass_pipeline(io::IO, opPassManager) + +Prints the pass pipeline to the IO. +""" +function print_pass_pipeline(io::IO, op_pass::OpPassManager) c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) ref = Ref(io) - println(io, "OpPassManager(\"\"\"") API.mlirPrintPassPipeline(op_pass, c_print_callback, ref) - println(io) - return print(io, "\"\"\")") + return io +end + +function Base.show(io::IO, op_pass::OpPassManager) + println(io, "OpPassManager(\"\"\"") + print_pass_pipeline(io, op_pass) + return print(io, "\n\"\"\")") end struct AddPipelineException <: Exception @@ -130,16 +266,6 @@ function Base.showerror(io::IO, err::AddPipelineException) return nothing end -""" - add_owned_pass!(passManager, pass) - -Add a pass and transfer ownership to the provided top-level `PassManager`. If the pass is not a generic operation pass or a `ModulePass`, a new `OpPassManager` is implicitly nested under the provided PassManager. -""" -function add_owned_pass!(pm::PassManager, pass) - API.mlirPassManagerAddOwnedPass(pm, pass) - return pm -end - """ add_owned_pass!(opPassManager, pass) @@ -151,19 +277,15 @@ function add_owned_pass!(opm::OpPassManager, pass) end """ - parse(passManager, pipeline) + parse(opPassManager, pipeline) Parse a textual MLIR pass pipeline and add it to the provided `OpPassManager`. """ function Base.parse(opm::OpPassManager, pipeline::String) + io = IOBuffer() + c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) result = LogicalResult( - if MLIR_VERSION[] >= v"16" - io = IOBuffer() - c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) - API.mlirParsePassPipeline(opm, pipeline, c_print_callback, Ref(io)) - else - API.mlirParsePassPipeline(opm, pipeline) - end, + API.mlirParsePassPipeline(opm, pipeline, c_print_callback, Ref(io)) ) if isfailure(result) @@ -173,12 +295,12 @@ function Base.parse(opm::OpPassManager, pipeline::String) end """ - add_pipeline!(passManager, pipelineElements, callback, userData) + add_pipeline!(opPassManager, pipeline) Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager. If parsing fails an error message is reported using the provided callback. """ function add_pipeline!(op_pass::OpPassManager, pipeline) - if MLIR_VERSION[] >= v"16" + @static if isdefined(API, :mlirOpPassManagerAddPipeline) io = IOBuffer() c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) result = LogicalResult( @@ -201,25 +323,26 @@ end # AbstractPass interface: opname(::AbstractPass) = "" -function pass_run(::Context, ::P, op) where {P<:AbstractPass} +function pass_run(::Context, ::P, _) where {P<:AbstractPass} return error("pass $P does not implement `MLIR.pass_run`") end -function _pass_construct(ptr::ExternalPassHandle) +function _pass_construct(::ExternalPassHandle) return nothing end -function _pass_destruct(ptr::ExternalPassHandle) +function _pass_destruct(::ExternalPassHandle) return nothing end function _pass_initialize(ctx, handle::ExternalPassHandle) - try + (; ref) = try handle.ctx = Context(ctx) success() catch failure() end + return ref end function _pass_clone(handle::ExternalPassHandle) @@ -227,7 +350,7 @@ function _pass_clone(handle::ExternalPassHandle) end function _pass_run(rawop, external_pass, handle::ExternalPassHandle) - op = Operation(rawop, false) + op = Operation(rawop) try pass_run(handle.ctx, handle.pass, op) catch ex @@ -238,8 +361,9 @@ function _pass_run(rawop, external_pass, handle::ExternalPassHandle) end function create_external_pass!(oppass::OpPassManager, args...) - return create_external_pass!(oppass.pass, args...) + return create_external_pass!(oppass.ref, args...) end + function create_external_pass!( manager, pass, @@ -250,8 +374,6 @@ function create_external_pass!( dependent_dialects=API.MlirDialectHandle[], ) passid = TypeID(manager.allocator) - MLIR_VERSION[] >= v"15" || - throw(MLIRException("`create_external_pass!` requires MLIR version 15 or later")) callbacks = API.MlirExternalPassCallbacks( @cfunction(_pass_construct, Cvoid, (Any,)), @cfunction(_pass_destruct, Cvoid, (Any,)), diff --git a/src/IR/Region.jl b/src/IR/Region.jl index 38d38c8f..f0491928 100644 --- a/src/IR/Region.jl +++ b/src/IR/Region.jl @@ -1,15 +1,5 @@ -mutable struct Region - region::API.MlirRegion - @atomic owned::Bool - - function Region(region, owned=true) - @assert !mlirIsNull(region) - finalizer(new(region, owned)) do region - if region.owned - API.mlirRegionDestroy(region.region) - end - end - end +@checked struct Region + ref::API.MlirRegion end """ @@ -17,10 +7,18 @@ end Creates a new empty region and transfers ownership to the caller. """ -Region() = Region(API.mlirRegionCreate()) +Region() = Region(mark_alloc(API.mlirRegionCreate())) + +""" + dispose(region::Region) + +Disposes the given region and releases its resources. +After calling this function, the region must not be used anymore. +""" +dispose(region::Region) = mark_dispose(API.mlirRegionDestroy, region) Base.cconvert(::Core.Type{API.MlirRegion}, region::Region) = region -Base.unsafe_convert(::Core.Type{API.MlirRegion}, region::Region) = region.region +Base.unsafe_convert(::Core.Type{API.MlirRegion}, region::Region) = mark_use(region).ref """ ==(region, other) @@ -29,13 +27,42 @@ Checks whether two region handles point to the same region. This does not perfor """ Base.:(==)(a::Region, b::Region) = API.mlirRegionEqual(a, b) +Base.IteratorSize(::Core.Type{Region}) = Base.SizeUnknown() +Base.IteratorEltype(::Core.Type{Region}) = Base.HasEltype() +Base.eltype(::Region) = Block + +""" + Base.iterate(region::Region) + +Iterates over all [`Block`](@ref) in the given region. +""" +function Base.iterate(it::Region) + raw_block = API.mlirRegionGetFirstBlock(it) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block) + (b, b) + end +end + +function Base.iterate(::Region, block) + raw_block = API.mlirBlockGetNextInRegion(block) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block) + (b, b) + end +end + """ push!(region, block) Takes a block owned by the caller and appends it to the given region. """ function Base.push!(region::Region, block::Block) - API.mlirRegionAppendOwnedBlock(region, lose_ownership!(block)) + API.mlirRegionAppendOwnedBlock(region, mark_donate(block)) return block end @@ -45,7 +72,7 @@ end Takes a block owned by the caller and inserts it at `index` to the given region. This is an expensive operation that linearly scans the region, prefer insertAfter/Before instead. """ function Base.insert!(region::Region, index, block::Block) - API.mlirRegionInsertOwnedBlock(region, index - 1, lose_ownership!(block)) + API.mlirRegionInsertOwnedBlock(region, index - 1, mark_donate(block)) return block end @@ -59,16 +86,18 @@ end Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, prepends the block to the region. """ -insert_after!(region::Region, reference::Block, block::Block) = - API.mlirRegionInsertOwnedBlockAfter(region, reference, lose_ownership!(block)) +function insert_after!(region::Region, reference::Block, block::Block) + return API.mlirRegionInsertOwnedBlockAfter(region, reference, mark_donate(block)) +end """ insert_before!(region, reference, block) Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, appends the block to the region. """ -insert_before!(region::Region, reference::Block, block::Block) = - API.mlirRegionInsertOwnedBlockBefore(region, reference, lose_ownership!(block)) +function insert_before!(region::Region, reference::Block, block::Block) + return API.mlirRegionInsertOwnedBlockBefore(region, reference, mark_donate(block)) +end """ first_block(region) @@ -78,12 +107,6 @@ Gets the first block in the region. function first_block(region::Region) block = API.mlirRegionGetFirstBlock(region) mlirIsNull(block) && return nothing - return Block(block, false) + return Block(block) end Base.first(region::Region) = first_block(region) - -function lose_ownership!(region::Region) - @assert region.owned - @atomic region.owned = false - return region -end diff --git a/src/IR/SymbolTable.jl b/src/IR/SymbolTable.jl index 3b5d2073..8c96bbe8 100644 --- a/src/IR/SymbolTable.jl +++ b/src/IR/SymbolTable.jl @@ -1,23 +1,22 @@ -mutable struct SymbolTable - st::API.MlirSymbolTable - - function SymbolTable(st) - @assert !mlirIsNull(st) "cannot create SymbolTable with null MlirSymbolTable" - return finalizer(API.mlirSymbolTableDestroy, new(st)) - end +@checked struct SymbolTable + ref::API.MlirSymbolTable end """ - mlirSymbolTableCreate(operation) + SymbolTable(operation) Creates a symbol table for the given operation. If the operation does not have the SymbolTable trait, returns a null symbol table. """ -SymbolTable(op::Operation) = SymbolTable(API.mlirSymbolTableCreate(op)) +SymbolTable(op::Operation) = SymbolTable(mark_alloc(API.mlirSymbolTableCreate(op))) +SymbolTable(mod::Module) = SymbolTable(Operation(mod)) + +dispose(st::SymbolTable) = mark_dispose(API.mlirSymbolTableDestroy(st)) -Base.convert(::Core.Type{API.MlirSymbolTable}, st::SymbolTable) = st.st +Base.cconvert(::Core.Type{API.MlirSymbolTable}, st::SymbolTable) = st +Base.unsafe_convert(::Core.Type{API.MlirSymbolTable}, st::SymbolTable) = mark_use(st).ref -# TODO mlirSymbolTableGetSymbolAttributeName -# TODO mlirSymbolTableGetVisibilityAttributeName +# TODO(#2246) mlirSymbolTableGetSymbolAttributeName +# TODO(#2246) mlirSymbolTableGetVisibilityAttributeName """ lookup(symboltable, name) @@ -25,12 +24,20 @@ Base.convert(::Core.Type{API.MlirSymbolTable}, st::SymbolTable) = st.st Looks up a symbol with the given name in the given symbol table and returns the operation that corresponds to the symbol. If the symbol cannot be found, returns a null operation. """ -lookup(st::SymbolTable, name::AbstractString) = - Operation(API.mlirSymbolTableLookup(st, name)) -Base.getindex(st::SymbolTable, name::AbstractString) = lookup(st, name) +function lookup(st::SymbolTable, name::AbstractString) + raw_op = API.mlirSymbolTableLookup(st, name) + if raw_op.ptr == C_NULL + nothing + else + Operation(raw_op) + end +end +function Base.getindex(st::SymbolTable, name::AbstractString) + @something(lookup(st, name), throw(KeyError(name))) +end """ - push!(symboltable, operation) + Base.push!(symboltable, operation) Inserts the given operation into the given symbol table. The operation must have the symbol trait. If the symbol table already has a symbol with the same name, renames the symbol being inserted to ensure name uniqueness. @@ -40,11 +47,11 @@ Returns the name of the symbol after insertion. Base.push!(st::SymbolTable, op::Operation) = Attribute(API.mlirSymbolTableInsert(st, op)) """ - delete!(symboltable, operation) + Base.delete!(symboltable, operation) Removes the given operation from the symbol table and erases it. """ -delete!(st::SymbolTable, op::Operation) = API.mlirSymbolTableErase(st, op) +Base.delete!(st::SymbolTable, op::Operation) = API.mlirSymbolTableErase(st, op) -# TODO mlirSymbolTableReplaceAllSymbolUses -# TODO mlirSymbolTableWalkSymbolTables +# TODO(#2246) mlirSymbolTableReplaceAllSymbolUses +# TODO(#2246) mlirSymbolTableWalkSymbolTables diff --git a/src/IR/Type.jl b/src/IR/Type.jl index a62c2925..9bb244e7 100644 --- a/src/IR/Type.jl +++ b/src/IR/Type.jl @@ -1,21 +1,18 @@ -struct Type - type::API.MlirType - - function Type(type) - @assert !mlirIsNull(type) "cannot create Type with null MlirType" - return new(type) - end +@checked struct Type + ref::API.MlirType end -Base.convert(::Core.Type{API.MlirType}, type::Type) = type.type +Base.cconvert(::Core.Type{API.MlirType}, type::Type) = type +Base.unsafe_convert(::Core.Type{API.MlirType}, type::Type) = type.ref """ - parse(type; context=context()) + Base.parse(type; context=current_context()) Parses a type. The type is owned by the context. """ -Base.parse(::Core.Type{Type}, s; context::Context=context()) = - Type(API.mlirTypeParseGet(context, s)) +function Base.parse(::Core.Type{Type}, s; context::Context=current_context()) + return Type(API.mlirTypeParseGet(context, s)) +end """ ==(t1, t2) @@ -40,11 +37,12 @@ typeid(type::Type) = TypeID(API.mlirTypeGetTypeID(type)) # None type """ - Type(::Core.Type{Nothing}; context=context()) + Type(::Core.Type{Nothing}; context=current_context()) Creates a None type in the given context. The type is owned by the context. """ -Type(::Core.Type{Nothing}; context::Context=context()) = Type(API.mlirNoneTypeGet(context)) +Type(::Core.Type{Nothing}; context::Context=current_context()) = + Type(API.mlirNoneTypeGet(context)) """ mlirTypeIsANone(type) @@ -55,11 +53,11 @@ isnone(type::Type) = API.mlirTypeIsANone(type) # Index type """ - IndexType(; context=context()) + IndexType(; context=current_context()) Creates an index type in the given context. The type is owned by the context. """ -IndexType(; context::Context=context()) = Type(API.mlirIndexTypeGet(context)) +IndexType(; context::Context=current_context()) = Type(API.mlirIndexTypeGet(context)) """ isindex(type) @@ -69,37 +67,41 @@ Checks whether the given type is an index type. isindex(type::Type) = API.mlirTypeIsAIndex(type) """ - Type(T::Core.Type{Bool}; context=context() + Type(T::Core.Type{Bool}; context=current_context() Creates a 1-bit signless integer type in the context. The type is owned by the context. """ -Type(::Core.Type{Bool}; context::Context=context()) = - Type(API.mlirIntegerTypeGet(context, 1)) +function Type(::Core.Type{Bool}; context::Context=current_context()) + return Type(API.mlirIntegerTypeGet(context, 1)) +end # Integer types """ - Type(T::Core.Type{<:Integer}; context=context() + Type(T::Core.Type{<:Integer}; context=current_context() Creates a signless integer type of the given bitwidth in the context. The type is owned by the context. """ -Type(T::Core.Type{<:Integer}; context::Context=context()) = - Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +function Type(T::Core.Type{<:Integer}; context::Context=current_context()) + return Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +end """ - Type(T::Core.Type{<:Signed}; context=context() + Type(T::Core.Type{<:Signed}; context=current_context() Creates a signed integer type of the given bitwidth in the context. The type is owned by the context. """ -Type(T::Core.Type{<:Signed}; context::Context=context()) = - Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +function Type(T::Core.Type{<:Signed}; context::Context=current_context()) + return Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +end """ - Type(T::Core.Type{<:Unsigned}; context=context() + Type(T::Core.Type{<:Unsigned}; context=current_context() Creates an unsigned integer type of the given bitwidth in the context. The type is owned by the context. """ -Type(T::Core.Type{<:Unsigned}; context::Context=context()) = - Type(API.mlirIntegerTypeUnsignedGet(context, sizeof(T) * 8)) +function Type(T::Core.Type{<:Unsigned}; context::Context=current_context()) + return Type(API.mlirIntegerTypeUnsignedGet(context, sizeof(T) * 8)) +end """ isinteger(type) @@ -141,76 +143,61 @@ end # Floating point types """ - Float8E5M2(; context=context()) + Float8E5M2(; context=current_context()) Creates an f8E5M2 type in the given context. The type is owned by the context. """ -function Float8E5M2(; context::Context=context()) - MLIR_VERSION[] >= v"16" || - throw(MLIRException("`Float8E5M2()` requires MLIR version 16 or later")) - return Type(API.mlirFloat8E5M2TypeGet(context)) -end +Float8E5M2(; context::Context=current_context()) = Type(API.mlirFloat8E5M2TypeGet(context)) """ - Float8E4M3FN(; context=context()) + Float8E4M3FN(; context=current_context()) Creates an f8E4M3FN type in the given context. The type is owned by the context. """ -function Float8E4M3FN(; context::Context=context()) - MLIR_VERSION[] >= v"16" || - throw(MLIRException("`Float8E4M3FN()` requires MLIR version 16 or later")) - return Type(API.mlirFloat8E4M3FNTypeGet(context)) -end +Float8E4M3FN(; context::Context=current_context()) = + Type(API.mlirFloat8E4M3FNTypeGet(context)) """ -BFloat16Type(; context=context()) +BFloat16Type(; context=current_context()) Creates a bf16 type in the given context. The type is owned by the context. """ -BFloat16Type(; context::Context=context()) = Type(API.mlirBF16TypeGet(context)) +BFloat16Type(; context::Context=current_context()) = Type(API.mlirBF16TypeGet(context)) """ - Type(::Core.Type{Float16}; context=context()) + Type(::Core.Type{Float16}; context=current_context()) Creates an f16 type in the given context. The type is owned by the context. """ -Type(::Core.Type{Float16}; context::Context=context()) = Type(API.mlirF16TypeGet(context)) - -""" - Type(Core.Type{Float32}; context=context()) - -Creates an f32 type in the given context. The type is owned by the context. -""" -Type(::Core.Type{Float32}; context::Context=context()) = Type(API.mlirF32TypeGet(context)) +Type(::Core.Type{Float16}; context::Context=current_context()) = + Type(API.mlirF16TypeGet(context)) -""" - Type(Core.Type{Float64}; context=context()) +if isdefined(Core, :BFloat16) + """ + Type(::Core.Type{Core.BFloat16}; context=current_context()) -Creates a f64 type in the given context. The type is owned by the context. -""" -Type(::Core.Type{Float64}; context::Context=context()) = Type(API.mlirF64TypeGet(context)) + Creates an bf16 type in the given context. The type is owned by the context. + """ + function Type(::Core.Type{Core.BFloat16}; context::Context=current_context()) + return BFloat16Type(; context) + end +end """ - isf8e5m2(type) + Type(Core.Type{Float32}; context=current_context()) -Checks whether the given type is an f8E5M2 type. +Creates an f32 type in the given context. The type is owned by the context. """ -function isf8e5m2(type::Type) - MLIR_VERSION[] >= v"16" || - throw(MLIRException("`isf8e5m2()` requires MLIR version 16 or later")) - return API.mlirTypeIsAFloat8E5M2(type) -end +Type(::Core.Type{Float32}; context::Context=current_context()) = + Type(API.mlirF32TypeGet(context)) """ - isf8e4m3fn(type) + Type(Core.Type{Float64}; context=current_context()) -Checks whether the given type is an f8E4M3FN type. +Creates a f64 type in the given context. The type is owned by the context. """ -function isf8e4m3fn(type::Type) - MLIR_VERSION[] >= v"16" || - throw(MLIRException("`isf8e4m3fn()` requires MLIR version 16 or later")) - return API.mlirTypeIsAFloat8E4M3FN(type) -end +Type(::Core.Type{Float64}; context::Context=current_context()) = + Type(API.mlirF64TypeGet(context)) """ isbf16(type) @@ -365,14 +352,17 @@ isvector(type::Type) = API.mlirTypeIsAVector(type) TensorType(shape, elementType, encoding=Attribute(); location=Location(), check=false) Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in the same context as the element type. -The type is owned by the context. Tensor types without any specific encoding field should assign [`mlirAttributeGetNull`](@ref) to this parameter. +The type is owned by the context. Tensor types without any specific encoding field should assign [`MLIR.API.mlirAttributeGetNull`](@ref) to this parameter. If `check=true`, emits appropriate diagnostics on illegal arguments. """ -function TensorType( - shape, elem_type, encoding=Attribute(); location::Location=Location(), check::Bool=false +Base.@nospecializeinfer function TensorType( + shape::Vector{Int}, + @nospecialize(elem_type::Type), + encoding=Attribute(); + location::Location=Location(), + check::Bool=false, ) rank = length(shape) - shape = shape isa AbstractVector ? shape : collect(shape) return Type( if check API.mlirRankedTensorTypeGetChecked(location, rank, shape, elem_type, encoding) @@ -388,7 +378,7 @@ end Creates an unranked tensor type with the given element type in the same context as the element type. The type is owned by the context. If `check=true`, emits appropriate diagnostics on illegal arguments. """ -function TensorType(elem_type; location::Location=Location(), check::Bool=false) +function TensorType(elem_type::Type; location::Location=Location(), check::Bool=false) return Type( if check API.mlirUnrankedTensorTypeGetChecked(location, elem_type) @@ -398,7 +388,7 @@ function TensorType(elem_type; location::Location=Location(), check::Bool=false) ) end -# TODO maybe add these helper methods? +# TODO(#2245) maybe add these helper methods? # Type(a::AbstractArray{T}) where {T} = Type(Type(T), size(a)) # Type(::Core.Type{<:AbstractArray{T,N}}, dims) where {T,N} = # Type(API.mlirRankedTensorTypeGetChecked( @@ -465,15 +455,11 @@ function MemRefType( if check Type( API.mlirMemRefTypeGetChecked( - location, elem_type, length(shape), pointer(shape), layout, memspace + location, elem_type, length(shape), shape, layout, memspace ), ) else - Type( - API.mlirMemRefTypeGet( - elem_type, length(shape), pointer(shape), layout, memspace - ), - ) + Type(API.mlirMemRefTypeGet(elem_type, length(shape), shape, layout, memspace)) end end @@ -490,15 +476,11 @@ function MemRefType( if check Type( API.mlirMemRefTypeContiguousGetChecked( - location, elem_type, length(shape), pointer(shape), memspace + location, elem_type, length(shape), shape, memspace ), ) else - Type( - API.mlirMemRefTypeContiguousGet( - elem_type, length(shape), pointer(shape), memspace - ), - ) + Type(API.mlirMemRefTypeContiguousGet(elem_type, length(shape), shape, memspace)) end end @@ -551,7 +533,7 @@ Returns the affine map of the given MemRef type. """ function affinemap(type::Type) @assert ismemref(type) "expected a MemRef type" - return AffineMap(API.mlirMemRefTypeGetAffineMaps(type)) + return AffineMap(API.mlirMemRefTypeGetAffineMap(type)) end """ @@ -570,17 +552,19 @@ end # Tuple type """ - Type(elements; context=context()) - Type(::Core.Type{<:Tuple{T...}}; context=context()) + Type(elements; context=current_context()) + Type(::Core.Type{<:Tuple{T...}}; context=current_context()) Creates a tuple type that consists of the given list of elemental types. The type is owned by the context. """ -Type(elements::Vector{Type}; context::Context=context()) = - Type(API.mlirTupleTypeGet(context, length(elements), pointer(elements))) -function Type(@nospecialize(elements::NTuple{N,Type}); context::Context=context()) where {N} +Type(elements::Vector{Type}; context::Context=current_context()) = + Type(API.mlirTupleTypeGet(context, length(elements), elements)) +function Type( + @nospecialize(elements::NTuple{N,Type}); context::Context=current_context() +) where {N} return Type(collect(elements); context) end -function Type(T::Core.Type{<:Tuple}; context::Context=context()) +function Type(T::Core.Type{<:Tuple}; context::Context=current_context()) return Type(map(Type, T.parameters); context) end @@ -600,19 +584,17 @@ Checks whether the given type is a function type. isfunction(type::Type) = API.mlirTypeIsAFunction(type) """ - FunctionType(inputs, results; context=context()) + FunctionType(inputs, results; context=current_context()) Creates a function type, mapping a list of input types to result types. """ -function FunctionType(inputs, results; context::Context=context()) +function FunctionType(inputs, results; context::Context=current_context()) return Type( - API.mlirFunctionTypeGet( - context, length(inputs), pointer(inputs), length(results), pointer(results) - ), + API.mlirFunctionTypeGet(context, length(inputs), inputs, length(results), results) ) end -# TODO maybe add this helper method? +# TODO(#2245) maybe add this helper method? # Type(ft::Pair) = Type(API.mlirFunctionTypeGet(context(), # length(ft.first), [Type(t) for t in ft.first], # length(ft.second), [Type(t) for t in ft.second])) @@ -659,11 +641,11 @@ end # Opaque type """ - OpaqueType(dialectNamespace, typeData; context=context()) + OpaqueType(dialectNamespace, typeData; context=current_context()) Creates an opaque type in the given context associated with the dialect identified by its namespace. The type contains opaque byte data of the specified length (data need not be null-terminated). """ -OpaqueType(namespace, data; context::Context=context()) = +OpaqueType(namespace, data; context::Context=current_context()) = Type(API.mlirOpaqueTypeGet(context, namespace, data)) """ @@ -747,6 +729,8 @@ function julia_type(type::Type) throw("could not convert unsigned $width-bit integer type to julia") end end + elseif isbf16(type) + Core.BFloat16 elseif isf16(type) Float16 elseif isf32(type) diff --git a/src/IR/TypeID.jl b/src/IR/TypeID.jl index 89e8745a..c7413475 100644 --- a/src/IR/TypeID.jl +++ b/src/IR/TypeID.jl @@ -1,24 +1,13 @@ -struct TypeID - typeid::API.MlirTypeID - - function TypeID(typeid) - @assert !mlirIsNull(typeid) "cannot create TypeID with null MlirTypeID" - return new(typeid) - end +@checked struct TypeID + ref::API.MlirTypeID end TypeID(type::Type) = TypeID(API.mlirTypeGetTypeID(type)) # mlirTypeIDCreate -""" - hash(typeID) - -Returns the hash value of the type id. -""" -Base.hash(typeid::TypeID) = API.mlirTypeIDHashValue(typeid.typeid) - -Base.convert(::Core.Type{API.MlirTypeID}, typeid::TypeID) = typeid.typeid +Base.cconvert(::Core.Type{API.MlirTypeID}, typeid::TypeID) = typeid +Base.unsafe_convert(::Core.Type{API.MlirTypeID}, typeid::TypeID) = typeid.ref """ ==(typeID1, typeID2) @@ -27,21 +16,26 @@ Checks if two type ids are equal. """ Base.:(==)(a::TypeID, b::TypeID) = API.mlirTypeIDEqual(a, b) -mutable struct TypeIDAllocator - allocator::API.MlirTypeIDAllocator +""" + hash(typeID) + +Returns the hash value of the type id. +""" +Base.hash(typeid::TypeID) = API.mlirTypeIDHashValue(typeid) - function TypeIDAllocator() - MLIR_VERSION[] >= v"15" || - throw(MLIRException("`TypeIDAllocator` requires MLIR version 15 or later")) - ptr = API.mlirTypeIDAllocatorCreate() - @assert ptr != C_NULL "cannot create TypeIDAllocator" - return finalizer(API.mlirTypeIDAllocatorDestroy, new(ptr)) - end +@checked struct TypeIDAllocator + ref::API.MlirTypeIDAllocator end -Base.cconvert(::Core.Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator -Base.unsafe_convert(::Core.Type{API.MlirTypeIDAllocator}, allocator) = allocator.allocator +TypeIDAllocator() = TypeIDAllocator(mark_alloc(API.mlirTypeIDAllocatorCreate())) + +dispose(alloc::TypeIDAllocator) = mark_dispose(API.mlirTypeIDAllocatorDestroy(alloc)) + +Base.cconvert(::Core.Type{API.MlirTypeIDAllocator}, alloc::TypeIDAllocator) = alloc +function Base.unsafe_convert(::Core.Type{API.MlirTypeIDAllocator}, alloc::TypeIDAllocator) + return mark_use(alloc).ref +end -function TypeID(allocator::TypeIDAllocator) - return TypeID(API.mlirTypeIDAllocatorAllocateTypeID(allocator)) +function TypeID(alloc::TypeIDAllocator) + return TypeID(mark_alloc(API.mlirTypeIDAllocatorAllocateTypeID(alloc))) end diff --git a/src/IR/Utils.jl b/src/IR/Utils.jl new file mode 100644 index 00000000..cdbe2399 --- /dev/null +++ b/src/IR/Utils.jl @@ -0,0 +1,112 @@ +function mlirIsNull(val) + return val.ptr == C_NULL +end + +function print_callback(str::API.MlirStringRef, userdata) + data = unsafe_wrap(Array, Base.convert(Ptr{Cchar}, str.data), str.length; own=false) + write(userdata isa Base.RefValue ? userdata[] : userdata, data) + return Cvoid() +end + +# MlirStringRef is a non-owning reference to a string, +# we thus need to ensure that the Julia string remains alive +# over the use. For that we use the cconvert/unsafe_convert mechanism +# for foreign-calls. The returned value of the cconvert is rooted across +# foreign-call. +Base.cconvert(::Core.Type{API.MlirStringRef}, s::Union{Symbol,String}) = s +function Base.cconvert(::Core.Type{API.MlirStringRef}, s::AbstractString) + return Base.cconvert(API.MlirStringRef, String(s)::String) +end + +# Directly create `MlirStringRef` instead of adding an extra ccall. +function Base.unsafe_convert( + ::Core.Type{API.MlirStringRef}, s::Union{Symbol,String,AbstractVector{UInt8}} +) + p = Base.unsafe_convert(Ptr{Cchar}, s) + return API.MlirStringRef(p, sizeof(s)) +end + +function Base.String(str::API.MlirStringRef) + return Base.unsafe_string(pointer(str.data), str.length) +end + +Base.String(str::API.MlirIdentifier) = String(API.mlirIdentifierStr(str)) + +function visit(f, op) + all_ok = true + for region in op + for block in region + for op in block + all_ok &= f(op) + end + end + end + return all_ok +end + +""" + @dispose foo=Foo() bar=Bar() begin + ... + end + +Helper macro for disposing resources (by calling the `dispose` function for every resource +in reverse order) after executing a block of code. This is often equivalent to calling the +recourse constructor with do-block syntax, but without using (potentially costly) closures. +""" +macro dispose(ex...) + resources = ex[1:(end - 1)] + code = ex[end] + + Meta.isexpr(code, :block) || + error("Expected a code block as final argument to LLVM.@dispose") + + cleanup = quote end + for res in reverse(resources) + Meta.isexpr(res, :(=)) || + error("Resource arguments to LLVM.@dispose should be assignments") + push!(cleanup.args, :($dispose($(res.args[1])))) + end + + ex = quote + let $(resources...) + try + $code + finally + $(cleanup.args...) + end + end + end + return esc(ex) +end + +# TODO potentially move to `ScopedValues.@with` if we move from task-local storage to ScopedValues +""" + @scope obj begin + body + end + +Activates `obj` for the duration of `body`, then deactivates it. +""" +macro scope(obj, body) + bodybody = if Base.isexpr(body, :block) + body.args + else + [body] + end + if Base.isexpr(obj, :(=)) + prologue = esc(obj) + symbol = obj.args[1] + else + prologue = nothing + symbol = esc(obj) + end + quote + $prologue + activate($symbol) + try + $(esc.(bodybody)...) + finally + deactivate($symbol) + end + end +end diff --git a/src/IR/Value.jl b/src/IR/Value.jl index 6538447c..ce4b0f1b 100644 --- a/src/IR/Value.jl +++ b/src/IR/Value.jl @@ -1,13 +1,10 @@ -struct Value - value::API.MlirValue - - function Value(value) - @assert !mlirIsNull(value) "cannot create Value with null MlirValue" - return new(value) - end +@checked struct Value + ref::API.MlirValue end -Base.convert(::Core.Type{API.MlirValue}, value::Value) = value.value +Base.cconvert(::Core.Type{API.MlirValue}, value::Value) = value +Base.unsafe_convert(::Core.Type{API.MlirValue}, value::Value) = value.ref + Base.size(value::Value) = Base.size(type(value)) Base.ndims(value::Value) = Base.ndims(type(value)) @@ -39,7 +36,7 @@ Returns the block in which this value is defined as an argument. Asserts if the """ function block_owner(value::Value) @assert is_block_arg(value) "could not get owner, value is not a block argument" - return Block(API.mlirBlockArgumentGetOwner(value), false) + return Block(API.mlirBlockArgumentGetOwner(value)) end """ @@ -49,18 +46,18 @@ Returns an operation that produced this value as its result. Asserts if the valu """ function op_owner(value::Value) @assert is_op_res(value) "could not get owner, value is not an op result" - return Operation(API.mlirOpResultGetOwner(value), false) + return Operation(API.mlirOpResultGetOwner(value)) end function owner(value::Value) if is_block_arg(value) raw_block = API.mlirBlockArgumentGetOwner(value) mlirIsNull(raw_block) && return nothing - return Block(raw_block, false) + return Block(raw_block) elseif is_op_res(value) raw_op = API.mlirOpResultGetOwner(value) mlirIsNull(raw_op) && return nothing - return Operation(raw_op, false) + return Operation(raw_op) else error("Value is neither a block argument nor an op result") end @@ -104,12 +101,12 @@ Returns the type of the value. type(value::Value) = Type(API.mlirValueGetType(value)) """ - set_type!(value, type) + settype!(value, type) Sets the type of the block argument to the given type. """ -function type!(value, type) - @assert is_a_block_argument(value) "could not set type, value is not a block argument" +function settype!(value, type) + @assert is_block_arg(value) "could not set type, value is not a block argument" API.mlirBlockArgumentSetType(value, type) return value end @@ -117,5 +114,7 @@ end function Base.show(io::IO, value::Value) c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) ref = Ref(io) - return API.mlirValuePrint(value, c_print_callback, ref) + GC.@preserve value ref begin + API.mlirValuePrint(value, c_print_callback, ref) + end end diff --git a/src/IR/debug.jl b/src/IR/debug.jl new file mode 100644 index 00000000..57af2401 --- /dev/null +++ b/src/IR/debug.jl @@ -0,0 +1,205 @@ +# inspired by LLVM.jl's memory tracking utilities, but with an extra `mark_donate` function to track when objects are given another object as owner +using Preferences + +const MEMCHECK_ENABLED = parse(Bool, @load_preference("memcheck", "false")) + +# object => (alloc_bt, dispose_bt) +const tracked_objects = Dict{Any,Any}() + +# the most basic check is asserting that we don't use a null pointer +@inline function refcheck(::Type, ref::Ptr) + return ref == C_NULL && throw(UndefRefError()) +end + +function mark_alloc(obj; allow_overwrite::Bool=false) + @static if MEMCHECK_ENABLED + io = Core.stdout + new_alloc_bt = backtrace()[2:end] + + if haskey(tracked_objects, obj) && !allow_overwrite + old_alloc_bt, dispose_bt = tracked_objects[obj] + if dispose_bt == nothing + print( + "\nWARNING: An instance of $(typeof(obj)) was not properly disposed of, and a new allocation will overwrite it.", + ) + print("\nThe original allocation was at:") + Base.show_backtrace(io, old_alloc_bt) + print("\nThe new allocation is at:") + Base.show_backtrace(io, new_alloc_bt) + println(io) + end + end + + tracked_objects[obj] = (new_alloc_bt, nothing) + end + return obj +end + +function mark_use(obj::Any) + @static if MEMCHECK_ENABLED + io = Core.stdout + + if !haskey(tracked_objects, obj) + # we have to ignore unknown objects, as they may originate externally. + # for example, a Julia-created Type we call `context` on. + return obj + end + + alloc_bt, dispose_bt = tracked_objects[obj] + if dispose_bt !== nothing + print( + "\nWARNING: An instance of $(typeof(obj)) is being used after it was disposed of.", + ) + print("\nThe object was allocated at:") + Base.show_backtrace(io, alloc_bt) + print("\nThe object was disposed of at:") + Base.show_backtrace(io, dispose_bt) + print("\nThe object is being used at:") + Base.show_backtrace(io, backtrace()[2:end]) + println(io) + end + end + return obj +end + +mark_dispose(obj) = mark_dispose(Returns(nothing), obj) + +function mark_dispose(f, obj) + data = @static if MEMCHECK_ENABLED + io = Core.stdout + new_dispose_bt = backtrace()[2:end] + + if !haskey(tracked_objects, obj) + print( + io, "\nWARNING: An unknown instance of $(typeof(obj)) is being disposed of." + ) + Base.show_backtrace(io, new_dispose_bt) + nothing + else + alloc_bt, old_dispose_bt = tracked_objects[obj] + if old_dispose_bt !== nothing + print( + "\nWARNING: An instance of $(typeof(obj)) is being disposed of twice." + ) + print("\nThe object was allocated at:") + Base.show_backtrace(io, alloc_bt) + print("\nThe object was already disposed of at:") + Base.show_backtrace(io, old_dispose_bt) + print("\nThe object is being disposed of again at:") + Base.show_backtrace(io, new_dispose_bt) + println(io) + end + + (alloc_bt, new_dispose_bt) + end + end + ret = f(obj) + @static if MEMCHECK_ENABLED + if data !== nothing + tracked_objects[obj] = data + end + end + return nothing +end + +# we could potentially track ownership here +mark_donate(new_owner, obj) = mark_dispose(obj) + +# MLIR.API types +for AT in [ + :MlirDialect, + :MlirDialectHandle, + :MlirDialectRegistry, + :MlirContext, + :MlirLocation, + :MlirType, + :MlirTypeID, + :MlirTypeIDAllocator, + :MlirModule, + :MlirOperation, + :MlirOpOperand, + :MlirBlock, + :MlirRegion, + :MlirValue, + # :MlirLogicalResult, + :MlirAffineExpr, + :MlirAffineMap, + # :MlirAttribute, + # :MlirNamedAttribute, + :MlirIntegerSet, + :MlirIdentifier, + :MlirSymbolTable, + :MlirExecutionEngine, + :MlirPassManager, + :MlirOpPassManager, +] + @eval refcheck(T::Core.Type, ref::API.$AT) = refcheck(T, ref.ptr) +end + +function report_leaks(code=0) + # if we errored, we can't trust the memory state + if code != 0 + return nothing + end + + @static if MEMCHECK_ENABLED + io = Core.stdout + for (obj, (alloc_bt, dispose_bt)) in tracked_objects + if dispose_bt === nothing + print( + io, + "\nWARNING: An instance of $(typeof(obj)) was not properly disposed of.", + ) + print("\nThe object was allocated at:") + Base.show_backtrace(io, alloc_bt) + println(io) + end + end + end +end + +# macro that adds an inner constructor to a type definition, +# calling `refcheck` on the ref field argument +macro checked(typedef) + # decode structure definition + if Meta.isexpr(typedef, :struct) + structure = typedef.args[2] + body = typedef.args[3] + else + error("argument is not a structure definition") + end + if isa(structure, Symbol) + # basic type definition + typename = structure + elseif Meta.isexpr(structure, :<:) + # typename <: parentname + all(e -> isa(e, Symbol), structure.args) || + error("typedef should consist of plain types, ie. not parametric ones") + typename = structure.args[1] + else + error("malformed type definition: cannot decode type name") + end + + # decode fields + field_names = Symbol[] + field_defs = Union{Symbol,Expr}[] + for arg in body.args + if isa(arg, LineNumberNode) + continue + elseif isa(arg, Symbol) + push!(field_names, arg) + push!(field_defs, arg) + elseif Meta.isexpr(arg, :(::)) + push!(field_names, arg.args[1]) + push!(field_defs, arg) + end + end + :ref in field_names || error("structure definition should contain 'ref' field") + + # insert checked constructor + push!(body.args, :(function $typename($(field_defs...)) + return ($refcheck($typename, ref); new($(field_names...))) + end)) + + return esc(typedef) +end