1# Copyright 2022 Google Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Module for the base class to handle Mobly Snippet Lib's callback events.""" 15import abc 16import time 17 18from mobly.snippet import callback_event 19from mobly.snippet import errors 20 21 22class CallbackHandlerBase(abc.ABC): 23 """Base class for handling Mobly Snippet Lib's callback events. 24 25 All the events handled by a callback handler are originally triggered by one 26 async RPC call. All the events are tagged with a callback_id specific to a 27 call to an async RPC method defined on the server side. 28 29 The raw message representing an event looks like: 30 31 .. code-block:: python 32 33 { 34 'callbackId': <string, callbackId>, 35 'name': <string, name of the event>, 36 'time': <long, epoch time of when the event was created on the 37 server side>, 38 'data': <dict, extra data from the callback on the server side> 39 } 40 41 Each message is then used to create a CallbackEvent object on the client 42 side. 43 44 Attributes: 45 ret_value: any, the direct return value of the async RPC call. 46 """ 47 48 def __init__( 49 self, 50 callback_id, 51 event_client, 52 ret_value, 53 method_name, 54 device, 55 rpc_max_timeout_sec, 56 default_timeout_sec=120, 57 ): 58 """Initializes a callback handler base object. 59 60 Args: 61 callback_id: str, the callback ID which associates with a group of 62 callback events. 63 event_client: SnippetClientV2, the client object used to send RPC to the 64 server and receive response. 65 ret_value: any, the direct return value of the async RPC call. 66 method_name: str, the name of the executed Async snippet function. 67 device: DeviceController, the device object associated with this handler. 68 rpc_max_timeout_sec: float, maximum time for sending a single RPC call. 69 default_timeout_sec: float, the default timeout for this handler. It 70 must be no longer than rpc_max_timeout_sec. 71 """ 72 self._id = callback_id 73 self.ret_value = ret_value 74 self._device = device 75 self._event_client = event_client 76 self._method_name = method_name 77 78 if rpc_max_timeout_sec < default_timeout_sec: 79 raise ValueError( 80 'The max timeout of a single RPC must be no smaller ' 81 'than the default timeout of the callback handler. ' 82 f'Got rpc_max_timeout_sec={rpc_max_timeout_sec}, ' 83 f'default_timeout_sec={default_timeout_sec}.' 84 ) 85 self._rpc_max_timeout_sec = rpc_max_timeout_sec 86 self._default_timeout_sec = default_timeout_sec 87 88 @property 89 def rpc_max_timeout_sec(self): 90 """Maximum time for sending a single RPC call.""" 91 return self._rpc_max_timeout_sec 92 93 @property 94 def default_timeout_sec(self): 95 """Default timeout used by this callback handler.""" 96 return self._default_timeout_sec 97 98 @property 99 def callback_id(self): 100 """The callback ID which associates a group of callback events.""" 101 return self._id 102 103 @abc.abstractmethod 104 def callEventWaitAndGetRpc(self, callback_id, event_name, timeout_sec): 105 """Calls snippet lib's RPC to wait for a callback event. 106 107 Override this method to use this class with various snippet lib 108 implementations. 109 110 This function waits and gets a CallbackEvent with the specified identifier 111 from the server. It will raise a timeout error if the expected event does 112 not occur within the time limit. 113 114 Args: 115 callback_id: str, the callback identifier. 116 event_name: str, the callback name. 117 timeout_sec: float, the number of seconds to wait for the event. It is 118 already checked that this argument is no longer than the max timeout 119 of a single RPC. 120 121 Returns: 122 The event dictionary. 123 124 Raises: 125 errors.CallbackHandlerTimeoutError: Raised if the expected event does not 126 occur within the time limit. 127 """ 128 129 @abc.abstractmethod 130 def callEventGetAllRpc(self, callback_id, event_name): 131 """Calls snippet lib's RPC to get all existing snippet events. 132 133 Override this method to use this class with various snippet lib 134 implementations. 135 136 This function gets all existing events in the server with the specified 137 identifier without waiting. 138 139 Args: 140 callback_id: str, the callback identifier. 141 event_name: str, the callback name. 142 143 Returns: 144 A list of event dictionaries. 145 """ 146 147 def waitAndGet(self, event_name, timeout=None): 148 """Waits and gets a CallbackEvent with the specified identifier. 149 150 It will raise a timeout error if the expected event does not occur within 151 the time limit. 152 153 Args: 154 event_name: str, the name of the event to get. 155 timeout: float, the number of seconds to wait before giving up. If None, 156 it will be set to self.default_timeout_sec. 157 158 Returns: 159 CallbackEvent, the oldest entry of the specified event. 160 161 Raises: 162 errors.CallbackHandlerBaseError: If the specified timeout is longer than 163 the max timeout supported. 164 errors.CallbackHandlerTimeoutError: The expected event does not occur 165 within the time limit. 166 """ 167 if timeout is None: 168 timeout = self.default_timeout_sec 169 170 if timeout: 171 if timeout > self.rpc_max_timeout_sec: 172 raise errors.CallbackHandlerBaseError( 173 self._device, 174 f'Specified timeout {timeout} is longer than max timeout ' 175 f'{self.rpc_max_timeout_sec}.', 176 ) 177 178 raw_event = self.callEventWaitAndGetRpc(self._id, event_name, timeout) 179 return callback_event.from_dict(raw_event) 180 181 def waitForEvent(self, event_name, predicate, timeout=None): 182 """Waits for an event of the specific name that satisfies the predicate. 183 184 This call will block until the expected event has been received or time 185 out. 186 187 The predicate function defines the condition the event is expected to 188 satisfy. It takes an event and returns True if the condition is 189 satisfied, False otherwise. 190 191 Note all events of the same name that are received but don't satisfy 192 the predicate will be discarded and not be available for further 193 consumption. 194 195 Args: 196 event_name: str, the name of the event to wait for. 197 predicate: function, a function that takes an event (dictionary) and 198 returns a bool. 199 timeout: float, the number of seconds to wait before giving up. If None, 200 it will be set to self.default_timeout_sec. 201 202 Returns: 203 dictionary, the event that satisfies the predicate if received. 204 205 Raises: 206 errors.CallbackHandlerTimeoutError: raised if no event that satisfies the 207 predicate is received after timeout seconds. 208 """ 209 if timeout is None: 210 timeout = self.default_timeout_sec 211 212 deadline = time.perf_counter() + timeout 213 while time.perf_counter() <= deadline: 214 single_rpc_timeout = deadline - time.perf_counter() 215 if single_rpc_timeout < 0: 216 break 217 218 single_rpc_timeout = min(single_rpc_timeout, self.rpc_max_timeout_sec) 219 try: 220 event = self.waitAndGet(event_name, single_rpc_timeout) 221 except errors.CallbackHandlerTimeoutError: 222 # Ignoring errors.CallbackHandlerTimeoutError since we need to throw 223 # one with a more specific message. 224 break 225 if predicate(event): 226 return event 227 228 raise errors.CallbackHandlerTimeoutError( 229 self._device, 230 f'Timed out after {timeout}s waiting for an "{event_name}" event that ' 231 f'satisfies the predicate "{predicate.__name__}".', 232 ) 233 234 def getAll(self, event_name): 235 """Gets all existing events in the server with the specified identifier. 236 237 This is a non-blocking call. 238 239 Args: 240 event_name: str, the name of the event to get. 241 242 Returns: 243 A list of CallbackEvent, each representing an event from the Server side. 244 """ 245 raw_events = self.callEventGetAllRpc(self._id, event_name) 246 return [callback_event.from_dict(msg) for msg in raw_events] 247