1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
|
from mpi4py import MPI
class Intracomm(MPI.Intracomm):
"""
Intracommunicator class with scalable, point-to-point based
implementations of global reduction operations.
"""
def __new__(cls, comm=None):
return super().__new__(cls, comm)
def reduce(self, sendobj=None, recvobj=None, op=MPI.SUM, root=0):
size = self.size
rank = self.rank
assert 0 <= root < size
tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1
recvobj = sendobj
mask = 1
while mask < size:
if (mask & rank) != 0:
target = (rank & ~mask) % size
self.send(recvobj, dest=target, tag=tag)
else:
target = (rank | mask)
if target < size:
tmp = self.recv(None, source=target, tag=tag)
recvobj = op(recvobj, tmp)
mask <<= 1
if root != 0:
if rank == 0:
self.send(recvobj, dest=root, tag=tag)
elif rank == root:
recvobj = self.recv(None, source=0, tag=tag)
if rank != root:
recvobj = None
return recvobj
def allreduce(self, sendobj=None, recvobj=None, op=MPI.SUM):
recvobj = self.reduce(sendobj, recvobj, op, 0)
recvobj = self.bcast(recvobj, 0)
return recvobj
def scan(self, sendobj=None, recvobj=None, op=MPI.SUM):
size = self.size
rank = self.rank
tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1
recvobj = sendobj
partial = sendobj
mask = 1
while mask < size:
target = rank ^ mask
if target < size:
tmp = self.sendrecv(partial, dest=target, source=target,
sendtag=tag, recvtag=tag)
if rank > target:
partial = op(tmp, partial)
recvobj = op(tmp, recvobj)
else:
tmp = op(partial, tmp)
partial = tmp
mask <<= 1
return recvobj
def exscan(self, sendobj=None, recvobj=None, op=MPI.SUM):
size = self.size
rank = self.rank
tag = MPI.COMM_WORLD.Get_attr(MPI.TAG_UB)-1
recvobj = sendobj
partial = sendobj
mask = 1
flag = False
while mask < size:
target = rank ^ mask
if target < size:
tmp = self.sendrecv(partial, dest=target, source=target,
sendtag=tag, recvtag=tag)
if rank > target:
partial = op(tmp, partial)
if rank != 0:
if not flag:
recvobj = tmp
flag = True
else:
recvobj = op(tmp, recvobj)
else:
tmp = op(partial, tmp)
partial = tmp
mask <<= 1
if rank == 0:
recvobj = None
return recvobj
|