Skip to content

NEC BLAS batched_gemm support #331

@raver119

Description

@raver119

BLAS interface defines matrix multiplication primitive for 2D matrices, but in modern neural networks "batched matrix multiplication" is very important.

I.e. consider the following multiplication example taken from BERT model with seqLen of 128:
[4, 12, 128, 128] x [4, 12, 128, 64] = [4, 12, 128, 64].

This operation basically defines 48 multiplications of [128, 128] x [128, 64] matrices.
In current APIs that would be sequential single-threaded loop, and it'll be quite inefficient, since inner matrices are relatively small.

We'll have better performance if we do this in parallel along batch dimension.
Here's the relatively simple batched_gemm code we've added:

static void batchedGemmUnPackC(const NDArray* vA, const NDArray* vB, NDArray* vC,

And it provides x12 better performance than sequential single-threaded loop.
So it would be awesome to have proper support for such a primitive in NEC BLAS library.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions