# (C) Copyright 2007 # Andreas Kloeckner # # Use, modification and distribution is subject to the Boost Software # License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at # http://www.boost.org/LICENSE_1_0.txt) # # Authors: Andreas Kloeckner from __future__ import print_function import mpi import random import sys MAX_GENERATIONS = 20 TAG_DEBUG = 0 TAG_DATA = 1 TAG_TERMINATE = 2 TAG_PROGRESS_REPORT = 3 class TagGroupListener: """Class to help listen for only a given set of tags. This is contrived: Typicallly you could just listen for mpi.any_tag and filter.""" def __init__(self, comm, tags): self.tags = tags self.comm = comm self.active_requests = {} def wait(self): for tag in self.tags: if tag not in self.active_requests: self.active_requests[tag] = self.comm.irecv(tag=tag) requests = mpi.RequestList(self.active_requests.values()) data, status, index = mpi.wait_any(requests) del self.active_requests[status.tag] return status, data def cancel(self): for r in self.active_requests.itervalues(): r.cancel() #r.wait() self.active_requests = {} def rank0(): sent_histories = (mpi.size-1)*15 print ("sending %d packets on their way" % sent_histories) send_reqs = mpi.RequestList() for i in range(sent_histories): dest = random.randrange(1, mpi.size) send_reqs.append(mpi.world.isend(dest, TAG_DATA, [])) mpi.wait_all(send_reqs) completed_histories = [] progress_reports = {} dead_kids = [] tgl = TagGroupListener(mpi.world, [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE]) def is_complete(): for i in progress_reports.values(): if i != sent_histories: return False return len(dead_kids) == mpi.size-1 while True: status, data = tgl.wait() if status.tag == TAG_DATA: #print ("received completed history %s from %d" % (data, status.source)) completed_histories.append(data) if len(completed_histories) == sent_histories: print ("all histories received, exiting") for rank in range(1, mpi.size): mpi.world.send(rank, TAG_TERMINATE, None) elif status.tag == TAG_PROGRESS_REPORT: progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1 elif status.tag == TAG_DEBUG: print ("[DBG %d] %s" % (status.source, data)) elif status.tag == TAG_TERMINATE: dead_kids.append(status.source) else: print ("unexpected tag %d from %d" % (status.tag, status.source)) if is_complete(): break print ("OK") def comm_rank(): while True: data, status = mpi.world.recv(return_status=True) if status.tag == TAG_DATA: mpi.world.send(0, TAG_PROGRESS_REPORT, data) data.append(mpi.rank) if len(data) >= MAX_GENERATIONS: dest = 0 else: dest = random.randrange(1, mpi.size) mpi.world.send(dest, TAG_DATA, data) elif status.tag == TAG_TERMINATE: from time import sleep mpi.world.send(0, TAG_TERMINATE, 0) break else: print ("[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)) def main(): # this program sends around messages consisting of lists of visited nodes # randomly. After MAX_GENERATIONS, they are returned to rank 0. if mpi.rank == 0: rank0() else: comm_rank() if __name__ == "__main__": main()