Skip to content

Batched dot(x, A, y) #641

@3f6a

Description

@3f6a

Motivation and description

Say I have the arrays x[i,b], y[j,b] and A[i,j,b]. Is there an efficient way to do the following "batched dot" operation:

[sum(x[i,b] * A[i,j,b] * y[j,b] for i = axes(A,1) for j = axes(A,2)) for b = ...]

where b traverses the batch dimension. As usual, we could have size(x,2) == 1, size(A,3)==1, ..., which would mean the corresponding missing dimension is broadcasted.

Apologies if there is already a way to do this (efficiently) with existing functions in NNlib, I could not figure it out.

Possible Implementation

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions