1#!/usr/bin/env python3 2# 3# Copyright 2019 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17from abc import ABC, abstractmethod 18from concurrent.futures import ThreadPoolExecutor 19from datetime import datetime, timedelta 20import logging 21from queue import SimpleQueue, Empty 22 23from mobly import asserts 24 25from google.protobuf import text_format 26 27from grpc import RpcError 28 29from blueberry.tests.gd.cert.closable import Closable 30 31 32class IEventStream(ABC): 33 34 @abstractmethod 35 def get_event_queue(self): 36 pass 37 38 39class FilteringEventStream(IEventStream): 40 41 def __init__(self, stream, filter_fn): 42 self.filter_fn = filter_fn if filter_fn else lambda x: x 43 self.event_queue = SimpleQueue() 44 self.stream = stream 45 46 self.stream.register_callback(self.__event_callback, lambda packet: self.filter_fn(packet) is not None) 47 48 def __event_callback(self, event): 49 self.event_queue.put(self.filter_fn(event)) 50 51 def get_event_queue(self): 52 return self.event_queue 53 54 def unregister(self): 55 self.stream.unregister(self.__event_callback) 56 57 58def pretty_print(proto_event): 59 return '{} {}'.format(type(proto_event).__name__, text_format.MessageToString(proto_event, as_one_line=True)) 60 61 62DEFAULT_TIMEOUT_SECONDS = 30 63 64 65class EventStream(IEventStream, Closable): 66 """ 67 A class that streams events from a gRPC stream, which you can assert on. 68 69 Don't use these asserts directly, use the ones from cert.truth. 70 """ 71 72 def __init__(self, server_stream_call): 73 if server_stream_call is None: 74 raise ValueError("server_stream_call cannot be None") 75 76 self.server_stream_call = server_stream_call 77 self.event_queue = SimpleQueue() 78 self.handlers = [] 79 self.executor = ThreadPoolExecutor() 80 self.future = self.executor.submit(EventStream.__event_loop, self) 81 82 def get_event_queue(self): 83 return self.event_queue 84 85 def close(self): 86 """ 87 Stop the gRPC lambda so that event_callback will not be invoked after 88 the method returns. 89 90 This object will be useless after this call as there is no way to 91 restart the gRPC callback. You would have to create a new EventStream 92 93 :raise None on success, or the same exception as __event_loop(), or 94 concurrent.futures.TimeoutError if underlying stream failed to 95 terminate within DEFAULT_TIMEOUT_SECONDS 96 """ 97 # Try to cancel the execution, don't care the result, non-blocking 98 self.server_stream_call.cancel() 99 try: 100 # cancelling gRPC stream should cause __event_loop() to quit 101 # same exception will be raised by future.result() or 102 # concurrent.futures.TimeoutError will be raised after timeout 103 self.future.result(timeout=DEFAULT_TIMEOUT_SECONDS) 104 finally: 105 # Make sure we force shutdown the executor regardless of the result 106 self.executor.shutdown(wait=False) 107 108 def register_callback(self, callback, matcher_fn=None): 109 """ 110 Register a callback to handle events. Event will be handled by callback 111 if matcher_fn(event) returns True 112 113 callback and matcher are registered as a tuple. Hence the same callback 114 with different matcher are considered two different handler units. Same 115 matcher, but different callback are also considered different handling 116 unit 117 118 Callback will be invoked on a ThreadPoolExecutor owned by this 119 EventStream 120 121 :param callback: Will be called as callback(event) 122 :param matcher_fn: A boolean function that returns True or False when 123 calling matcher_fn(event), if None, all event will 124 be matched 125 """ 126 if callback is None: 127 raise ValueError("callback must not be None") 128 self.handlers.append((callback, matcher_fn)) 129 130 def unregister_callback(self, callback, matcher_fn=None): 131 """ 132 Unregister callback and matcher_fn from the event stream. Both objects 133 must match exactly the ones when calling register_callback() 134 135 :param callback: callback used in register_callback() 136 :param matcher_fn: matcher_fn used in register_callback() 137 :raises ValueError when (callback, matcher_fn) tuple is not found 138 """ 139 if callback is None: 140 raise ValueError("callback must not be None") 141 self.handlers.remove((callback, matcher_fn)) 142 143 def __event_loop(self): 144 """ 145 Main loop for consuming the gRPC stream events. 146 Blocks until computation is cancelled 147 :raise grpc.Error on failure 148 """ 149 try: 150 for event in self.server_stream_call: 151 self.event_queue.put(event) 152 for (callback, matcher_fn) in self.handlers: 153 if not matcher_fn or matcher_fn(event): 154 callback(event) 155 except RpcError as exp: 156 # Underlying gRPC stream should run indefinitely until cancelled 157 # Hence any other reason besides CANCELLED is raised as an error 158 if self.server_stream_call.cancelled(): 159 logging.debug("Cancelled") 160 else: 161 raise exp 162 163 def assert_event_occurs(self, match_fn, at_least_times=1, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 164 """ 165 Assert at least |at_least_times| instances of events happen where 166 match_fn(event) returns True within timeout period 167 168 :param match_fn: returns True/False on match_fn(event) 169 :param timeout: a timedelta object 170 :param at_least_times: how many times at least a matching event should 171 happen 172 :return: 173 """ 174 NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout) 175 176 def assert_event_occurs_at_most(self, match_fn, at_most_times, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 177 """ 178 Assert at most |at_most_times| instances of events happen where 179 match_fn(event) returns True within timeout period 180 181 :param match_fn: returns True/False on match_fn(event) 182 :param at_most_times: how many times at most a matching event should 183 happen 184 :param timeout:a timedelta object 185 :return: 186 """ 187 logging.debug("assert_event_occurs_at_most") 188 event_list = [] 189 end_time = datetime.now() + timeout 190 while len(event_list) <= at_most_times and datetime.now() < end_time: 191 remaining = static_remaining_time_delta(end_time) 192 logging.debug("Waiting for event iteration (%fs remaining)" % (remaining.total_seconds())) 193 try: 194 current_event = self.event_queue.get(timeout=remaining.total_seconds()) 195 if match_fn(current_event): 196 event_list.append(current_event) 197 except Empty: 198 continue 199 logging.debug("Done waiting, got %d events" % len(event_list)) 200 assert_true( 201 self, 202 len(event_list) <= at_most_times, 203 msg=("Expected at most %d events, but got %d" % (at_most_times, len(event_list)))) 204 205 206def static_remaining_time_delta(end_time): 207 remaining = end_time - datetime.now() 208 if remaining < timedelta(milliseconds=0): 209 remaining = timedelta(milliseconds=0) 210 return remaining 211 212 213def assert_true(istream, expr, msg, extras=None): 214 if not expr: 215 istream.close() 216 asserts.fail(msg, extras) 217 218 219def NOT_FOR_YOU_assert_event_occurs(istream, 220 match_fn, 221 at_least_times=1, 222 timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 223 logging.debug("assert_event_occurs %d %fs" % (at_least_times, timeout.total_seconds())) 224 event_list = [] 225 end_time = datetime.now() + timeout 226 while len(event_list) < at_least_times and datetime.now() < end_time: 227 remaining = static_remaining_time_delta(end_time) 228 logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds())) 229 try: 230 current_event = istream.get_event_queue().get(timeout=remaining.total_seconds()) 231 logging.debug("current_event: %s", current_event) 232 if match_fn(current_event): 233 event_list.append(current_event) 234 except Empty: 235 continue 236 logging.debug("Done waiting for event, received %d", len(event_list)) 237 238 assert_true( 239 istream, 240 len(event_list) >= at_least_times, 241 msg=("Expected at least %d events, but got %d" % (at_least_times, len(event_list)))) 242 243 244def NOT_FOR_YOU_assert_all_events_occur(istream, 245 match_fns, 246 order_matters, 247 timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): 248 logging.debug("assert_all_events_occur %fs" % timeout.total_seconds()) 249 pending_matches = list(match_fns) 250 matched_order = [] 251 end_time = datetime.now() + timeout 252 while len(pending_matches) > 0 and datetime.now() < end_time: 253 remaining = static_remaining_time_delta(end_time) 254 logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds())) 255 try: 256 current_event = istream.get_event_queue().get(timeout=remaining.total_seconds()) 257 for match_fn in pending_matches: 258 if match_fn(current_event): 259 pending_matches.remove(match_fn) 260 matched_order.append(match_fn) 261 except Empty: 262 continue 263 logging.debug("Done waiting for event") 264 assert_true( 265 istream, 266 len(matched_order) == len(match_fns), 267 msg=("Expected at least %d events, but got %d" % (len(match_fns), len(matched_order)))) 268 if order_matters: 269 correct_order = True 270 i = 0 271 while i < len(match_fns): 272 if match_fns[i] is not matched_order[i]: 273 correct_order = False 274 break 275 i += 1 276 assert_true(istream, correct_order, "Events not received in correct order %s %s" % (match_fns, matched_order)) 277 278 279def NOT_FOR_YOU_assert_none_matching(istream, match_fn, timeout): 280 logging.debug("assert_none_matching %fs" % (timeout.total_seconds())) 281 event = None 282 end_time = datetime.now() + timeout 283 while event is None and datetime.now() < end_time: 284 remaining = static_remaining_time_delta(end_time) 285 logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds())) 286 try: 287 current_event = istream.get_event_queue().get(timeout=remaining.total_seconds()) 288 if match_fn(current_event): 289 event = current_event 290 except Empty: 291 continue 292 logging.debug("Done waiting for an event") 293 if event is None: 294 return # Avoid an assert in MessageToString(None, ...) 295 assert_true(istream, event is None, msg='Expected None matching, but got {}'.format(pretty_print(event))) 296 297 298def NOT_FOR_YOU_assert_none(istream, timeout): 299 logging.debug("assert_none %fs" % (timeout.total_seconds())) 300 try: 301 event = istream.get_event_queue().get(timeout=timeout.total_seconds()) 302 assert_true(istream, event is None, msg='Expected None, but got {}'.format(pretty_print(event))) 303 except Empty: 304 return 305