• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Google Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase."""
15
16import unittest
17from unittest import mock
18
19from mobly.snippet import callback_event
20from mobly.snippet import callback_handler_base
21from mobly.snippet import errors
22
23MOCK_CALLBACK_ID = '2-1'
24MOCK_RAW_EVENT = {
25    'callbackId': '2-1',
26    'name': 'AsyncTaskResult',
27    'time': 20460228696,
28    'data': {
29        'exampleData': "Here's a simple event.",
30        'successful': True,
31        'secretNumber': 12,
32    },
33}
34
35
36class FakeCallbackHandler(callback_handler_base.CallbackHandlerBase):
37  """Fake client class for unit tests."""
38
39  def __init__(
40      self,
41      callback_id=None,
42      event_client=None,
43      ret_value=None,
44      method_name=None,
45      device=None,
46      rpc_max_timeout_sec=120,
47      default_timeout_sec=120,
48  ):
49    """Initializes a fake callback handler object used for unit tests."""
50    super().__init__(
51        callback_id,
52        event_client,
53        ret_value,
54        method_name,
55        device,
56        rpc_max_timeout_sec,
57        default_timeout_sec,
58    )
59    self.mock_rpc_func = mock.Mock()
60
61  def callEventWaitAndGetRpc(self, *args, **kwargs):
62    """See base class."""
63    return self.mock_rpc_func.callEventWaitAndGetRpc(*args, **kwargs)
64
65  def callEventGetAllRpc(self, *args, **kwargs):
66    """See base class."""
67    return self.mock_rpc_func.callEventGetAllRpc(*args, **kwargs)
68
69
70class CallbackHandlerBaseTest(unittest.TestCase):
71  """Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase."""
72
73  def assert_event_correct(self, actual_event, expected_raw_event_dict):
74    expected_event = callback_event.from_dict(expected_raw_event_dict)
75    self.assertEqual(str(actual_event), str(expected_event))
76
77  def test_default_timeout_too_large(self):
78    err_msg = (
79        'The max timeout of a single RPC must be no smaller than '
80        'the default timeout of the callback handler. '
81        'Got rpc_max_timeout_sec=10, default_timeout_sec=20.'
82    )
83    with self.assertRaisesRegex(ValueError, err_msg):
84      _ = FakeCallbackHandler(rpc_max_timeout_sec=10, default_timeout_sec=20)
85
86  def test_timeout_property(self):
87    handler = FakeCallbackHandler(
88        rpc_max_timeout_sec=20, default_timeout_sec=10
89    )
90    self.assertEqual(handler.rpc_max_timeout_sec, 20)
91    self.assertEqual(handler.default_timeout_sec, 10)
92    with self.assertRaises(AttributeError):
93      handler.rpc_max_timeout_sec = 5
94
95    with self.assertRaises(AttributeError):
96      handler.default_timeout_sec = 5
97
98  def test_callback_id_property(self):
99    handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
100    self.assertEqual(handler.callback_id, MOCK_CALLBACK_ID)
101    with self.assertRaises(AttributeError):
102      handler.callback_id = 'ha'
103
104  def test_event_dict_to_snippet_event(self):
105    handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
106    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
107        return_value=MOCK_RAW_EVENT
108    )
109
110    event = handler.waitAndGet('ha', timeout=10)
111    self.assert_event_correct(event, MOCK_RAW_EVENT)
112    handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with(
113        MOCK_CALLBACK_ID, 'ha', 10
114    )
115
116  def test_wait_and_get_timeout_default(self):
117    handler = FakeCallbackHandler(rpc_max_timeout_sec=20, default_timeout_sec=5)
118    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
119        return_value=MOCK_RAW_EVENT
120    )
121    _ = handler.waitAndGet('ha')
122    handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with(
123        mock.ANY, mock.ANY, 5
124    )
125
126  def test_wait_and_get_timeout_ecxeed_threshold(self):
127    rpc_max_timeout_sec = 5
128    big_timeout_sec = 10
129    handler = FakeCallbackHandler(
130        rpc_max_timeout_sec=rpc_max_timeout_sec,
131        default_timeout_sec=rpc_max_timeout_sec,
132    )
133    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
134        return_value=MOCK_RAW_EVENT
135    )
136
137    expected_msg = (
138        f'Specified timeout {big_timeout_sec} is longer than max timeout '
139        f'{rpc_max_timeout_sec}.'
140    )
141    with self.assertRaisesRegex(errors.CallbackHandlerBaseError, expected_msg):
142      handler.waitAndGet('ha', big_timeout_sec)
143
144  def test_wait_for_event(self):
145    handler = FakeCallbackHandler()
146    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
147        return_value=MOCK_RAW_EVENT
148    )
149
150    def some_condition(event):
151      return event.data['successful']
152
153    event = handler.waitForEvent('AsyncTaskResult', some_condition, 0.01)
154    self.assert_event_correct(event, MOCK_RAW_EVENT)
155
156  def test_wait_for_event_negative(self):
157    handler = FakeCallbackHandler()
158    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
159        return_value=MOCK_RAW_EVENT
160    )
161
162    expected_msg = (
163        'Timed out after 0.01s waiting for an "AsyncTaskResult" event that'
164        ' satisfies the predicate "some_condition".'
165    )
166
167    def some_condition(_):
168      return False
169
170    with self.assertRaisesRegex(
171        errors.CallbackHandlerTimeoutError, expected_msg
172    ):
173      handler.waitForEvent('AsyncTaskResult', some_condition, 0.01)
174
175  def test_wait_for_event_max_timeout(self):
176    """waitForEvent should not raise the timeout exceed threshold error."""
177    rpc_max_timeout_sec = 5
178    big_timeout_sec = 10
179    handler = FakeCallbackHandler(
180        rpc_max_timeout_sec=rpc_max_timeout_sec,
181        default_timeout_sec=rpc_max_timeout_sec,
182    )
183    handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock(
184        return_value=MOCK_RAW_EVENT
185    )
186
187    def some_condition(event):
188      return event.data['successful']
189
190    # This line should not raise.
191    event = handler.waitForEvent(
192        'AsyncTaskResult', some_condition, timeout=big_timeout_sec
193    )
194    self.assert_event_correct(event, MOCK_RAW_EVENT)
195
196  def test_get_all(self):
197    handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID)
198    handler.mock_rpc_func.callEventGetAllRpc = mock.Mock(
199        return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT]
200    )
201
202    all_events = handler.getAll('ha')
203    for event in all_events:
204      self.assert_event_correct(event, MOCK_RAW_EVENT)
205
206    handler.mock_rpc_func.callEventGetAllRpc.assert_called_once_with(
207        MOCK_CALLBACK_ID, 'ha'
208    )
209
210
211if __name__ == '__main__':
212  unittest.main()
213