1# (C) Copyright 2007 2# Andreas Kloeckner <inform -at- tiker.net> 3# 4# Use, modification and distribution is subject to the Boost Software 5# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at 6# http://www.boost.org/LICENSE_1_0.txt) 7# 8# Authors: Andreas Kloeckner 9 10from __future__ import print_function 11import mpi 12import random 13import sys 14 15MAX_GENERATIONS = 20 16TAG_DEBUG = 0 17TAG_DATA = 1 18TAG_TERMINATE = 2 19TAG_PROGRESS_REPORT = 3 20 21 22 23 24class TagGroupListener: 25 """Class to help listen for only a given set of tags. 26 27 This is contrived: Typicallly you could just listen for 28 mpi.any_tag and filter.""" 29 def __init__(self, comm, tags): 30 self.tags = tags 31 self.comm = comm 32 self.active_requests = {} 33 34 def wait(self): 35 for tag in self.tags: 36 if tag not in self.active_requests: 37 self.active_requests[tag] = self.comm.irecv(tag=tag) 38 requests = mpi.RequestList(self.active_requests.values()) 39 data, status, index = mpi.wait_any(requests) 40 del self.active_requests[status.tag] 41 return status, data 42 43 def cancel(self): 44 for r in self.active_requests.itervalues(): 45 r.cancel() 46 #r.wait() 47 self.active_requests = {} 48 49 50 51def rank0(): 52 sent_histories = (mpi.size-1)*15 53 print ("sending %d packets on their way" % sent_histories) 54 send_reqs = mpi.RequestList() 55 for i in range(sent_histories): 56 dest = random.randrange(1, mpi.size) 57 send_reqs.append(mpi.world.isend(dest, TAG_DATA, [])) 58 59 mpi.wait_all(send_reqs) 60 61 completed_histories = [] 62 progress_reports = {} 63 dead_kids = [] 64 65 tgl = TagGroupListener(mpi.world, 66 [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE]) 67 68 def is_complete(): 69 for i in progress_reports.values(): 70 if i != sent_histories: 71 return False 72 return len(dead_kids) == mpi.size-1 73 74 while True: 75 status, data = tgl.wait() 76 77 if status.tag == TAG_DATA: 78 #print ("received completed history %s from %d" % (data, status.source)) 79 completed_histories.append(data) 80 if len(completed_histories) == sent_histories: 81 print ("all histories received, exiting") 82 for rank in range(1, mpi.size): 83 mpi.world.send(rank, TAG_TERMINATE, None) 84 elif status.tag == TAG_PROGRESS_REPORT: 85 progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1 86 elif status.tag == TAG_DEBUG: 87 print ("[DBG %d] %s" % (status.source, data)) 88 elif status.tag == TAG_TERMINATE: 89 dead_kids.append(status.source) 90 else: 91 print ("unexpected tag %d from %d" % (status.tag, status.source)) 92 93 if is_complete(): 94 break 95 96 print ("OK") 97 98def comm_rank(): 99 while True: 100 data, status = mpi.world.recv(return_status=True) 101 if status.tag == TAG_DATA: 102 mpi.world.send(0, TAG_PROGRESS_REPORT, data) 103 data.append(mpi.rank) 104 if len(data) >= MAX_GENERATIONS: 105 dest = 0 106 else: 107 dest = random.randrange(1, mpi.size) 108 mpi.world.send(dest, TAG_DATA, data) 109 elif status.tag == TAG_TERMINATE: 110 from time import sleep 111 mpi.world.send(0, TAG_TERMINATE, 0) 112 break 113 else: 114 print ("[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)) 115 116 117def main(): 118 # this program sends around messages consisting of lists of visited nodes 119 # randomly. After MAX_GENERATIONS, they are returned to rank 0. 120 121 if mpi.rank == 0: 122 rank0() 123 else: 124 comm_rank() 125 126 127 128if __name__ == "__main__": 129 main() 130