1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
# (C) Copyright 2007
# Andreas Kloeckner <inform -at- tiker.net>
#
# Use, modification and distribution is subject to the Boost Software
# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
# http://www.boost.org/LICENSE_1_0.txt)
#
# Authors: Andreas Kloeckner
import boost.mpi as mpi
import random
import sys
MAX_GENERATIONS = 20
TAG_DEBUG = 0
TAG_DATA = 1
TAG_TERMINATE = 2
TAG_PROGRESS_REPORT = 3
class TagGroupListener:
"""Class to help listen for only a given set of tags.
This is contrived: Typicallly you could just listen for
mpi.any_tag and filter."""
def __init__(self, comm, tags):
self.tags = tags
self.comm = comm
self.active_requests = {}
def wait(self):
for tag in self.tags:
if tag not in self.active_requests:
self.active_requests[tag] = self.comm.irecv(tag=tag)
requests = mpi.RequestList(self.active_requests.values())
data, status, index = mpi.wait_any(requests)
del self.active_requests[status.tag]
return status, data
def cancel(self):
for r in self.active_requests.itervalues():
r.cancel()
#r.wait()
self.active_requests = {}
def rank0():
sent_histories = (mpi.size-1)*15
print "sending %d packets on their way" % sent_histories
send_reqs = mpi.RequestList()
for i in range(sent_histories):
dest = random.randrange(1, mpi.size)
send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))
mpi.wait_all(send_reqs)
completed_histories = []
progress_reports = {}
dead_kids = []
tgl = TagGroupListener(mpi.world,
[TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])
def is_complete():
for i in progress_reports.values():
if i != sent_histories:
return False
return len(dead_kids) == mpi.size-1
while True:
status, data = tgl.wait()
if status.tag == TAG_DATA:
#print "received completed history %s from %d" % (data, status.source)
completed_histories.append(data)
if len(completed_histories) == sent_histories:
print "all histories received, exiting"
for rank in range(1, mpi.size):
mpi.world.send(rank, TAG_TERMINATE, None)
elif status.tag == TAG_PROGRESS_REPORT:
progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
elif status.tag == TAG_DEBUG:
print "[DBG %d] %s" % (status.source, data)
elif status.tag == TAG_TERMINATE:
dead_kids.append(status.source)
else:
print "unexpected tag %d from %d" % (status.tag, status.source)
if is_complete():
break
print "OK"
def comm_rank():
while True:
data, status = mpi.world.recv(return_status=True)
if status.tag == TAG_DATA:
mpi.world.send(0, TAG_PROGRESS_REPORT, data)
data.append(mpi.rank)
if len(data) >= MAX_GENERATIONS:
dest = 0
else:
dest = random.randrange(1, mpi.size)
mpi.world.send(dest, TAG_DATA, data)
elif status.tag == TAG_TERMINATE:
from time import sleep
mpi.world.send(0, TAG_TERMINATE, 0)
break
else:
print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)
def main():
# this program sends around messages consisting of lists of visited nodes
# randomly. After MAX_GENERATIONS, they are returned to rank 0.
if mpi.rank == 0:
rank0()
else:
comm_rank()
if __name__ == "__main__":
main()
|