• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021-2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# -----------------------------------------------------------------------------
16# Imports
17# -----------------------------------------------------------------------------
18import asyncio
19import logging
20import traceback
21import collections
22import sys
23from typing import Awaitable, Set, TypeVar
24from functools import wraps
25from pyee import EventEmitter
26
27from .colors import color
28
29# -----------------------------------------------------------------------------
30# Logging
31# -----------------------------------------------------------------------------
32logger = logging.getLogger(__name__)
33
34
35# -----------------------------------------------------------------------------
36def setup_event_forwarding(emitter, forwarder, event_name):
37    def emit(*args, **kwargs):
38        forwarder.emit(event_name, *args, **kwargs)
39
40    emitter.on(event_name, emit)
41
42
43# -----------------------------------------------------------------------------
44def composite_listener(cls):
45    """
46    Decorator that adds a `register` and `deregister` method to a class, which
47    registers/deregisters all methods named `on_<event_name>` as a listener for
48    the <event_name> event with an emitter.
49    """
50    # pylint: disable=protected-access
51
52    def register(self, emitter):
53        for method_name in dir(cls):
54            if method_name.startswith('on_'):
55                emitter.on(method_name[3:], getattr(self, method_name))
56
57    def deregister(self, emitter):
58        for method_name in dir(cls):
59            if method_name.startswith('on_'):
60                emitter.remove_listener(method_name[3:], getattr(self, method_name))
61
62    cls._bumble_register_composite = register
63    cls._bumble_deregister_composite = deregister
64    return cls
65
66
67# -----------------------------------------------------------------------------
68_T = TypeVar('_T')
69
70
71class AbortableEventEmitter(EventEmitter):
72    def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]:
73        """
74        Set a coroutine or future to abort when an event occur.
75        """
76        future = asyncio.ensure_future(awaitable)
77        if future.done():
78            return future
79
80        def on_event(*_):
81            if future.done():
82                return
83            msg = f'abort: {event} event occurred.'
84            if isinstance(future, asyncio.Task):
85                # python < 3.9 does not support passing a message on `Task.cancel`
86                if sys.version_info < (3, 9, 0):
87                    future.cancel()
88                else:
89                    future.cancel(msg)
90            else:
91                future.set_exception(asyncio.CancelledError(msg))
92
93        def on_done(_):
94            self.remove_listener(event, on_event)
95
96        self.on(event, on_event)
97        future.add_done_callback(on_done)
98        return future
99
100
101# -----------------------------------------------------------------------------
102class CompositeEventEmitter(AbortableEventEmitter):
103    def __init__(self):
104        super().__init__()
105        self._listener = None
106
107    @property
108    def listener(self):
109        return self._listener
110
111    @listener.setter
112    def listener(self, listener):
113        # pylint: disable=protected-access
114        if self._listener:
115            # Call the deregistration methods for each base class that has them
116            for cls in self._listener.__class__.mro():
117                if hasattr(cls, '_bumble_register_composite'):
118                    cls._bumble_deregister_composite(listener, self)
119        self._listener = listener
120        if listener:
121            # Call the registration methods for each base class that has them
122            for cls in listener.__class__.mro():
123                if hasattr(cls, '_bumble_deregister_composite'):
124                    cls._bumble_register_composite(listener, self)
125
126
127# -----------------------------------------------------------------------------
128class AsyncRunner:
129    class WorkQueue:
130        def __init__(self, create_task=True):
131            self.queue = None
132            self.task = None
133            self.create_task = create_task
134
135        def enqueue(self, coroutine):
136            # Create a task now if we need to and haven't done so already
137            if self.create_task and self.task is None:
138                self.task = asyncio.create_task(self.run())
139
140            # Lazy-create the coroutine queue
141            if self.queue is None:
142                self.queue = asyncio.Queue()
143
144            # Enqueue the work
145            self.queue.put_nowait(coroutine)
146
147        async def run(self):
148            while True:
149                item = await self.queue.get()
150                try:
151                    await item
152                except Exception as error:
153                    logger.warning(
154                        f'{color("!!! Exception in work queue:", "red")} {error}'
155                    )
156
157    # Shared default queue
158    default_queue = WorkQueue()
159
160    # Shared set of running tasks
161    running_tasks: Set[Awaitable] = set()
162
163    @staticmethod
164    def run_in_task(queue=None):
165        """
166        Function decorator used to adapt an async function into a sync function
167        """
168
169        def decorator(func):
170            @wraps(func)
171            def wrapper(*args, **kwargs):
172                coroutine = func(*args, **kwargs)
173                if queue is None:
174                    # Create a task to run the coroutine
175                    async def run():
176                        try:
177                            await coroutine
178                        except Exception:
179                            logger.warning(
180                                f'{color("!!! Exception in wrapper:", "red")} '
181                                f'{traceback.format_exc()}'
182                            )
183
184                    asyncio.create_task(run())
185                else:
186                    # Queue the coroutine to be awaited by the work queue
187                    queue.enqueue(coroutine)
188
189            return wrapper
190
191        return decorator
192
193    @staticmethod
194    def spawn(coroutine):
195        """
196        Spawn a task to run a coroutine in a "fire and forget" mode.
197
198        Using this method instead of just calling `asyncio.create_task(coroutine)`
199        is necessary when you don't keep a reference to the task, because `asyncio`
200        only keeps weak references to alive tasks.
201        """
202        task = asyncio.create_task(coroutine)
203        AsyncRunner.running_tasks.add(task)
204        task.add_done_callback(AsyncRunner.running_tasks.remove)
205
206
207# -----------------------------------------------------------------------------
208class FlowControlAsyncPipe:
209    """
210    Asyncio pipe with flow control. When writing to the pipe, the source is
211    paused (by calling a function passed in when the pipe is created) if the
212    amount of queued data exceeds a specified threshold.
213    """
214
215    def __init__(
216        self,
217        pause_source,
218        resume_source,
219        write_to_sink=None,
220        drain_sink=None,
221        threshold=0,
222    ):
223        self.pause_source = pause_source
224        self.resume_source = resume_source
225        self.write_to_sink = write_to_sink
226        self.drain_sink = drain_sink
227        self.threshold = threshold
228        self.queue = collections.deque()  # Queue of packets
229        self.queued_bytes = 0  # Number of bytes in the queue
230        self.ready_to_pump = asyncio.Event()
231        self.paused = False
232        self.source_paused = False
233        self.pump_task = None
234
235    def start(self):
236        if self.pump_task is None:
237            self.pump_task = asyncio.create_task(self.pump())
238
239        self.check_pump()
240
241    def stop(self):
242        if self.pump_task is not None:
243            self.pump_task.cancel()
244            self.pump_task = None
245
246    def write(self, packet):
247        self.queued_bytes += len(packet)
248        self.queue.append(packet)
249
250        # Pause the source if we're over the threshold
251        if self.queued_bytes > self.threshold and not self.source_paused:
252            logger.debug(f'pausing source (queued={self.queued_bytes})')
253            self.pause_source()
254            self.source_paused = True
255
256        self.check_pump()
257
258    def pause(self):
259        if not self.paused:
260            self.paused = True
261            if not self.source_paused:
262                self.pause_source()
263                self.source_paused = True
264            self.check_pump()
265
266    def resume(self):
267        if self.paused:
268            self.paused = False
269            if self.source_paused:
270                self.resume_source()
271                self.source_paused = False
272            self.check_pump()
273
274    def can_pump(self):
275        return self.queue and not self.paused and self.write_to_sink is not None
276
277    def check_pump(self):
278        if self.can_pump():
279            self.ready_to_pump.set()
280        else:
281            self.ready_to_pump.clear()
282
283    async def pump(self):
284        while True:
285            # Wait until we can try to pump packets
286            await self.ready_to_pump.wait()
287
288            # Try to pump a packet
289            if self.can_pump():
290                packet = self.queue.pop()
291                self.write_to_sink(packet)
292                self.queued_bytes -= len(packet)
293
294                # Drain the sink if we can
295                if self.drain_sink:
296                    await self.drain_sink()
297
298                # Check if we can accept more
299                if self.queued_bytes <= self.threshold and self.source_paused:
300                    logger.debug(f'resuming source (queued={self.queued_bytes})')
301                    self.source_paused = False
302                    self.resume_source()
303
304            self.check_pump()
305