BLAS interface defines matrix multiplication primitive for 2D matrices, but in modern neural networks "batched matrix multiplication" is very important.
I.e. consider the following multiplication example taken from BERT model with seqLen of 128:
[4, 12, 128, 128] x [4, 12, 128, 64] = [4, 12, 128, 64].
This operation basically defines 48 multiplications of [128, 128] x [128, 64] matrices.
In current APIs that would be sequential single-threaded loop, and it'll be quite inefficient, since inner matrices are relatively small.
We'll have better performance if we do this in parallel along batch dimension.
Here's the relatively simple batched_gemm code we've added:
|
static void batchedGemmUnPackC(const NDArray* vA, const NDArray* vB, NDArray* vC, |
And it provides x12 better performance than sequential single-threaded loop.
So it would be awesome to have proper support for such a primitive in NEC BLAS library.
BLAS interface defines matrix multiplication primitive for 2D matrices, but in modern neural networks "batched matrix multiplication" is very important.
I.e. consider the following multiplication example taken from BERT model with seqLen of 128:
[4, 12, 128, 128] x [4, 12, 128, 64] = [4, 12, 128, 64].
This operation basically defines 48 multiplications of [128, 128] x [128, 64] matrices.
In current APIs that would be sequential single-threaded loop, and it'll be quite inefficient, since inner matrices are relatively small.
We'll have better performance if we do this in parallel along batch dimension.
Here's the relatively simple batched_gemm code we've added:
deeplearning4j/libnd4j/include/helpers/cpu/MmulHelper.cpp
Line 1018 in 3bf785f
And it provides x12 better performance than sequential single-threaded loop.
So it would be awesome to have proper support for such a primitive in NEC BLAS library.