What happened?
In trying to debug something, I encountered this weird edge case where the result of parallel matrix multiplication is wrong. The error is detailed below and appears only when run with two tasks. I am not sure what makes this edge case special.
I have been running into this issue first with QR decomposition. You can try:
import heat as ht
comm = ht.comm
A = ht.random.randn(comm.size, comm.size, dtype=ht.float32, split=0)
QR = ht.linalg.qr(A)
matmul_success_loc = ht.allclose(QR.Q @ QR.R, A)
matmul_success_glob = ht.allclose((QR.Q.resplit(None) @ QR.R.resplit(None)).resplit(QR.Q.split), A)
print(matmul_success_loc, matmul_success_glob)
which will print False, True.
I am a bit lost here. In particular, the matrix multiplication function is very long and hard to understand. So I thought I just report this bug because it seems pretty bad and maybe somebody can give me some pointers.
Code snippet triggering the error
import heat as ht
split = 0
shape = (4, 3)
A = ht.ones(shape, split=split)
B = ht.ones(shape[::-1], split=split)
C = A @ B
C_glob = (A.resplit(None) @ B.resplit(None)).resplit(C.split)
print(C)
print(C_glob)
Error message or erroneous outcome
Output when run with two tasks:
DNDarray([[3., 3., 3., 3.],
[6., 6., 6., 6.],
[3., 3., 3., 3.],
[6., 6., 6., 6.]], dtype=ht.float32, device=cpu:0, split=0)
DNDarray([[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.],
[3., 3., 3., 3.]], dtype=ht.float32, device=cpu:0, split=0)
Version
main (development branch)
Python version
3.13
PyTorch version
2.9
MPI version
What happened?
In trying to debug something, I encountered this weird edge case where the result of parallel matrix multiplication is wrong. The error is detailed below and appears only when run with two tasks. I am not sure what makes this edge case special.
I have been running into this issue first with QR decomposition. You can try:
which will print
False, True.I am a bit lost here. In particular, the matrix multiplication function is very long and hard to understand. So I thought I just report this bug because it seems pretty bad and maybe somebody can give me some pointers.
Code snippet triggering the error
Error message or erroneous outcome
Output when run with two tasks: DNDarray([[3., 3., 3., 3.], [6., 6., 6., 6.], [3., 3., 3., 3.], [6., 6., 6., 6.]], dtype=ht.float32, device=cpu:0, split=0) DNDarray([[3., 3., 3., 3.], [3., 3., 3., 3.], [3., 3., 3., 3.], [3., 3., 3., 3.]], dtype=ht.float32, device=cpu:0, split=0)Version
main (development branch)
Python version
3.13
PyTorch version
2.9
MPI version