1# Copyright 2022 Google Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase.""" 15 16import unittest 17from unittest import mock 18 19from mobly.snippet import callback_event 20from mobly.snippet import callback_handler_base 21from mobly.snippet import errors 22 23MOCK_CALLBACK_ID = '2-1' 24MOCK_RAW_EVENT = { 25 'callbackId': '2-1', 26 'name': 'AsyncTaskResult', 27 'time': 20460228696, 28 'data': { 29 'exampleData': "Here's a simple event.", 30 'successful': True, 31 'secretNumber': 12, 32 }, 33} 34 35 36class FakeCallbackHandler(callback_handler_base.CallbackHandlerBase): 37 """Fake client class for unit tests.""" 38 39 def __init__( 40 self, 41 callback_id=None, 42 event_client=None, 43 ret_value=None, 44 method_name=None, 45 device=None, 46 rpc_max_timeout_sec=120, 47 default_timeout_sec=120, 48 ): 49 """Initializes a fake callback handler object used for unit tests.""" 50 super().__init__( 51 callback_id, 52 event_client, 53 ret_value, 54 method_name, 55 device, 56 rpc_max_timeout_sec, 57 default_timeout_sec, 58 ) 59 self.mock_rpc_func = mock.Mock() 60 61 def callEventWaitAndGetRpc(self, *args, **kwargs): 62 """See base class.""" 63 return self.mock_rpc_func.callEventWaitAndGetRpc(*args, **kwargs) 64 65 def callEventGetAllRpc(self, *args, **kwargs): 66 """See base class.""" 67 return self.mock_rpc_func.callEventGetAllRpc(*args, **kwargs) 68 69 70class CallbackHandlerBaseTest(unittest.TestCase): 71 """Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase.""" 72 73 def assert_event_correct(self, actual_event, expected_raw_event_dict): 74 expected_event = callback_event.from_dict(expected_raw_event_dict) 75 self.assertEqual(str(actual_event), str(expected_event)) 76 77 def test_default_timeout_too_large(self): 78 err_msg = ( 79 'The max timeout of a single RPC must be no smaller than ' 80 'the default timeout of the callback handler. ' 81 'Got rpc_max_timeout_sec=10, default_timeout_sec=20.' 82 ) 83 with self.assertRaisesRegex(ValueError, err_msg): 84 _ = FakeCallbackHandler(rpc_max_timeout_sec=10, default_timeout_sec=20) 85 86 def test_timeout_property(self): 87 handler = FakeCallbackHandler( 88 rpc_max_timeout_sec=20, default_timeout_sec=10 89 ) 90 self.assertEqual(handler.rpc_max_timeout_sec, 20) 91 self.assertEqual(handler.default_timeout_sec, 10) 92 with self.assertRaises(AttributeError): 93 handler.rpc_max_timeout_sec = 5 94 95 with self.assertRaises(AttributeError): 96 handler.default_timeout_sec = 5 97 98 def test_callback_id_property(self): 99 handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) 100 self.assertEqual(handler.callback_id, MOCK_CALLBACK_ID) 101 with self.assertRaises(AttributeError): 102 handler.callback_id = 'ha' 103 104 def test_event_dict_to_snippet_event(self): 105 handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) 106 handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( 107 return_value=MOCK_RAW_EVENT 108 ) 109 110 event = handler.waitAndGet('ha', timeout=10) 111 self.assert_event_correct(event, MOCK_RAW_EVENT) 112 handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with( 113 MOCK_CALLBACK_ID, 'ha', 10 114 ) 115 116 def test_wait_and_get_timeout_default(self): 117 handler = FakeCallbackHandler(rpc_max_timeout_sec=20, default_timeout_sec=5) 118 handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( 119 return_value=MOCK_RAW_EVENT 120 ) 121 _ = handler.waitAndGet('ha') 122 handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with( 123 mock.ANY, mock.ANY, 5 124 ) 125 126 def test_wait_and_get_timeout_ecxeed_threshold(self): 127 rpc_max_timeout_sec = 5 128 big_timeout_sec = 10 129 handler = FakeCallbackHandler( 130 rpc_max_timeout_sec=rpc_max_timeout_sec, 131 default_timeout_sec=rpc_max_timeout_sec, 132 ) 133 handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( 134 return_value=MOCK_RAW_EVENT 135 ) 136 137 expected_msg = ( 138 f'Specified timeout {big_timeout_sec} is longer than max timeout ' 139 f'{rpc_max_timeout_sec}.' 140 ) 141 with self.assertRaisesRegex(errors.CallbackHandlerBaseError, expected_msg): 142 handler.waitAndGet('ha', big_timeout_sec) 143 144 def test_wait_for_event(self): 145 handler = FakeCallbackHandler() 146 handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( 147 return_value=MOCK_RAW_EVENT 148 ) 149 150 def some_condition(event): 151 return event.data['successful'] 152 153 event = handler.waitForEvent('AsyncTaskResult', some_condition, 0.01) 154 self.assert_event_correct(event, MOCK_RAW_EVENT) 155 156 def test_wait_for_event_negative(self): 157 handler = FakeCallbackHandler() 158 handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( 159 return_value=MOCK_RAW_EVENT 160 ) 161 162 expected_msg = ( 163 'Timed out after 0.01s waiting for an "AsyncTaskResult" event that' 164 ' satisfies the predicate "some_condition".' 165 ) 166 167 def some_condition(_): 168 return False 169 170 with self.assertRaisesRegex( 171 errors.CallbackHandlerTimeoutError, expected_msg 172 ): 173 handler.waitForEvent('AsyncTaskResult', some_condition, 0.01) 174 175 def test_wait_for_event_max_timeout(self): 176 """waitForEvent should not raise the timeout exceed threshold error.""" 177 rpc_max_timeout_sec = 5 178 big_timeout_sec = 10 179 handler = FakeCallbackHandler( 180 rpc_max_timeout_sec=rpc_max_timeout_sec, 181 default_timeout_sec=rpc_max_timeout_sec, 182 ) 183 handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( 184 return_value=MOCK_RAW_EVENT 185 ) 186 187 def some_condition(event): 188 return event.data['successful'] 189 190 # This line should not raise. 191 event = handler.waitForEvent( 192 'AsyncTaskResult', some_condition, timeout=big_timeout_sec 193 ) 194 self.assert_event_correct(event, MOCK_RAW_EVENT) 195 196 def test_get_all(self): 197 handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) 198 handler.mock_rpc_func.callEventGetAllRpc = mock.Mock( 199 return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT] 200 ) 201 202 all_events = handler.getAll('ha') 203 for event in all_events: 204 self.assert_event_correct(event, MOCK_RAW_EVENT) 205 206 handler.mock_rpc_func.callEventGetAllRpc.assert_called_once_with( 207 MOCK_CALLBACK_ID, 'ha' 208 ) 209 210 211if __name__ == '__main__': 212 unittest.main() 213