• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import sys
2import multiprocessing
3
4
5_current = None
6_total = None
7
8
9def _init(current, total):
10    global _current
11    global _total
12    _current = current
13    _total = total
14
15
16def _wrapped_func(func_and_args):
17    func, argument, should_print_progress, filter_ = func_and_args
18
19    if should_print_progress:
20        with _current.get_lock():
21            _current.value += 1
22        sys.stdout.write('\r\t{} of {}'.format(_current.value, _total.value))
23        sys.stdout.flush()
24
25    return func(argument, filter_)
26
27
28def pmap(func, iterable, processes, should_print_progress, filter_=None, *args, **kwargs):
29    """
30    A parallel map function that reports on its progress.
31
32    Applies `func` to every item of `iterable` and return a list of the
33    results. If `processes` is greater than one, a process pool is used to run
34    the functions in parallel. `should_print_progress` is a boolean value that
35    indicates whether a string 'N of M' should be printed to indicate how many
36    of the functions have finished being run.
37    """
38    global _current
39    global _total
40    _current = multiprocessing.Value('i', 0)
41    _total = multiprocessing.Value('i', len(iterable))
42
43    func_and_args = [(func, arg, should_print_progress, filter_) for arg in iterable]
44    if processes == 1:
45        result = list(map(_wrapped_func, func_and_args, *args, **kwargs))
46    else:
47        pool = multiprocessing.Pool(initializer=_init,
48                                    initargs=(_current, _total,),
49                                    processes=processes)
50        result = pool.map(_wrapped_func, func_and_args, *args, **kwargs)
51        pool.close()
52        pool.join()
53
54    if should_print_progress:
55        sys.stdout.write('\r')
56    return result
57