• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# (C) Copyright 2007
2# Andreas Kloeckner <inform -at- tiker.net>
3#
4# Use, modification and distribution is subject to the Boost Software
5# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
6# http://www.boost.org/LICENSE_1_0.txt)
7#
8#  Authors: Andreas Kloeckner
9
10from __future__ import print_function
11import mpi
12import random
13import sys
14
15MAX_GENERATIONS = 20
16TAG_DEBUG = 0
17TAG_DATA = 1
18TAG_TERMINATE = 2
19TAG_PROGRESS_REPORT = 3
20
21
22
23
24class TagGroupListener:
25    """Class to help listen for only a given set of tags.
26
27    This is contrived: Typicallly you could just listen for
28    mpi.any_tag and filter."""
29    def __init__(self, comm, tags):
30        self.tags = tags
31        self.comm = comm
32        self.active_requests = {}
33
34    def wait(self):
35        for tag in self.tags:
36            if tag not in self.active_requests:
37                self.active_requests[tag] = self.comm.irecv(tag=tag)
38        requests = mpi.RequestList(self.active_requests.values())
39        data, status, index = mpi.wait_any(requests)
40        del self.active_requests[status.tag]
41        return status, data
42
43    def cancel(self):
44        for r in self.active_requests.itervalues():
45            r.cancel()
46            #r.wait()
47        self.active_requests = {}
48
49
50
51def rank0():
52    sent_histories = (mpi.size-1)*15
53    print ("sending %d packets on their way" % sent_histories)
54    send_reqs = mpi.RequestList()
55    for i in range(sent_histories):
56        dest = random.randrange(1, mpi.size)
57        send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))
58
59    mpi.wait_all(send_reqs)
60
61    completed_histories = []
62    progress_reports = {}
63    dead_kids = []
64
65    tgl = TagGroupListener(mpi.world,
66            [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])
67
68    def is_complete():
69        for i in progress_reports.values():
70            if i != sent_histories:
71                return False
72        return len(dead_kids) == mpi.size-1
73
74    while True:
75        status, data = tgl.wait()
76
77        if status.tag == TAG_DATA:
78            #print ("received completed history %s from %d" % (data, status.source))
79            completed_histories.append(data)
80            if len(completed_histories) == sent_histories:
81                print ("all histories received, exiting")
82                for rank in range(1, mpi.size):
83                    mpi.world.send(rank, TAG_TERMINATE, None)
84        elif status.tag == TAG_PROGRESS_REPORT:
85            progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
86        elif status.tag == TAG_DEBUG:
87            print ("[DBG %d] %s" % (status.source, data))
88        elif status.tag == TAG_TERMINATE:
89            dead_kids.append(status.source)
90        else:
91            print ("unexpected tag %d from %d" % (status.tag, status.source))
92
93        if is_complete():
94            break
95
96    print ("OK")
97
98def comm_rank():
99    while True:
100        data, status = mpi.world.recv(return_status=True)
101        if status.tag == TAG_DATA:
102            mpi.world.send(0, TAG_PROGRESS_REPORT, data)
103            data.append(mpi.rank)
104            if len(data) >= MAX_GENERATIONS:
105                dest = 0
106            else:
107                dest = random.randrange(1, mpi.size)
108            mpi.world.send(dest, TAG_DATA, data)
109        elif status.tag == TAG_TERMINATE:
110            from time import sleep
111            mpi.world.send(0, TAG_TERMINATE, 0)
112            break
113        else:
114            print ("[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source))
115
116
117def main():
118    # this program sends around messages consisting of lists of visited nodes
119    # randomly. After MAX_GENERATIONS, they are returned to rank 0.
120
121    if mpi.rank == 0:
122        rank0()
123    else:
124        comm_rank()
125
126
127
128if __name__ == "__main__":
129    main()
130