# Copyright 2022 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase.""" import unittest from unittest import mock from mobly.snippet import callback_event from mobly.snippet import callback_handler_base from mobly.snippet import errors MOCK_CALLBACK_ID = '2-1' MOCK_RAW_EVENT = { 'callbackId': '2-1', 'name': 'AsyncTaskResult', 'time': 20460228696, 'data': { 'exampleData': "Here's a simple event.", 'successful': True, 'secretNumber': 12, }, } class FakeCallbackHandler(callback_handler_base.CallbackHandlerBase): """Fake client class for unit tests.""" def __init__( self, callback_id=None, event_client=None, ret_value=None, method_name=None, device=None, rpc_max_timeout_sec=120, default_timeout_sec=120, ): """Initializes a fake callback handler object used for unit tests.""" super().__init__( callback_id, event_client, ret_value, method_name, device, rpc_max_timeout_sec, default_timeout_sec, ) self.mock_rpc_func = mock.Mock() def callEventWaitAndGetRpc(self, *args, **kwargs): """See base class.""" return self.mock_rpc_func.callEventWaitAndGetRpc(*args, **kwargs) def callEventGetAllRpc(self, *args, **kwargs): """See base class.""" return self.mock_rpc_func.callEventGetAllRpc(*args, **kwargs) class CallbackHandlerBaseTest(unittest.TestCase): """Unit tests for mobly.snippet.callback_handler_base.CallbackHandlerBase.""" def assert_event_correct(self, actual_event, expected_raw_event_dict): expected_event = callback_event.from_dict(expected_raw_event_dict) self.assertEqual(str(actual_event), str(expected_event)) def test_default_timeout_too_large(self): err_msg = ( 'The max timeout of a single RPC must be no smaller than ' 'the default timeout of the callback handler. ' 'Got rpc_max_timeout_sec=10, default_timeout_sec=20.' ) with self.assertRaisesRegex(ValueError, err_msg): _ = FakeCallbackHandler(rpc_max_timeout_sec=10, default_timeout_sec=20) def test_timeout_property(self): handler = FakeCallbackHandler( rpc_max_timeout_sec=20, default_timeout_sec=10 ) self.assertEqual(handler.rpc_max_timeout_sec, 20) self.assertEqual(handler.default_timeout_sec, 10) with self.assertRaises(AttributeError): handler.rpc_max_timeout_sec = 5 with self.assertRaises(AttributeError): handler.default_timeout_sec = 5 def test_callback_id_property(self): handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) self.assertEqual(handler.callback_id, MOCK_CALLBACK_ID) with self.assertRaises(AttributeError): handler.callback_id = 'ha' def test_event_dict_to_snippet_event(self): handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( return_value=MOCK_RAW_EVENT ) event = handler.waitAndGet('ha', timeout=10) self.assert_event_correct(event, MOCK_RAW_EVENT) handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with( MOCK_CALLBACK_ID, 'ha', 10 ) def test_wait_and_get_timeout_default(self): handler = FakeCallbackHandler(rpc_max_timeout_sec=20, default_timeout_sec=5) handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( return_value=MOCK_RAW_EVENT ) _ = handler.waitAndGet('ha') handler.mock_rpc_func.callEventWaitAndGetRpc.assert_called_once_with( mock.ANY, mock.ANY, 5 ) def test_wait_and_get_timeout_ecxeed_threshold(self): rpc_max_timeout_sec = 5 big_timeout_sec = 10 handler = FakeCallbackHandler( rpc_max_timeout_sec=rpc_max_timeout_sec, default_timeout_sec=rpc_max_timeout_sec, ) handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( return_value=MOCK_RAW_EVENT ) expected_msg = ( f'Specified timeout {big_timeout_sec} is longer than max timeout ' f'{rpc_max_timeout_sec}.' ) with self.assertRaisesRegex(errors.CallbackHandlerBaseError, expected_msg): handler.waitAndGet('ha', big_timeout_sec) def test_wait_for_event(self): handler = FakeCallbackHandler() handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( return_value=MOCK_RAW_EVENT ) def some_condition(event): return event.data['successful'] event = handler.waitForEvent('AsyncTaskResult', some_condition, 0.01) self.assert_event_correct(event, MOCK_RAW_EVENT) def test_wait_for_event_negative(self): handler = FakeCallbackHandler() handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( return_value=MOCK_RAW_EVENT ) expected_msg = ( 'Timed out after 0.01s waiting for an "AsyncTaskResult" event that' ' satisfies the predicate "some_condition".' ) def some_condition(_): return False with self.assertRaisesRegex( errors.CallbackHandlerTimeoutError, expected_msg ): handler.waitForEvent('AsyncTaskResult', some_condition, 0.01) def test_wait_for_event_max_timeout(self): """waitForEvent should not raise the timeout exceed threshold error.""" rpc_max_timeout_sec = 5 big_timeout_sec = 10 handler = FakeCallbackHandler( rpc_max_timeout_sec=rpc_max_timeout_sec, default_timeout_sec=rpc_max_timeout_sec, ) handler.mock_rpc_func.callEventWaitAndGetRpc = mock.Mock( return_value=MOCK_RAW_EVENT ) def some_condition(event): return event.data['successful'] # This line should not raise. event = handler.waitForEvent( 'AsyncTaskResult', some_condition, timeout=big_timeout_sec ) self.assert_event_correct(event, MOCK_RAW_EVENT) def test_get_all(self): handler = FakeCallbackHandler(callback_id=MOCK_CALLBACK_ID) handler.mock_rpc_func.callEventGetAllRpc = mock.Mock( return_value=[MOCK_RAW_EVENT, MOCK_RAW_EVENT] ) all_events = handler.getAll('ha') for event in all_events: self.assert_event_correct(event, MOCK_RAW_EVENT) handler.mock_rpc_func.callEventGetAllRpc.assert_called_once_with( MOCK_CALLBACK_ID, 'ha' ) if __name__ == '__main__': unittest.main()