1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
# --------------------------------------------------------------------
from mpi4py import MPI
import sys, os
class Counter:
def __init__(self, comm):
assert not comm.Is_inter()
self.comm = comm.Dup()
# start counter process
script = os.path.abspath(__file__)
if script[-4:] in ('.pyc', '.pyo'):
script = script[:-1]
self.child = self.comm.Spawn(sys.executable,
[script, '--child'], 1)
def free(self):
self.comm.Barrier()
# stop counter process
rank = self.child.Get_rank()
if rank == 0:
self.child.send(None, 0, 1)
self.child.Disconnect()
#
self.comm.Free()
def next(self):
#
incr = 1
self.child.send(incr, 0, 0)
ival = self.child.recv(None, 0, 0)
nxtval = ival
#
return nxtval
# --------------------------------------------------------------------
def _counter_child():
parent = MPI.Comm.Get_parent()
assert parent != MPI.COMM_NULL
try:
counter = 0
status = MPI.Status()
any_src, any_tag = MPI.ANY_SOURCE, MPI.ANY_TAG
while True: # server loop
incr = parent.recv(None, any_src, any_tag, status)
if status.tag == 1: break
parent.send(counter, status.source, 0)
counter += incr
finally:
parent.Disconnect()
if __name__ == '__main__':
if (len(sys.argv) > 1 and
sys.argv[0] == __file__ and
sys.argv[1] == '--child'):
_counter_child()
sys.exit(0)
# --------------------------------------------------------------------
def test():
vals = []
counter = Counter(MPI.COMM_WORLD)
for i in range(5):
c = counter.next()
vals.append(c)
counter.free()
#
vals = MPI.COMM_WORLD.allreduce(vals)
assert sorted(vals) == list(range(len(vals)))
if __name__ == '__main__':
test()
# --------------------------------------------------------------------
|