• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Module providing the `Pool` class for managing a process pool
3#
4# multiprocessing/pool.py
5#
6# Copyright (c) 2006-2008, R Oudkerk
7# All rights reserved.
8#
9# Redistribution and use in source and binary forms, with or without
10# modification, are permitted provided that the following conditions
11# are met:
12#
13# 1. Redistributions of source code must retain the above copyright
14#    notice, this list of conditions and the following disclaimer.
15# 2. Redistributions in binary form must reproduce the above copyright
16#    notice, this list of conditions and the following disclaimer in the
17#    documentation and/or other materials provided with the distribution.
18# 3. Neither the name of author nor the names of any contributors may be
19#    used to endorse or promote products derived from this software
20#    without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
23# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
24# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
25# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
26# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
27# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
28# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
29# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
30# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
31# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
32# SUCH DAMAGE.
33#
34
35__all__ = ['Pool']
36
37#
38# Imports
39#
40
41import threading
42import Queue
43import itertools
44import collections
45import time
46
47from multiprocessing import Process, cpu_count, TimeoutError
48from multiprocessing.util import Finalize, debug
49
50#
51# Constants representing the state of a pool
52#
53
54RUN = 0
55CLOSE = 1
56TERMINATE = 2
57
58#
59# Miscellaneous
60#
61
62job_counter = itertools.count()
63
64def mapstar(args):
65    return map(*args)
66
67#
68# Code run by worker processes
69#
70
71class MaybeEncodingError(Exception):
72    """Wraps possible unpickleable errors, so they can be
73    safely sent through the socket."""
74
75    def __init__(self, exc, value):
76        self.exc = repr(exc)
77        self.value = repr(value)
78        super(MaybeEncodingError, self).__init__(self.exc, self.value)
79
80    def __str__(self):
81        return "Error sending result: '%s'. Reason: '%s'" % (self.value,
82                                                             self.exc)
83
84    def __repr__(self):
85        return "<MaybeEncodingError: %s>" % str(self)
86
87
88def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
89    assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
90    put = outqueue.put
91    get = inqueue.get
92    if hasattr(inqueue, '_writer'):
93        inqueue._writer.close()
94        outqueue._reader.close()
95
96    if initializer is not None:
97        initializer(*initargs)
98
99    completed = 0
100    while maxtasks is None or (maxtasks and completed < maxtasks):
101        try:
102            task = get()
103        except (EOFError, IOError):
104            debug('worker got EOFError or IOError -- exiting')
105            break
106
107        if task is None:
108            debug('worker got sentinel -- exiting')
109            break
110
111        job, i, func, args, kwds = task
112        try:
113            result = (True, func(*args, **kwds))
114        except Exception, e:
115            result = (False, e)
116        try:
117            put((job, i, result))
118        except Exception as e:
119            wrapped = MaybeEncodingError(e, result[1])
120            debug("Possible encoding error while sending result: %s" % (
121                wrapped))
122            put((job, i, (False, wrapped)))
123        completed += 1
124    debug('worker exiting after %d tasks' % completed)
125
126#
127# Class representing a process pool
128#
129
130class Pool(object):
131    '''
132    Class which supports an async version of the `apply()` builtin
133    '''
134    Process = Process
135
136    def __init__(self, processes=None, initializer=None, initargs=(),
137                 maxtasksperchild=None):
138        self._setup_queues()
139        self._taskqueue = Queue.Queue()
140        self._cache = {}
141        self._state = RUN
142        self._maxtasksperchild = maxtasksperchild
143        self._initializer = initializer
144        self._initargs = initargs
145
146        if processes is None:
147            try:
148                processes = cpu_count()
149            except NotImplementedError:
150                processes = 1
151        if processes < 1:
152            raise ValueError("Number of processes must be at least 1")
153
154        if initializer is not None and not hasattr(initializer, '__call__'):
155            raise TypeError('initializer must be a callable')
156
157        self._processes = processes
158        self._pool = []
159        self._repopulate_pool()
160
161        self._worker_handler = threading.Thread(
162            target=Pool._handle_workers,
163            args=(self, )
164            )
165        self._worker_handler.daemon = True
166        self._worker_handler._state = RUN
167        self._worker_handler.start()
168
169
170        self._task_handler = threading.Thread(
171            target=Pool._handle_tasks,
172            args=(self._taskqueue, self._quick_put, self._outqueue,
173                  self._pool, self._cache)
174            )
175        self._task_handler.daemon = True
176        self._task_handler._state = RUN
177        self._task_handler.start()
178
179        self._result_handler = threading.Thread(
180            target=Pool._handle_results,
181            args=(self._outqueue, self._quick_get, self._cache)
182            )
183        self._result_handler.daemon = True
184        self._result_handler._state = RUN
185        self._result_handler.start()
186
187        self._terminate = Finalize(
188            self, self._terminate_pool,
189            args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
190                  self._worker_handler, self._task_handler,
191                  self._result_handler, self._cache),
192            exitpriority=15
193            )
194
195    def _join_exited_workers(self):
196        """Cleanup after any worker processes which have exited due to reaching
197        their specified lifetime.  Returns True if any workers were cleaned up.
198        """
199        cleaned = False
200        for i in reversed(range(len(self._pool))):
201            worker = self._pool[i]
202            if worker.exitcode is not None:
203                # worker exited
204                debug('cleaning up worker %d' % i)
205                worker.join()
206                cleaned = True
207                del self._pool[i]
208        return cleaned
209
210    def _repopulate_pool(self):
211        """Bring the number of pool processes up to the specified number,
212        for use after reaping workers which have exited.
213        """
214        for i in range(self._processes - len(self._pool)):
215            w = self.Process(target=worker,
216                             args=(self._inqueue, self._outqueue,
217                                   self._initializer,
218                                   self._initargs, self._maxtasksperchild)
219                            )
220            self._pool.append(w)
221            w.name = w.name.replace('Process', 'PoolWorker')
222            w.daemon = True
223            w.start()
224            debug('added worker')
225
226    def _maintain_pool(self):
227        """Clean up any exited workers and start replacements for them.
228        """
229        if self._join_exited_workers():
230            self._repopulate_pool()
231
232    def _setup_queues(self):
233        from .queues import SimpleQueue
234        self._inqueue = SimpleQueue()
235        self._outqueue = SimpleQueue()
236        self._quick_put = self._inqueue._writer.send
237        self._quick_get = self._outqueue._reader.recv
238
239    def apply(self, func, args=(), kwds={}):
240        '''
241        Equivalent of `apply()` builtin
242        '''
243        assert self._state == RUN
244        return self.apply_async(func, args, kwds).get()
245
246    def map(self, func, iterable, chunksize=None):
247        '''
248        Equivalent of `map()` builtin
249        '''
250        assert self._state == RUN
251        return self.map_async(func, iterable, chunksize).get()
252
253    def imap(self, func, iterable, chunksize=1):
254        '''
255        Equivalent of `itertools.imap()` -- can be MUCH slower than `Pool.map()`
256        '''
257        assert self._state == RUN
258        if chunksize == 1:
259            result = IMapIterator(self._cache)
260            self._taskqueue.put((((result._job, i, func, (x,), {})
261                         for i, x in enumerate(iterable)), result._set_length))
262            return result
263        else:
264            assert chunksize > 1
265            task_batches = Pool._get_tasks(func, iterable, chunksize)
266            result = IMapIterator(self._cache)
267            self._taskqueue.put((((result._job, i, mapstar, (x,), {})
268                     for i, x in enumerate(task_batches)), result._set_length))
269            return (item for chunk in result for item in chunk)
270
271    def imap_unordered(self, func, iterable, chunksize=1):
272        '''
273        Like `imap()` method but ordering of results is arbitrary
274        '''
275        assert self._state == RUN
276        if chunksize == 1:
277            result = IMapUnorderedIterator(self._cache)
278            self._taskqueue.put((((result._job, i, func, (x,), {})
279                         for i, x in enumerate(iterable)), result._set_length))
280            return result
281        else:
282            assert chunksize > 1
283            task_batches = Pool._get_tasks(func, iterable, chunksize)
284            result = IMapUnorderedIterator(self._cache)
285            self._taskqueue.put((((result._job, i, mapstar, (x,), {})
286                     for i, x in enumerate(task_batches)), result._set_length))
287            return (item for chunk in result for item in chunk)
288
289    def apply_async(self, func, args=(), kwds={}, callback=None):
290        '''
291        Asynchronous equivalent of `apply()` builtin
292        '''
293        assert self._state == RUN
294        result = ApplyResult(self._cache, callback)
295        self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
296        return result
297
298    def map_async(self, func, iterable, chunksize=None, callback=None):
299        '''
300        Asynchronous equivalent of `map()` builtin
301        '''
302        assert self._state == RUN
303        if not hasattr(iterable, '__len__'):
304            iterable = list(iterable)
305
306        if chunksize is None:
307            chunksize, extra = divmod(len(iterable), len(self._pool) * 4)
308            if extra:
309                chunksize += 1
310        if len(iterable) == 0:
311            chunksize = 0
312
313        task_batches = Pool._get_tasks(func, iterable, chunksize)
314        result = MapResult(self._cache, chunksize, len(iterable), callback)
315        self._taskqueue.put((((result._job, i, mapstar, (x,), {})
316                              for i, x in enumerate(task_batches)), None))
317        return result
318
319    @staticmethod
320    def _handle_workers(pool):
321        thread = threading.current_thread()
322
323        # Keep maintaining workers until the cache gets drained, unless the pool
324        # is terminated.
325        while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
326            pool._maintain_pool()
327            time.sleep(0.1)
328        # send sentinel to stop workers
329        pool._taskqueue.put(None)
330        debug('worker handler exiting')
331
332    @staticmethod
333    def _handle_tasks(taskqueue, put, outqueue, pool, cache):
334        thread = threading.current_thread()
335
336        for taskseq, set_length in iter(taskqueue.get, None):
337            task = None
338            i = -1
339            try:
340                for i, task in enumerate(taskseq):
341                    if thread._state:
342                        debug('task handler found thread._state != RUN')
343                        break
344                    try:
345                        put(task)
346                    except Exception as e:
347                        job, ind = task[:2]
348                        try:
349                            cache[job]._set(ind, (False, e))
350                        except KeyError:
351                            pass
352                else:
353                    if set_length:
354                        debug('doing set_length()')
355                        set_length(i+1)
356                    continue
357                break
358            except Exception as ex:
359                job, ind = task[:2] if task else (0, 0)
360                if job in cache:
361                    cache[job]._set(ind + 1, (False, ex))
362                if set_length:
363                    debug('doing set_length()')
364                    set_length(i+1)
365        else:
366            debug('task handler got sentinel')
367
368
369        try:
370            # tell result handler to finish when cache is empty
371            debug('task handler sending sentinel to result handler')
372            outqueue.put(None)
373
374            # tell workers there is no more work
375            debug('task handler sending sentinel to workers')
376            for p in pool:
377                put(None)
378        except IOError:
379            debug('task handler got IOError when sending sentinels')
380
381        debug('task handler exiting')
382
383    @staticmethod
384    def _handle_results(outqueue, get, cache):
385        thread = threading.current_thread()
386
387        while 1:
388            try:
389                task = get()
390            except (IOError, EOFError):
391                debug('result handler got EOFError/IOError -- exiting')
392                return
393
394            if thread._state:
395                assert thread._state == TERMINATE
396                debug('result handler found thread._state=TERMINATE')
397                break
398
399            if task is None:
400                debug('result handler got sentinel')
401                break
402
403            job, i, obj = task
404            try:
405                cache[job]._set(i, obj)
406            except KeyError:
407                pass
408
409        while cache and thread._state != TERMINATE:
410            try:
411                task = get()
412            except (IOError, EOFError):
413                debug('result handler got EOFError/IOError -- exiting')
414                return
415
416            if task is None:
417                debug('result handler ignoring extra sentinel')
418                continue
419            job, i, obj = task
420            try:
421                cache[job]._set(i, obj)
422            except KeyError:
423                pass
424
425        if hasattr(outqueue, '_reader'):
426            debug('ensuring that outqueue is not full')
427            # If we don't make room available in outqueue then
428            # attempts to add the sentinel (None) to outqueue may
429            # block.  There is guaranteed to be no more than 2 sentinels.
430            try:
431                for i in range(10):
432                    if not outqueue._reader.poll():
433                        break
434                    get()
435            except (IOError, EOFError):
436                pass
437
438        debug('result handler exiting: len(cache)=%s, thread._state=%s',
439              len(cache), thread._state)
440
441    @staticmethod
442    def _get_tasks(func, it, size):
443        it = iter(it)
444        while 1:
445            x = tuple(itertools.islice(it, size))
446            if not x:
447                return
448            yield (func, x)
449
450    def __reduce__(self):
451        raise NotImplementedError(
452              'pool objects cannot be passed between processes or pickled'
453              )
454
455    def close(self):
456        debug('closing pool')
457        if self._state == RUN:
458            self._state = CLOSE
459            self._worker_handler._state = CLOSE
460
461    def terminate(self):
462        debug('terminating pool')
463        self._state = TERMINATE
464        self._worker_handler._state = TERMINATE
465        self._terminate()
466
467    def join(self):
468        debug('joining pool')
469        assert self._state in (CLOSE, TERMINATE)
470        self._worker_handler.join()
471        self._task_handler.join()
472        self._result_handler.join()
473        for p in self._pool:
474            p.join()
475
476    @staticmethod
477    def _help_stuff_finish(inqueue, task_handler, size):
478        # task_handler may be blocked trying to put items on inqueue
479        debug('removing tasks from inqueue until task handler finished')
480        inqueue._rlock.acquire()
481        while task_handler.is_alive() and inqueue._reader.poll():
482            inqueue._reader.recv()
483            time.sleep(0)
484
485    @classmethod
486    def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
487                        worker_handler, task_handler, result_handler, cache):
488        # this is guaranteed to only be called once
489        debug('finalizing pool')
490
491        worker_handler._state = TERMINATE
492        task_handler._state = TERMINATE
493
494        debug('helping task handler/workers to finish')
495        cls._help_stuff_finish(inqueue, task_handler, len(pool))
496
497        assert result_handler.is_alive() or len(cache) == 0
498
499        result_handler._state = TERMINATE
500        outqueue.put(None)                  # sentinel
501
502        # We must wait for the worker handler to exit before terminating
503        # workers because we don't want workers to be restarted behind our back.
504        debug('joining worker handler')
505        if threading.current_thread() is not worker_handler:
506            worker_handler.join(1e100)
507
508        # Terminate workers which haven't already finished.
509        if pool and hasattr(pool[0], 'terminate'):
510            debug('terminating workers')
511            for p in pool:
512                if p.exitcode is None:
513                    p.terminate()
514
515        debug('joining task handler')
516        if threading.current_thread() is not task_handler:
517            task_handler.join(1e100)
518
519        debug('joining result handler')
520        if threading.current_thread() is not result_handler:
521            result_handler.join(1e100)
522
523        if pool and hasattr(pool[0], 'terminate'):
524            debug('joining pool workers')
525            for p in pool:
526                if p.is_alive():
527                    # worker has not yet exited
528                    debug('cleaning up worker %d' % p.pid)
529                    p.join()
530
531#
532# Class whose instances are returned by `Pool.apply_async()`
533#
534
535class ApplyResult(object):
536
537    def __init__(self, cache, callback):
538        self._cond = threading.Condition(threading.Lock())
539        self._job = job_counter.next()
540        self._cache = cache
541        self._ready = False
542        self._callback = callback
543        cache[self._job] = self
544
545    def ready(self):
546        return self._ready
547
548    def successful(self):
549        assert self._ready
550        return self._success
551
552    def wait(self, timeout=None):
553        self._cond.acquire()
554        try:
555            if not self._ready:
556                self._cond.wait(timeout)
557        finally:
558            self._cond.release()
559
560    def get(self, timeout=None):
561        self.wait(timeout)
562        if not self._ready:
563            raise TimeoutError
564        if self._success:
565            return self._value
566        else:
567            raise self._value
568
569    def _set(self, i, obj):
570        self._success, self._value = obj
571        if self._callback and self._success:
572            self._callback(self._value)
573        self._cond.acquire()
574        try:
575            self._ready = True
576            self._cond.notify()
577        finally:
578            self._cond.release()
579        del self._cache[self._job]
580
581AsyncResult = ApplyResult       # create alias -- see #17805
582
583#
584# Class whose instances are returned by `Pool.map_async()`
585#
586
587class MapResult(ApplyResult):
588
589    def __init__(self, cache, chunksize, length, callback):
590        ApplyResult.__init__(self, cache, callback)
591        self._success = True
592        self._value = [None] * length
593        self._chunksize = chunksize
594        if chunksize <= 0:
595            self._number_left = 0
596            self._ready = True
597            del cache[self._job]
598        else:
599            self._number_left = length//chunksize + bool(length % chunksize)
600
601    def _set(self, i, success_result):
602        success, result = success_result
603        if success:
604            self._value[i*self._chunksize:(i+1)*self._chunksize] = result
605            self._number_left -= 1
606            if self._number_left == 0:
607                if self._callback:
608                    self._callback(self._value)
609                del self._cache[self._job]
610                self._cond.acquire()
611                try:
612                    self._ready = True
613                    self._cond.notify()
614                finally:
615                    self._cond.release()
616
617        else:
618            self._success = False
619            self._value = result
620            del self._cache[self._job]
621            self._cond.acquire()
622            try:
623                self._ready = True
624                self._cond.notify()
625            finally:
626                self._cond.release()
627
628#
629# Class whose instances are returned by `Pool.imap()`
630#
631
632class IMapIterator(object):
633
634    def __init__(self, cache):
635        self._cond = threading.Condition(threading.Lock())
636        self._job = job_counter.next()
637        self._cache = cache
638        self._items = collections.deque()
639        self._index = 0
640        self._length = None
641        self._unsorted = {}
642        cache[self._job] = self
643
644    def __iter__(self):
645        return self
646
647    def next(self, timeout=None):
648        self._cond.acquire()
649        try:
650            try:
651                item = self._items.popleft()
652            except IndexError:
653                if self._index == self._length:
654                    raise StopIteration
655                self._cond.wait(timeout)
656                try:
657                    item = self._items.popleft()
658                except IndexError:
659                    if self._index == self._length:
660                        raise StopIteration
661                    raise TimeoutError
662        finally:
663            self._cond.release()
664
665        success, value = item
666        if success:
667            return value
668        raise value
669
670    __next__ = next                    # XXX
671
672    def _set(self, i, obj):
673        self._cond.acquire()
674        try:
675            if self._index == i:
676                self._items.append(obj)
677                self._index += 1
678                while self._index in self._unsorted:
679                    obj = self._unsorted.pop(self._index)
680                    self._items.append(obj)
681                    self._index += 1
682                self._cond.notify()
683            else:
684                self._unsorted[i] = obj
685
686            if self._index == self._length:
687                del self._cache[self._job]
688        finally:
689            self._cond.release()
690
691    def _set_length(self, length):
692        self._cond.acquire()
693        try:
694            self._length = length
695            if self._index == self._length:
696                self._cond.notify()
697                del self._cache[self._job]
698        finally:
699            self._cond.release()
700
701#
702# Class whose instances are returned by `Pool.imap_unordered()`
703#
704
705class IMapUnorderedIterator(IMapIterator):
706
707    def _set(self, i, obj):
708        self._cond.acquire()
709        try:
710            self._items.append(obj)
711            self._index += 1
712            self._cond.notify()
713            if self._index == self._length:
714                del self._cache[self._job]
715        finally:
716            self._cond.release()
717
718#
719#
720#
721
722class ThreadPool(Pool):
723
724    from .dummy import Process
725
726    def __init__(self, processes=None, initializer=None, initargs=()):
727        Pool.__init__(self, processes, initializer, initargs)
728
729    def _setup_queues(self):
730        self._inqueue = Queue.Queue()
731        self._outqueue = Queue.Queue()
732        self._quick_put = self._inqueue.put
733        self._quick_get = self._outqueue.get
734
735    @staticmethod
736    def _help_stuff_finish(inqueue, task_handler, size):
737        # put sentinels at head of inqueue to make workers finish
738        inqueue.not_empty.acquire()
739        try:
740            inqueue.queue.clear()
741            inqueue.queue.extend([None] * size)
742            inqueue.not_empty.notify_all()
743        finally:
744            inqueue.not_empty.release()
745