1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
|
# -----------------------------------------------------------------------------
from array import array
from threading import Thread
from mpi4py import MPI
class Counter:
def __init__(self, comm):
# duplicate communicator
assert not comm.Is_inter()
self.comm = comm.Dup()
# start counter thread
self.thread = None
rank = self.comm.Get_rank()
if rank == 0:
self.thread = Thread(target=self._counter_thread)
self.thread.start()
def _counter_thread(self):
incr = array("i", [0])
ival = array("i", [0])
status = MPI.Status()
while True: # server loop
self.comm.Recv(
[incr, MPI.INT], MPI.ANY_SOURCE, MPI.ANY_TAG, status
)
if status.Get_tag() == 1:
return
self.comm.Ssend([ival, MPI.INT], status.Get_source(), 0)
ival[0] += incr[0]
def free(self):
self.comm.Barrier()
# stop counter thread
rank = self.comm.Get_rank()
if rank == 0:
self.comm.Ssend([None, MPI.INT], 0, 1)
self.thread.join()
#
self.comm.Free()
def next(self):
incr = array("i", [1])
ival = array("i", [0])
self.comm.Ssend([incr, MPI.INT], 0, 0)
self.comm.Recv([ival, MPI.INT], 0, 0)
nxtval = ival[0]
return nxtval
# -----------------------------------------------------------------------------
def test_thread_level():
import sys
flag = MPI.Query_thread() == MPI.THREAD_MULTIPLE
flag = MPI.COMM_WORLD.bcast(flag, root=0)
if not flag:
if MPI.COMM_WORLD.Get_rank() == 0:
sys.stderr.write("MPI does not provide enough thread support\n")
sys.exit(0)
def test():
vals = []
counter = Counter(MPI.COMM_WORLD)
for _i in range(5):
c = counter.next()
vals.append(c)
counter.free()
vals = MPI.COMM_WORLD.allreduce(vals)
assert sorted(vals) == list(range(len(vals)))
if __name__ == "__main__":
test_thread_level()
test()
# -----------------------------------------------------------------------------
|