-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexecutor_MPI.py
More file actions
68 lines (53 loc) · 2.25 KB
/
executor_MPI.py
File metadata and controls
68 lines (53 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from mpi4py import MPI
import numpy as np
from scipy.special import logsumexp
from discretesampling.base.executor import Executor
from discretesampling.base.util import gather_all
from discretesampling.base.executor.MPI.distributed_fixed_size_redistribution.prefix_sum import (
inclusive_prefix_sum
)
from discretesampling.base.executor.MPI.variable_size_redistribution import (
variable_size_redistribution
)
def LSE(xmem, ymem, dt):
x = np.frombuffer(xmem, dtype='d')
y = np.frombuffer(ymem, dtype='d')
y[:] = logsumexp(np.hstack((x, y)))
class Executor_MPI(Executor):
def __init__(self):
self.comm = MPI.COMM_WORLD
self.P = self.comm.Get_size() # number of MPI nodes/ranks
self.rank = self.comm.Get_rank()
def max(self, x):
local_max = np.max(x)
x_dtype = MPI._typedict[x.dtype.char]
max_dim = np.zeros_like(1, dtype=x.dtype)
self.comm.Allreduce(sendbuf=[local_max, x_dtype], recvbuf=[max_dim, x_dtype], op=MPI.MAX)
return max_dim
def sum(self, x):
x_dtype = MPI._typedict[x.dtype.char]
sum_of_x = np.array(1, dtype=x.dtype)
self.comm.Allreduce(sendbuf=[np.sum(x), x_dtype], recvbuf=[sum_of_x, x_dtype], op=MPI.SUM)
return sum_of_x
def gather(self, x, all_x_shape):
x_dtype = MPI._typedict[x.dtype.char]
all_x = np.zeros(all_x_shape, dtype=x.dtype)
self.comm.Allgather(sendbuf=[x, x_dtype], recvbuf=[all_x, x_dtype])
return all_x
def bcast(self, x):
self.comm.Bcast(buf=[x, MPI._typedict[x.dtype.char]], root=0)
def logsumexp(self, x):
op = MPI.Op.Create(LSE, commute=True)
log_sum = np.zeros_like(1, x.dtype)
MPI_dtype = MPI._typedict[x.dtype.char]
leaf_node = np.array([-np.inf]).astype(x.dtype) if len(x) == 0 else logsumexp(x)
MPI.COMM_WORLD.Allreduce(sendbuf=[leaf_node, MPI_dtype], recvbuf=[log_sum, MPI_dtype], op=op)
op.Free()
return log_sum
def cumsum(self, x):
return inclusive_prefix_sum(x)
def redistribute(self, particles, ncopies):
return variable_size_redistribution(particles, ncopies, self)
def gather_all(self, particles):
particles = gather_all(particles, exec=self)
return particles