# Copyright (c) 2013 The Chromium OS Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. from multiprocessing import Queue, queues class QueueBarrierTimeout(Exception): """QueueBarrier timeout exception.""" class QueueBarrier(object): """This class implements a simple barrier to synchronize processes. The barrier relies on the fact that there a single process "master" and |n| different "slaves" to make the implementation simpler. Also, given this hierarchy, the slaves and the master can exchange a token while passing through the barrier. The so called "master" shall call master_barrier() while the "slave" shall call the slave_barrier() method. If the same group of |n| slaves and the same master are participating in the barrier, it is totally safe to reuse the barrier several times with the same group of processes. """ def __init__(self, n): """Initializes the barrier with |n| slave processes and a master. @param n: The number of slave processes.""" self.n_ = n self.queue_master_ = Queue() self.queue_slave_ = Queue() def master_barrier(self, token=None, timeout=None): """Makes the master wait until all the "n" slaves have reached this point. @param token: A value passed to every slave. @param timeout: The timeout, in seconds, to wait for the slaves. A None value will block forever. Returns the list of received tokens from the slaves. """ # Wait for all the slaves. result = [] try: for _ in range(self.n_): result.append(self.queue_master_.get(timeout=timeout)) except queues.Empty: # Timeout expired raise QueueBarrierTimeout() # Release all the blocked slaves. for _ in range(self.n_): self.queue_slave_.put(token) return result def slave_barrier(self, token=None, timeout=None): """Makes a slave wait until all the "n" slaves and the master have reached this point. @param token: A value passed to the master. @param timeout: The timeout, in seconds, to wait for the slaves. A None value will block forever. """ self.queue_master_.put(token) try: return self.queue_slave_.get(timeout=timeout) except queues.Empty: # Timeout expired raise QueueBarrierTimeout()