-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathamm4a.py
More file actions
41 lines (32 loc) · 990 Bytes
/
amm4a.py
File metadata and controls
41 lines (32 loc) · 990 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import arkouda as ak ; import numpy as np
import sys
from arkouda.comm_diagnostics \
import (start_comm_diagnostics, stop_comm_diagnostics, print_comm_diagnostics_table, reset_comm_diagnostics)
ak.connect()
if len(sys.argv) > 1 :
iseed = int(sys.argv[1])
else :
iseed = 1701
np.random.seed(iseed)
nleft = np.random.randint(-10,10,(11,11,11,250))
nright = np.random.randint(-10,10,(11,11,250,11))
nprod = np.matmul(nleft,nright)
print ("\n\nTesting 4D matrix multiplication: non-distributed, then distributed.")
print()
pleft = ak.array(nleft) ; pright = ak.array(nright)
start_comm_diagnostics()
pprod1 = ak.matmul(pleft, pright)
stop_comm_diagnostics()
print_comm_diagnostics_table()
print()
reset_comm_diagnostics()
start_comm_diagnostics()
pprod2 = ak.distmatmulmultidim(pleft,pright)
stop_comm_diagnostics()
print_comm_diagnostics_table()
n1 = pprod1.to_ndarray() ;
n2 = pprod2.to_ndarray()
print()
print ((nprod==n1).all())
print ((nprod==n2).all())
ak.shutdown()