• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2#
3#   Copyright 2019 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17from abc import ABC, abstractmethod
18from concurrent.futures import ThreadPoolExecutor
19from datetime import datetime, timedelta
20import logging
21from queue import SimpleQueue, Empty
22
23from mobly import asserts
24
25from google.protobuf import text_format
26
27from grpc import RpcError
28
29from blueberry.tests.gd.cert.closable import Closable
30
31
32class IEventStream(ABC):
33
34    @abstractmethod
35    def get_event_queue(self):
36        pass
37
38
39class FilteringEventStream(IEventStream):
40
41    def __init__(self, stream, filter_fn):
42        self.filter_fn = filter_fn if filter_fn else lambda x: x
43        self.event_queue = SimpleQueue()
44        self.stream = stream
45
46        self.stream.register_callback(self.__event_callback, lambda packet: self.filter_fn(packet) is not None)
47
48    def __event_callback(self, event):
49        self.event_queue.put(self.filter_fn(event))
50
51    def get_event_queue(self):
52        return self.event_queue
53
54    def unregister(self):
55        self.stream.unregister(self.__event_callback)
56
57
58def pretty_print(proto_event):
59    return '{} {}'.format(type(proto_event).__name__, text_format.MessageToString(proto_event, as_one_line=True))
60
61
62DEFAULT_TIMEOUT_SECONDS = 30
63
64
65class EventStream(IEventStream, Closable):
66    """
67    A class that streams events from a gRPC stream, which you can assert on.
68
69    Don't use these asserts directly, use the ones from cert.truth.
70    """
71
72    def __init__(self, server_stream_call):
73        if server_stream_call is None:
74            raise ValueError("server_stream_call cannot be None")
75
76        self.server_stream_call = server_stream_call
77        self.event_queue = SimpleQueue()
78        self.handlers = []
79        self.executor = ThreadPoolExecutor()
80        self.future = self.executor.submit(EventStream.__event_loop, self)
81
82    def get_event_queue(self):
83        return self.event_queue
84
85    def close(self):
86        """
87        Stop the gRPC lambda so that event_callback will not be invoked after
88        the method returns.
89
90        This object will be useless after this call as there is no way to
91        restart the gRPC callback. You would have to create a new EventStream
92
93        :raise None on success, or the same exception as __event_loop(), or
94               concurrent.futures.TimeoutError if underlying stream failed to
95               terminate within DEFAULT_TIMEOUT_SECONDS
96        """
97        # Try to cancel the execution, don't care the result, non-blocking
98        self.server_stream_call.cancel()
99        try:
100            # cancelling gRPC stream should cause __event_loop() to quit
101            # same exception will be raised by future.result() or
102            # concurrent.futures.TimeoutError will be raised after timeout
103            self.future.result(timeout=DEFAULT_TIMEOUT_SECONDS)
104        finally:
105            # Make sure we force shutdown the executor regardless of the result
106            self.executor.shutdown(wait=False)
107
108    def register_callback(self, callback, matcher_fn=None):
109        """
110        Register a callback to handle events. Event will be handled by callback
111        if matcher_fn(event) returns True
112
113        callback and matcher are registered as a tuple. Hence the same callback
114        with different matcher are considered two different handler units. Same
115        matcher, but different callback are also considered different handling
116        unit
117
118        Callback will be invoked on a ThreadPoolExecutor owned by this
119        EventStream
120
121        :param callback: Will be called as callback(event)
122        :param matcher_fn: A boolean function that returns True or False when
123                           calling matcher_fn(event), if None, all event will
124                           be matched
125        """
126        if callback is None:
127            raise ValueError("callback must not be None")
128        self.handlers.append((callback, matcher_fn))
129
130    def unregister_callback(self, callback, matcher_fn=None):
131        """
132        Unregister callback and matcher_fn from the event stream. Both objects
133        must match exactly the ones when calling register_callback()
134
135        :param callback: callback used in register_callback()
136        :param matcher_fn: matcher_fn used in register_callback()
137        :raises ValueError when (callback, matcher_fn) tuple is not found
138        """
139        if callback is None:
140            raise ValueError("callback must not be None")
141        self.handlers.remove((callback, matcher_fn))
142
143    def __event_loop(self):
144        """
145        Main loop for consuming the gRPC stream events.
146        Blocks until computation is cancelled
147        :raise grpc.Error on failure
148        """
149        try:
150            for event in self.server_stream_call:
151                self.event_queue.put(event)
152                for (callback, matcher_fn) in self.handlers:
153                    if not matcher_fn or matcher_fn(event):
154                        callback(event)
155        except RpcError as exp:
156            # Underlying gRPC stream should run indefinitely until cancelled
157            # Hence any other reason besides CANCELLED is raised as an error
158            if self.server_stream_call.cancelled():
159                logging.debug("Cancelled")
160            else:
161                raise exp
162
163    def assert_event_occurs(self, match_fn, at_least_times=1, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
164        """
165        Assert at least |at_least_times| instances of events happen where
166        match_fn(event) returns True within timeout period
167
168        :param match_fn: returns True/False on match_fn(event)
169        :param timeout: a timedelta object
170        :param at_least_times: how many times at least a matching event should
171                               happen
172        :return:
173        """
174        NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout)
175
176    def assert_event_occurs_at_most(self, match_fn, at_most_times, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
177        """
178        Assert at most |at_most_times| instances of events happen where
179        match_fn(event) returns True within timeout period
180
181        :param match_fn: returns True/False on match_fn(event)
182        :param at_most_times: how many times at most a matching event should
183                               happen
184        :param timeout:a timedelta object
185        :return:
186        """
187        logging.debug("assert_event_occurs_at_most")
188        event_list = []
189        end_time = datetime.now() + timeout
190        while len(event_list) <= at_most_times and datetime.now() < end_time:
191            remaining = static_remaining_time_delta(end_time)
192            logging.debug("Waiting for event iteration (%fs remaining)" % (remaining.total_seconds()))
193            try:
194                current_event = self.event_queue.get(timeout=remaining.total_seconds())
195                if match_fn(current_event):
196                    event_list.append(current_event)
197            except Empty:
198                continue
199        logging.debug("Done waiting, got %d events" % len(event_list))
200        assert_true(
201            self,
202            len(event_list) <= at_most_times,
203            msg=("Expected at most %d events, but got %d" % (at_most_times, len(event_list))))
204
205
206def static_remaining_time_delta(end_time):
207    remaining = end_time - datetime.now()
208    if remaining < timedelta(milliseconds=0):
209        remaining = timedelta(milliseconds=0)
210    return remaining
211
212
213def assert_true(istream, expr, msg, extras=None):
214    if not expr:
215        istream.close()
216        asserts.fail(msg, extras)
217
218
219def NOT_FOR_YOU_assert_event_occurs(istream,
220                                    match_fn,
221                                    at_least_times=1,
222                                    timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
223    logging.debug("assert_event_occurs %d %fs" % (at_least_times, timeout.total_seconds()))
224    event_list = []
225    end_time = datetime.now() + timeout
226    while len(event_list) < at_least_times and datetime.now() < end_time:
227        remaining = static_remaining_time_delta(end_time)
228        logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds()))
229        try:
230            current_event = istream.get_event_queue().get(timeout=remaining.total_seconds())
231            logging.debug("current_event: %s", current_event)
232            if match_fn(current_event):
233                event_list.append(current_event)
234        except Empty:
235            continue
236    logging.debug("Done waiting for event, received %d", len(event_list))
237
238    assert_true(
239        istream,
240        len(event_list) >= at_least_times,
241        msg=("Expected at least %d events, but got %d" % (at_least_times, len(event_list))))
242
243
244def NOT_FOR_YOU_assert_all_events_occur(istream,
245                                        match_fns,
246                                        order_matters,
247                                        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
248    logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
249    pending_matches = list(match_fns)
250    matched_order = []
251    end_time = datetime.now() + timeout
252    while len(pending_matches) > 0 and datetime.now() < end_time:
253        remaining = static_remaining_time_delta(end_time)
254        logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds()))
255        try:
256            current_event = istream.get_event_queue().get(timeout=remaining.total_seconds())
257            for match_fn in pending_matches:
258                if match_fn(current_event):
259                    pending_matches.remove(match_fn)
260                    matched_order.append(match_fn)
261        except Empty:
262            continue
263    logging.debug("Done waiting for event")
264    assert_true(
265        istream,
266        len(matched_order) == len(match_fns),
267        msg=("Expected at least %d events, but got %d" % (len(match_fns), len(matched_order))))
268    if order_matters:
269        correct_order = True
270        i = 0
271        while i < len(match_fns):
272            if match_fns[i] is not matched_order[i]:
273                correct_order = False
274                break
275            i += 1
276        assert_true(istream, correct_order, "Events not received in correct order %s %s" % (match_fns, matched_order))
277
278
279def NOT_FOR_YOU_assert_none_matching(istream, match_fn, timeout):
280    logging.debug("assert_none_matching %fs" % (timeout.total_seconds()))
281    event = None
282    end_time = datetime.now() + timeout
283    while event is None and datetime.now() < end_time:
284        remaining = static_remaining_time_delta(end_time)
285        logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds()))
286        try:
287            current_event = istream.get_event_queue().get(timeout=remaining.total_seconds())
288            if match_fn(current_event):
289                event = current_event
290        except Empty:
291            continue
292    logging.debug("Done waiting for an event")
293    if event is None:
294        return  # Avoid an assert in MessageToString(None, ...)
295    assert_true(istream, event is None, msg='Expected None matching, but got {}'.format(pretty_print(event)))
296
297
298def NOT_FOR_YOU_assert_none(istream, timeout):
299    logging.debug("assert_none %fs" % (timeout.total_seconds()))
300    try:
301        event = istream.get_event_queue().get(timeout=timeout.total_seconds())
302        assert_true(istream, event is None, msg='Expected None, but got {}'.format(pretty_print(event)))
303    except Empty:
304        return
305