1# Copyright 2021-2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# ----------------------------------------------------------------------------- 16# Imports 17# ----------------------------------------------------------------------------- 18import asyncio 19import logging 20import traceback 21from functools import wraps 22from colors import color 23from pyee import EventEmitter 24 25 26# ----------------------------------------------------------------------------- 27# Logging 28# ----------------------------------------------------------------------------- 29logger = logging.getLogger(__name__) 30 31 32# ----------------------------------------------------------------------------- 33def setup_event_forwarding(emitter, forwarder, event_name): 34 def emit(*args, **kwargs): 35 forwarder.emit(event_name, *args, **kwargs) 36 emitter.on(event_name, emit) 37 38 39# ----------------------------------------------------------------------------- 40def composite_listener(cls): 41 """ 42 Decorator that adds a `register` and `deregister` method to a class, which 43 registers/deregisters all methods named `on_<event_name>` as a listener for 44 the <event_name> event with an emitter. 45 """ 46 def register(self, emitter): 47 for method_name in dir(cls): 48 if method_name.startswith('on_'): 49 emitter.on(method_name[3:], getattr(self, method_name)) 50 51 def deregister(self, emitter): 52 for method_name in dir(cls): 53 if method_name.startswith('on_'): 54 emitter.remove_listener(method_name[3:], getattr(self, method_name)) 55 56 cls._bumble_register_composite = register 57 cls._bumble_deregister_composite = deregister 58 return cls 59 60 61# ----------------------------------------------------------------------------- 62class CompositeEventEmitter(EventEmitter): 63 def __init__(self): 64 super().__init__() 65 self._listener = None 66 67 @property 68 def listener(self): 69 return self._listener 70 71 @listener.setter 72 def listener(self, listener): 73 if self._listener: 74 # Call the deregistration methods for each base class that has them 75 for cls in self._listener.__class__.mro(): 76 if hasattr(cls, '_bumble_register_composite'): 77 cls._bumble_deregister_composite(listener, self) 78 self._listener = listener 79 if listener: 80 # Call the registration methods for each base class that has them 81 for cls in listener.__class__.mro(): 82 if hasattr(cls, '_bumble_deregister_composite'): 83 cls._bumble_register_composite(listener, self) 84 85 86# ----------------------------------------------------------------------------- 87class AsyncRunner: 88 class WorkQueue: 89 def __init__(self, create_task=True): 90 self.queue = None 91 self.task = None 92 self.create_task = create_task 93 94 def enqueue(self, coroutine): 95 # Create a task now if we need to and haven't done so already 96 if self.create_task and self.task is None: 97 self.task = asyncio.create_task(self.run()) 98 99 # Lazy-create the coroutine queue 100 if self.queue is None: 101 self.queue = asyncio.Queue() 102 103 # Enqueue the work 104 self.queue.put_nowait(coroutine) 105 106 async def run(self): 107 while True: 108 item = await self.queue.get() 109 try: 110 await item 111 except Exception as error: 112 logger.warning(f'{color("!!! Exception in work queue:", "red")} {error}') 113 114 # Shared default queue 115 default_queue = WorkQueue() 116 117 @staticmethod 118 def run_in_task(queue=None): 119 """ 120 Function decorator used to adapt an async function into a sync function 121 """ 122 123 def decorator(func): 124 @wraps(func) 125 def wrapper(*args, **kwargs): 126 coroutine = func(*args, **kwargs) 127 if queue is None: 128 # Create a task to run the coroutine 129 async def run(): 130 try: 131 await coroutine 132 except Exception: 133 logger.warning(f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}') 134 135 asyncio.create_task(run()) 136 else: 137 # Queue the coroutine to be awaited by the work queue 138 queue.enqueue(coroutine) 139 140 return wrapper 141 142 return decorator 143