Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions src/Dialects/Utils.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,4 @@
import ..IR: IR, Attribute, NamedAttribute, context
import ..API
import ..IR: NamedAttribute

namedattribute(name, val) = namedattribute(name, Attribute(val))
namedattribute(name, val::Attribute) = NamedAttribute(name, val)
function namedattribute(name, val::NamedAttribute)
@assert true # TODO(jm): check whether name of attribute is correct, getting the name might need to be added to IR.jl?
return val
end

function operandsegmentsizes(segments)
return namedattribute("operand_segment_sizes", Attribute(Int32.(segments)))
end
operandsegmentsizes(segments) = NamedAttribute("operand_segment_sizes", Int32.(segments))
resultsegmentsizes(segments) = NamedAttribute("result_segment_sizes", Int32.(segments))
47 changes: 25 additions & 22 deletions src/IR/AffineExpr.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
struct AffineExpr
expr::API.MlirAffineExpr

function AffineExpr(expr)
@assert !mlirIsNull(expr) "cannot create AffineExpr with null MlirAffineExpr"
return new(expr)
end
@checked struct AffineExpr
ref::API.MlirAffineExpr
end

Base.convert(::Core.Type{API.MlirAffineExpr}, expr::AffineExpr) = expr.expr
Base.cconvert(::Core.Type{API.MlirAffineExpr}, expr::AffineExpr) = expr
Base.unsafe_convert(::Core.Type{API.MlirAffineExpr}, expr::AffineExpr) = expr.ref

"""
==(a, b)
Expand Down Expand Up @@ -56,8 +52,9 @@ ismultipleof(expr::AffineExpr, factor) = API.mlirAffineExprIsMultipleOf(expr, fa

Checks whether the given affine expression involves AffineDimExpr 'position'.
"""
isfunctionofdimexpr(expr::AffineExpr, position) =
API.mlirAffineExprIsFunctionOfDim(expr, position)
function isfunctionofdimexpr(expr::AffineExpr, position)
return API.mlirAffineExprIsFunctionOfDim(expr, position)
end

"""
isdimexpr(affineExpr)
Expand All @@ -71,8 +68,9 @@ isdimexpr(expr::AffineExpr) = API.mlirAffineExprIsADim(expr)

Creates an affine dimension expression with 'position' in the context.
"""
AffineDimensionExpr(position; context::Context=context()) =
AffineExpr(API.mlirAffineDimExprGet(context, position))
function AffineDimensionExpr(position; context::Context=current_context())
return AffineExpr(API.mlirAffineDimExprGet(context, position))
end

"""
issymbolexpr(affineExpr)
Expand All @@ -82,12 +80,13 @@ Checks whether the given affine expression is a symbol expression.
issymbolexpr(expr::AffineExpr) = API.mlirAffineExprIsASymbol(expr)

"""
SymbolExpr(position; context=context())
SymbolExpr(position; context=current_context())

Creates an affine symbol expression with 'position' in the context.
"""
SymbolExpr(position; context::Context=context()) =
AffineExpr(API.mlirAffineSymbolExprGet(context, position))
function SymbolExpr(position; context::Context=current_context())
return AffineExpr(API.mlirAffineSymbolExprGet(context, position))
end

"""
position(affineExpr)
Expand Down Expand Up @@ -116,12 +115,13 @@ Checks whether the given affine expression is a constant expression.
isconstantexpr(expr::AffineExpr) = API.mlirAffineExprIsAConstant(expr)

"""
ConstantExpr(constant::Int; context=context())
ConstantExpr(constant::Int; context=current_context())

Creates an affine constant expression with 'constant' in the context.
"""
ConstantExpr(constant; context::Context=context()) =
AffineExpr(API.mlirAffineConstantExprGet(context, constant))
function ConstantExpr(constant; context::Context=current_context())
return AffineExpr(API.mlirAffineConstantExprGet(context, constant))
end

"""
value(affineExpr)
Expand Down Expand Up @@ -189,8 +189,10 @@ isfloordiv(expr::AffineExpr) = API.mlirAffineExprIsAFloorDiv(expr)

Creates an affine floordiv expression with 'lhs' and 'rhs'.
"""
Base.div(lhs::AffineExpr, rhs::AffineExpr) =
AffineExpr(API.mlirAffineFloorDivExprGet(lhs, rhs))
function Base.div(lhs::AffineExpr, rhs::AffineExpr)
return AffineExpr(API.mlirAffineFloorDivExprGet(lhs, rhs))
end

Base.fld(lhs::AffineExpr, rhs::AffineExpr) = div(lhs, rhs)

"""
Expand All @@ -205,8 +207,9 @@ isceildiv(expr::AffineExpr) = API.mlirAffineExprIsACeilDiv(expr)

Creates an affine ceildiv expression with 'lhs' and 'rhs'.
"""
Base.cld(lhs::AffineExpr, rhs::AffineExpr) =
AffineExpr(API.mlirAffineCeilDivExprGet(lhs, rhs))
function Base.cld(lhs::AffineExpr, rhs::AffineExpr)
return AffineExpr(API.mlirAffineCeilDivExprGet(lhs, rhs))
end

"""
isbinary(affineExpr)
Expand Down
127 changes: 66 additions & 61 deletions src/IR/AffineMap.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
struct AffineMap
map::API.MlirAffineMap

function AffineMap(map::API.MlirAffineMap)
@assert !mlirIsNull(map) "cannot create AffineMap with null MlirAffineMap"
return new(map)
end
@checked struct AffineMap
ref::API.MlirAffineMap
end

"""
AffineMap(; context=context())
AffineMap(; context=current_context())

Creates a zero result affine map with no dimensions or symbols in the context.
The affine map is owned by the context.
"""
AffineMap(; context::Context=context()) = AffineMap(API.mlirAffineMapEmptyGet(context))
function AffineMap(; context::Context=current_context())
return AffineMap(API.mlirAffineMapEmptyGet(context))
end

Base.convert(::Core.Type{API.MlirAffineMap}, map::AffineMap) = map.map
Base.cconvert(::Core.Type{API.MlirAffineMap}, map::AffineMap) = map
Base.unsafe_convert(::Core.Type{API.MlirAffineMap}, map::AffineMap) = map.ref

"""
==(a, b)
Expand All @@ -25,78 +23,90 @@ Checks if two affine maps are equal.
Base.:(==)(a::AffineMap, b::AffineMap) = API.mlirAffineMapEqual(a, b)

"""
compose(affineExpr, affineMap)
context(affineMap)

Composes the given map with the given expression.
Gets the context that the given affine map was created with.
"""
compose(expr::AffineExpr, map::AffineMap) = AffineExpr(API.mlirAffineExprCompose(expr, map))
context(map::AffineMap) = Context(API.mlirAffineMapGetContext(map))

function Base.show(io::IO, map::AffineMap)
print(io, "AffineMap(#= ")
c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any))
ref = Ref(io)
API.mlirAffineMapPrint(map, c_print_callback, ref)
return print(io, " =#)")
end

"""
context(affineMap)
compose(affineExpr, affineMap)

Gets the context that the given affine map was created with.
Composes the given map with the given expression.
"""
context(map::AffineMap) = API.mlirAffineMapGetContext(map)
compose(expr::AffineExpr, map::AffineMap) = AffineExpr(API.mlirAffineExprCompose(expr, map))

"""
AffineMap(ndims, nsymbols; context=context())
AffineMap(ndims, nsymbols; context=current_context())

Creates a zero result affine map of the given dimensions and symbols in the context.
The affine map is owned by the context.
"""
AffineMap(ndims, nsymbols; context::Context=context()) =
AffineMap(API.mlirAffineMapZeroResultGet(context, ndims, nsymbols))
function AffineMap(ndims, nsymbols; context::Context=current_context())
return AffineMap(API.mlirAffineMapZeroResultGet(context, ndims, nsymbols))
end

"""
AffineMap(ndims, nsymbols, affineExprs; context=context())
AffineMap(ndims, nsymbols, affineExprs; context=current_context())

Creates an affine map with results defined by the given list of affine expressions.
The map resulting map also has the requested number of input dimensions and symbols, regardless of them being used in the results.
"""
AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context()) =
AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), pointer(exprs)))
function AffineMap(
ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=current_context()
)
return AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), exprs))
end

"""
ConstantAffineMap(val; context=context())
ConstantAffineMap(val; context=current_context())

Creates a single constant result affine map in the context. The affine map is owned by the context.
"""
ConstantAffineMap(val; context::Context=context()) =
AffineMap(API.mlirAffineMapConstantGet(context, val))
function ConstantAffineMap(val; context::Context=current_context())
return AffineMap(API.mlirAffineMapConstantGet(context, val))
end

"""
IdentityAffineMap(ndims; context=context())
IdentityAffineMap(ndims; context=current_context())

Creates an affine map with 'ndims' identity in the context. The affine map is owned by the context.
"""
IdentityAffineMap(ndims; context::Context=context()) =
AffineMap(API.mlirAffineMapMultiDimIdentityGet(context, ndims))
function IdentityAffineMap(ndims; context::Context=current_context())
return AffineMap(API.mlirAffineMapMultiDimIdentityGet(context, ndims))
end

"""
MinorIdentityAffineMap(ndims, nresults; context=context())
MinorIdentityAffineMap(ndims, nresults; context=current_context())

Creates an identity affine map on the most minor dimensions in the context. The affine map is owned by the context.
The function asserts that the number of dimensions is greater or equal to the number of results.
"""
function MinorIdentityAffineMap(ndims, nresults; context::Context=context())
function MinorIdentityAffineMap(ndims, nresults; context::Context=current_context())
@assert ndims >= nresults "number of dimensions must be greater or equal to the number of results"
return AffineMap(API.mlirAffineMapMinorIdentityGet(context, ndims, nresults))
end

"""
PermutationAffineMap(permutation; context=context())
PermutationAffineMap(permutation; context=current_context())

Creates an affine map with a permutation expression and its size in the context.
The permutation expression is a non-empty vector of integers.
The elements of the permutation vector must be continuous from 0 and cannot be repeated (i.e. `[1,2,0]` is a valid permutation. `[2,0]` or `[1,1,2]` is an invalid invalid permutation).
The affine map is owned by the context.
"""
function PermutationAffineMap(permutation; context::Context=context())
function PermutationAffineMap(permutation; context::Context=current_context())
@assert Base.isperm(permutation) "$permutation must be a valid permutation"
zero_perm = permutation .- 1
return AffineMap(
API.mlirAffineMapPermutationGet(context, length(zero_perm), pointer(zero_perm))
)
return AffineMap(API.mlirAffineMapPermutationGet(context, length(zero_perm), zero_perm))
end

"""
Expand Down Expand Up @@ -191,8 +201,9 @@ Base.isperm(map::AffineMap) = API.mlirAffineMapIsPermutation(map)

Returns the affine map consisting of the `positions` subset.
"""
submap(map::AffineMap, pos::Vector{Int}) =
AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pointer(pos)))
function submap(map::AffineMap, pos::Vector{Int})
return AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pos))
end

"""
majorsubmap(affineMap, nresults)
Expand All @@ -201,28 +212,34 @@ Returns the affine map consisting of the most major `nresults` results.
Returns the null AffineMap if the `nresults` is equal to zero.
Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map.
"""
majorsubmap(map::AffineMap, nresults) =
AffineMap(API.mlirAffineMapGetMajorSubMap(map, nresults))
function majorsubmap(map::AffineMap, nresults)
return AffineMap(API.mlirAffineMapGetMajorSubMap(map, nresults))
end

"""
minorsubmap(affineMap, nresults)

Returns the affine map consisting of the most minor `nresults` results. Returns the null AffineMap if the `nresults` is equal to zero.
Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map.
"""
minorsubmap(map::AffineMap, nresults) =
AffineMap(API.mlirAffineMapGetMinorSubMap(map, nresults))
function minorsubmap(map::AffineMap, nresults)
return AffineMap(API.mlirAffineMapGetMinorSubMap(map, nresults))
end

"""
mlirAffineMapReplace(affineMap, expression => replacement, numResultDims, numResultSyms)

Apply `AffineExpr::replace(map)` to each of the results and return a new new AffineMap with the new results and the specified number of dims and symbols.
"""
Base.replace(
function Base.replace(
map::AffineMap, old_new::Pair{AffineExpr,AffineExpr}, nresultdims, nresultsyms
) = AffineMap(
API.mlirAffineMapReplace(map, old_new.first, old_new.second, nresultdims, nresultsyms),
)
return AffineMap(
API.mlirAffineMapReplace(
map, old_new.first, old_new.second, nresultdims, nresultsyms
),
)
end

"""
simplify(affineMaps, size, result, populateResult)
Expand All @@ -232,15 +249,7 @@ Asserts that all maps in `affineMaps` are normalized to the same number of dims
Takes a callback `populateResult` to fill the `res` container with value `m` at entry `idx`.
This allows returning without worrying about ownership considerations.
"""
# TODO simplify(map::AffineMap, ...) = AffineMap(API.mlirAffineMapCompressUnusedSymbols(map, ...))

function Base.show(io::IO, map::AffineMap)
print(io, "AffineMap(#= ")
c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any))
ref = Ref(io)
API.mlirAffineMapPrint(map, c_print_callback, ref)
return print(io, " =#)")
end
# TODO(#2246) simplify(map::AffineMap, ...) = AffineMap(API.mlirAffineMapCompressUnusedSymbols(map, ...))

walk(f, other) = f(other)
function walk(f, expr::Expr)
Expand All @@ -258,12 +267,8 @@ On the right hand side are allowed the following function calls:

The rhs can only contains dimensions and symbols present on the left hand side or integer literals.

```juliadoctest
julia> using MLIR: IR, AffineUtils

julia> IR.context!(IR.Context()) do
IR.@affinemap (d1, d2)[s0] -> (d1 + s0, d2 % 10)
end
```julia
julia> @affinemap (d1, d2)[s0] -> (d1 + s0, d2 % 10)
MLIR.IR.AffineMap(#= (d0, d1)[s0] -> (d0 + s0, d1 mod 10) =#)
```
"""
Expand All @@ -287,11 +292,11 @@ macro affinemap(expr)
@assert all(x -> x isa Symbol, syms) "invalid symbols $syms"

dimexprs = map(enumerate(dims)) do (i, dim)
:($dim = AffineDimensionExpr($(i - 1)))
return :($dim = AffineDimensionExpr($(i - 1)))
end

symexprs = map(enumerate(syms)) do (i, sym)
:($sym = SymbolExpr($(i - 1)))
return :($sym = SymbolExpr($(i - 1)))
end

known_binops = [:+, :-, :*, :÷, :%, :fld, :cld]
Expand Down
Loading
Loading