Skip to content

batched_mul doesn't work with MtlArrays #581

@forrestlaine

Description

@forrestlaine

Using Julia 1.10.3
Metal v1.1.0
NNlib v0.9.14

julia> using Metal, NNlib
julia> A = randn(Float32, 3,4,5);
julia> B = randn(Float32, 4,6,5);
julia> Ag = MtlArray(A);
julia> Bg = MtlArray(B);
julia> NNlib.batched_mul(A,B); # works no prob
julia> NNlib.batched_mul(Ag, Bg)
ERROR: MethodError: no method matching _batched_gemm!(::Type{MtlArray{Float32, 3, Private}}, ::Char, ::Char, ::Float32, ::MtlArray{Float32, 3, Private}, ::MtlArray{Float32, 3, Private}, ::Float32, ::MtlArray{Float32, 3, Private})

Closest candidates are:
  _batched_gemm!(::Type{<:Array}, ::Char, ::Char, ::Number, ::Any, ::Any, ::Number, ::Any)
   @ NNlib ~/.julia/packages/NNlib/c3RdJ/src/batched/batchedmul.jl:262

Stacktrace:
 [1] _batched_try_gemm!(::Type{MtlArray{Float32, 3, Private}}, C::MtlArray{Float32, 3, Private}, A::MtlArray{Float32, 3, Private}, B::MtlArray{Float32, 3, Private}, α::Float32, β::Float32)
   @ NNlib ~/.julia/packages/NNlib/c3RdJ/src/batched/batchedmul.jl:258
 [2] _batched_mul!(::Type{MtlArray{Float32, 3, Private}}, C::MtlArray{Float32, 3, Private}, A::MtlArray{Float32, 3, Private}, B::MtlArray{Float32, 3, Private}, α::Float32, β::Float32)
   @ NNlib ~/.julia/packages/NNlib/c3RdJ/src/batched/batchedmul.jl:222
 [3] batched_mul!(C::MtlArray{Float32, 3, Private}, A::MtlArray{Float32, 3, Private}, B::MtlArray{Float32, 3, Private}, α::Float32, β::Float32)
   @ NNlib ~/.julia/packages/NNlib/c3RdJ/src/batched/batchedmul.jl:216
 [4] batched_mul!(C::MtlArray{Float32, 3, Private}, A::MtlArray{Float32, 3, Private}, B::MtlArray{Float32, 3, Private})
   @ NNlib ~/.julia/packages/NNlib/c3RdJ/src/batched/batchedmul.jl:216
 [5] _batched_mul(::Type{MtlArray{Float32, 3, Private}}, A::MtlArray{Float32, 3, Private}, B::MtlArray{Float32, 3, Private})
   @ NNlib ~/.julia/packages/NNlib/c3RdJ/src/batched/batchedmul.jl:72
 [6] batched_mul(A::MtlArray{Float32, 3, Private}, B::MtlArray{Float32, 3, Private})
   @ NNlib ~/.julia/packages/NNlib/c3RdJ/src/batched/batchedmul.jl:59
 [7] top-level scope
   @ REPL[24]:1
 [8] top-level scope
   @ ~/.julia/packages/Metal/q9oGt/src/initialization.jl:57

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions