1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
|
# -----------------------------------------------------------------------------
from mpi4py import MPI
from array import array
from threading import Thread
class Counter:
def __init__(self, comm):
# duplicate communicator
assert not comm.Is_inter()
self.comm = comm.Dup()
# start counter thread
self.thread = None
rank = self.comm.Get_rank()
if rank == 0:
self.thread = Thread(target=self._counter_thread)
self.thread.start()
def _counter_thread(self):
incr = array('i', [0])
ival = array('i', [0])
status = MPI.Status()
while True: # server loop
self.comm.Recv([incr, MPI.INT],
MPI.ANY_SOURCE, MPI.ANY_TAG,
status)
if status.Get_tag() == 1:
return
self.comm.Ssend([ival, MPI.INT],
status.Get_source(), 0)
ival[0] += incr[0]
def free(self):
self.comm.Barrier()
# stop counter thread
rank = self.comm.Get_rank()
if rank == 0:
self.comm.Ssend([None, MPI.INT], 0, 1)
self.thread.join()
#
self.comm.Free()
def next(self):
incr = array('i', [1])
ival = array('i', [0])
self.comm.Ssend([incr, MPI.INT], 0, 0)
self.comm.Recv([ival, MPI.INT], 0, 0)
nxtval = ival[0]
return nxtval
# -----------------------------------------------------------------------------
def test_thread_level():
import sys
flag = (MPI.Query_thread() == MPI.THREAD_MULTIPLE)
flag = MPI.COMM_WORLD.bcast(flag, root=0)
if not flag:
if MPI.COMM_WORLD.Get_rank() == 0:
sys.stderr.write("MPI does not provide enough thread support\n")
sys.exit(0)
def test():
vals = []
counter = Counter(MPI.COMM_WORLD)
for i in range(5):
c = counter.next()
vals.append(c)
counter.free()
vals = MPI.COMM_WORLD.allreduce(vals)
assert sorted(vals) == list(range(len(vals)))
if __name__ == '__main__':
test_thread_level()
test()
# -----------------------------------------------------------------------------
|