1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20import traceback 21import collections 22import sys 23from typing import Awaitable, Set, TypeVar 24from functools import wraps 25from pyee import EventEmitter 26 27from .colors import color 28 29# ----------------------------------------------------------------------------- 30# Logging 31# ----------------------------------------------------------------------------- 32logger = logging.getLogger(__name__) 33 34 35# ----------------------------------------------------------------------------- 36def setup_event_forwarding(emitter, forwarder, event_name): 37 def emit(*args, **kwargs): 38 forwarder.emit(event_name, *args, **kwargs) 39 40 emitter.on(event_name, emit) 41 42 43# ----------------------------------------------------------------------------- 44def composite_listener(cls): 45 """ 46 Decorator that adds a `register` and `deregister` method to a class, which 47 registers/deregisters all methods named `on_<event_name>` as a listener for 48 the <event_name> event with an emitter. 49 """ 50 # pylint: disable=protected-access 51 52 def register(self, emitter): 53 for method_name in dir(cls): 54 if method_name.startswith('on_'): 55 emitter.on(method_name[3:], getattr(self, method_name)) 56 57 def deregister(self, emitter): 58 for method_name in dir(cls): 59 if method_name.startswith('on_'): 60 emitter.remove_listener(method_name[3:], getattr(self, method_name)) 61 62 cls._bumble_register_composite = register 63 cls._bumble_deregister_composite = deregister 64 return cls 65 66 67# ----------------------------------------------------------------------------- 68_T = TypeVar('_T') 69 70 71class AbortableEventEmitter(EventEmitter): 72 def abort_on(self, event: str, awaitable: Awaitable[_T]) -> Awaitable[_T]: 73 """ 74 Set a coroutine or future to abort when an event occur. 75 """ 76 future = asyncio.ensure_future(awaitable) 77 if future.done(): 78 return future 79 80 def on_event(*_): 81 if future.done(): 82 return 83 msg = f'abort: {event} event occurred.' 84 if isinstance(future, asyncio.Task): 85 # python < 3.9 does not support passing a message on `Task.cancel` 86 if sys.version_info < (3, 9, 0): 87 future.cancel() 88 else: 89 future.cancel(msg) 90 else: 91 future.set_exception(asyncio.CancelledError(msg)) 92 93 def on_done(_): 94 self.remove_listener(event, on_event) 95 96 self.on(event, on_event) 97 future.add_done_callback(on_done) 98 return future 99 100 101# ----------------------------------------------------------------------------- 102class CompositeEventEmitter(AbortableEventEmitter): 103 def __init__(self): 104 super().__init__() 105 self._listener = None 106 107 @property 108 def listener(self): 109 return self._listener 110 111 @listener.setter 112 def listener(self, listener): 113 # pylint: disable=protected-access 114 if self._listener: 115 # Call the deregistration methods for each base class that has them 116 for cls in self._listener.__class__.mro(): 117 if hasattr(cls, '_bumble_register_composite'): 118 cls._bumble_deregister_composite(listener, self) 119 self._listener = listener 120 if listener: 121 # Call the registration methods for each base class that has them 122 for cls in listener.__class__.mro(): 123 if hasattr(cls, '_bumble_deregister_composite'): 124 cls._bumble_register_composite(listener, self) 125 126 127# ----------------------------------------------------------------------------- 128class AsyncRunner: 129 class WorkQueue: 130 def __init__(self, create_task=True): 131 self.queue = None 132 self.task = None 133 self.create_task = create_task 134 135 def enqueue(self, coroutine): 136 # Create a task now if we need to and haven't done so already 137 if self.create_task and self.task is None: 138 self.task = asyncio.create_task(self.run()) 139 140 # Lazy-create the coroutine queue 141 if self.queue is None: 142 self.queue = asyncio.Queue() 143 144 # Enqueue the work 145 self.queue.put_nowait(coroutine) 146 147 async def run(self): 148 while True: 149 item = await self.queue.get() 150 try: 151 await item 152 except Exception as error: 153 logger.warning( 154 f'{color("!!! Exception in work queue:", "red")} {error}' 155 ) 156 157 # Shared default queue 158 default_queue = WorkQueue() 159 160 # Shared set of running tasks 161 running_tasks: Set[Awaitable] = set() 162 163 @staticmethod 164 def run_in_task(queue=None): 165 """ 166 Function decorator used to adapt an async function into a sync function 167 """ 168 169 def decorator(func): 170 @wraps(func) 171 def wrapper(*args, **kwargs): 172 coroutine = func(*args, **kwargs) 173 if queue is None: 174 # Create a task to run the coroutine 175 async def run(): 176 try: 177 await coroutine 178 except Exception: 179 logger.warning( 180 f'{color("!!! Exception in wrapper:", "red")} ' 181 f'{traceback.format_exc()}' 182 ) 183 184 asyncio.create_task(run()) 185 else: 186 # Queue the coroutine to be awaited by the work queue 187 queue.enqueue(coroutine) 188 189 return wrapper 190 191 return decorator 192 193 @staticmethod 194 def spawn(coroutine): 195 """ 196 Spawn a task to run a coroutine in a "fire and forget" mode. 197 198 Using this method instead of just calling `asyncio.create_task(coroutine)` 199 is necessary when you don't keep a reference to the task, because `asyncio` 200 only keeps weak references to alive tasks. 201 """ 202 task = asyncio.create_task(coroutine) 203 AsyncRunner.running_tasks.add(task) 204 task.add_done_callback(AsyncRunner.running_tasks.remove) 205 206 207# ----------------------------------------------------------------------------- 208class FlowControlAsyncPipe: 209 """ 210 Asyncio pipe with flow control. When writing to the pipe, the source is 211 paused (by calling a function passed in when the pipe is created) if the 212 amount of queued data exceeds a specified threshold. 213 """ 214 215 def __init__( 216 self, 217 pause_source, 218 resume_source, 219 write_to_sink=None, 220 drain_sink=None, 221 threshold=0, 222 ): 223 self.pause_source = pause_source 224 self.resume_source = resume_source 225 self.write_to_sink = write_to_sink 226 self.drain_sink = drain_sink 227 self.threshold = threshold 228 self.queue = collections.deque() # Queue of packets 229 self.queued_bytes = 0 # Number of bytes in the queue 230 self.ready_to_pump = asyncio.Event() 231 self.paused = False 232 self.source_paused = False 233 self.pump_task = None 234 235 def start(self): 236 if self.pump_task is None: 237 self.pump_task = asyncio.create_task(self.pump()) 238 239 self.check_pump() 240 241 def stop(self): 242 if self.pump_task is not None: 243 self.pump_task.cancel() 244 self.pump_task = None 245 246 def write(self, packet): 247 self.queued_bytes += len(packet) 248 self.queue.append(packet) 249 250 # Pause the source if we're over the threshold 251 if self.queued_bytes > self.threshold and not self.source_paused: 252 logger.debug(f'pausing source (queued={self.queued_bytes})') 253 self.pause_source() 254 self.source_paused = True 255 256 self.check_pump() 257 258 def pause(self): 259 if not self.paused: 260 self.paused = True 261 if not self.source_paused: 262 self.pause_source() 263 self.source_paused = True 264 self.check_pump() 265 266 def resume(self): 267 if self.paused: 268 self.paused = False 269 if self.source_paused: 270 self.resume_source() 271 self.source_paused = False 272 self.check_pump() 273 274 def can_pump(self): 275 return self.queue and not self.paused and self.write_to_sink is not None 276 277 def check_pump(self): 278 if self.can_pump(): 279 self.ready_to_pump.set() 280 else: 281 self.ready_to_pump.clear() 282 283 async def pump(self): 284 while True: 285 # Wait until we can try to pump packets 286 await self.ready_to_pump.wait() 287 288 # Try to pump a packet 289 if self.can_pump(): 290 packet = self.queue.pop() 291 self.write_to_sink(packet) 292 self.queued_bytes -= len(packet) 293 294 # Drain the sink if we can 295 if self.drain_sink: 296 await self.drain_sink() 297 298 # Check if we can accept more 299 if self.queued_bytes <= self.threshold and self.source_paused: 300 logger.debug(f'resuming source (queued={self.queued_bytes})') 301 self.source_paused = False 302 self.resume_source() 303 304 self.check_pump() 305