File: nonblocking_test.py

package info (click to toggle)
boost1.90 1.90.0-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 593,120 kB
  • sloc: cpp: 4,190,908; xml: 196,648; python: 34,618; ansic: 23,145; asm: 5,468; sh: 3,774; makefile: 1,161; perl: 1,020; sql: 728; ruby: 676; yacc: 478; java: 77; lisp: 24; csh: 6
file content (129 lines) | stat: -rw-r--r-- 3,803 bytes parent folder | download | duplicates (13)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# (C) Copyright 2007 
# Andreas Kloeckner <inform -at- tiker.net>
#
# Use, modification and distribution is subject to the Boost Software
# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
# http://www.boost.org/LICENSE_1_0.txt)
#
#  Authors: Andreas Kloeckner

from __future__ import print_function
import mpi
import random
import sys

MAX_GENERATIONS = 20
TAG_DEBUG = 0
TAG_DATA = 1
TAG_TERMINATE = 2
TAG_PROGRESS_REPORT = 3




class TagGroupListener:
    """Class to help listen for only a given set of tags.

    This is contrived: Typicallly you could just listen for 
    mpi.any_tag and filter."""
    def __init__(self, comm, tags):
        self.tags = tags
        self.comm = comm
        self.active_requests = {}

    def wait(self):
        for tag in self.tags:
            if tag not in self.active_requests:
                self.active_requests[tag] = self.comm.irecv(tag=tag)
        requests = mpi.RequestList(self.active_requests.values())
        data, status, index = mpi.wait_any(requests)
        del self.active_requests[status.tag]
        return status, data

    def cancel(self):
        for r in self.active_requests.itervalues():
            r.cancel()
            #r.wait()
        self.active_requests = {}



def rank0():
    sent_histories = (mpi.size-1)*15
    print ("sending %d packets on their way" % sent_histories)
    send_reqs = mpi.RequestList()
    for i in range(sent_histories):
        dest = random.randrange(1, mpi.size)
        send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))

    mpi.wait_all(send_reqs)

    completed_histories = []
    progress_reports = {}
    dead_kids = []

    tgl = TagGroupListener(mpi.world,
            [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])

    def is_complete():
        for i in progress_reports.values():
            if i != sent_histories:
                return False
        return len(dead_kids) == mpi.size-1

    while True:
        status, data = tgl.wait()

        if status.tag == TAG_DATA:
            #print ("received completed history %s from %d" % (data, status.source))
            completed_histories.append(data)
            if len(completed_histories) == sent_histories:
                print ("all histories received, exiting")
                for rank in range(1, mpi.size):
                    mpi.world.send(rank, TAG_TERMINATE, None)
        elif status.tag == TAG_PROGRESS_REPORT:
            progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
        elif status.tag == TAG_DEBUG:
            print ("[DBG %d] %s" % (status.source, data))
        elif status.tag == TAG_TERMINATE:
            dead_kids.append(status.source)
        else:
            print ("unexpected tag %d from %d" % (status.tag, status.source))

        if is_complete():
            break

    print ("OK")

def comm_rank():
    while True:
        data, status = mpi.world.recv(return_status=True)
        if status.tag == TAG_DATA:
            mpi.world.send(0, TAG_PROGRESS_REPORT, data)
            data.append(mpi.rank)
            if len(data) >= MAX_GENERATIONS:
                dest = 0
            else:
                dest = random.randrange(1, mpi.size)
            mpi.world.send(dest, TAG_DATA, data)
        elif status.tag == TAG_TERMINATE:
            from time import sleep
            mpi.world.send(0, TAG_TERMINATE, 0)
            break
        else:
            print ("[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source))


def main():
    # this program sends around messages consisting of lists of visited nodes
    # randomly. After MAX_GENERATIONS, they are returned to rank 0.

    if mpi.rank == 0:
        rank0()
    else:
        comm_rank()
        


if __name__ == "__main__":
    main()