• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Support for running coroutines in parallel with staggered start times."""
2
3__all__ = 'staggered_race',
4
5import contextlib
6import typing
7
8from . import events
9from . import exceptions as exceptions_mod
10from . import locks
11from . import tasks
12
13
14async def staggered_race(
15        coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]],
16        delay: typing.Optional[float],
17        *,
18        loop: events.AbstractEventLoop = None,
19) -> typing.Tuple[
20    typing.Any,
21    typing.Optional[int],
22    typing.List[typing.Optional[Exception]]
23]:
24    """Run coroutines with staggered start times and take the first to finish.
25
26    This method takes an iterable of coroutine functions. The first one is
27    started immediately. From then on, whenever the immediately preceding one
28    fails (raises an exception), or when *delay* seconds has passed, the next
29    coroutine is started. This continues until one of the coroutines complete
30    successfully, in which case all others are cancelled, or until all
31    coroutines fail.
32
33    The coroutines provided should be well-behaved in the following way:
34
35    * They should only ``return`` if completed successfully.
36
37    * They should always raise an exception if they did not complete
38      successfully. In particular, if they handle cancellation, they should
39      probably reraise, like this::
40
41        try:
42            # do work
43        except asyncio.CancelledError:
44            # undo partially completed work
45            raise
46
47    Args:
48        coro_fns: an iterable of coroutine functions, i.e. callables that
49            return a coroutine object when called. Use ``functools.partial`` or
50            lambdas to pass arguments.
51
52        delay: amount of time, in seconds, between starting coroutines. If
53            ``None``, the coroutines will run sequentially.
54
55        loop: the event loop to use.
56
57    Returns:
58        tuple *(winner_result, winner_index, exceptions)* where
59
60        - *winner_result*: the result of the winning coroutine, or ``None``
61          if no coroutines won.
62
63        - *winner_index*: the index of the winning coroutine in
64          ``coro_fns``, or ``None`` if no coroutines won. If the winning
65          coroutine may return None on success, *winner_index* can be used
66          to definitively determine whether any coroutine won.
67
68        - *exceptions*: list of exceptions returned by the coroutines.
69          ``len(exceptions)`` is equal to the number of coroutines actually
70          started, and the order is the same as in ``coro_fns``. The winning
71          coroutine's entry is ``None``.
72
73    """
74    # TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
75    loop = loop or events.get_running_loop()
76    enum_coro_fns = enumerate(coro_fns)
77    winner_result = None
78    winner_index = None
79    exceptions = []
80    running_tasks = []
81
82    async def run_one_coro(
83            previous_failed: typing.Optional[locks.Event]) -> None:
84        # Wait for the previous task to finish, or for delay seconds
85        if previous_failed is not None:
86            with contextlib.suppress(exceptions_mod.TimeoutError):
87                # Use asyncio.wait_for() instead of asyncio.wait() here, so
88                # that if we get cancelled at this point, Event.wait() is also
89                # cancelled, otherwise there will be a "Task destroyed but it is
90                # pending" later.
91                await tasks.wait_for(previous_failed.wait(), delay)
92        # Get the next coroutine to run
93        try:
94            this_index, coro_fn = next(enum_coro_fns)
95        except StopIteration:
96            return
97        # Start task that will run the next coroutine
98        this_failed = locks.Event()
99        next_task = loop.create_task(run_one_coro(this_failed))
100        running_tasks.append(next_task)
101        assert len(running_tasks) == this_index + 2
102        # Prepare place to put this coroutine's exceptions if not won
103        exceptions.append(None)
104        assert len(exceptions) == this_index + 1
105
106        try:
107            result = await coro_fn()
108        except (SystemExit, KeyboardInterrupt):
109            raise
110        except BaseException as e:
111            exceptions[this_index] = e
112            this_failed.set()  # Kickstart the next coroutine
113        else:
114            # Store winner's results
115            nonlocal winner_index, winner_result
116            assert winner_index is None
117            winner_index = this_index
118            winner_result = result
119            # Cancel all other tasks. We take care to not cancel the current
120            # task as well. If we do so, then since there is no `await` after
121            # here and CancelledError are usually thrown at one, we will
122            # encounter a curious corner case where the current task will end
123            # up as done() == True, cancelled() == False, exception() ==
124            # asyncio.CancelledError. This behavior is specified in
125            # https://bugs.python.org/issue30048
126            for i, t in enumerate(running_tasks):
127                if i != this_index:
128                    t.cancel()
129
130    first_task = loop.create_task(run_one_coro(None))
131    running_tasks.append(first_task)
132    try:
133        # Wait for a growing list of tasks to all finish: poor man's version of
134        # curio's TaskGroup or trio's nursery
135        done_count = 0
136        while done_count != len(running_tasks):
137            done, _ = await tasks.wait(running_tasks)
138            done_count = len(done)
139            # If run_one_coro raises an unhandled exception, it's probably a
140            # programming error, and I want to see it.
141            if __debug__:
142                for d in done:
143                    if d.done() and not d.cancelled() and d.exception():
144                        raise d.exception()
145        return winner_result, winner_index, exceptions
146    finally:
147        # Make sure no tasks are left running if we leave this function
148        for t in running_tasks:
149            t.cancel()
150