1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests encoding HDLC frames.""" 16 17import unittest 18from unittest import mock 19 20import pw_status 21 22from pw_protobuf_compiler import python_protos 23import pw_rpc 24from pw_rpc import callback_client 25from pw_rpc.console_tools import CommandHelper, Context, ClientInfo, Watchdog 26 27 28class TestWatchdog(unittest.TestCase): 29 """Tests the Watchdog class.""" 30 def setUp(self) -> None: 31 self._reset = mock.Mock() 32 self._expiration = mock.Mock() 33 self._while_expired = mock.Mock() 34 35 self._watchdog = Watchdog(self._reset, self._expiration, 36 self._while_expired, 99999) 37 38 def _trigger_timeout(self) -> None: 39 # Don't wait for the timeout -- that's too flaky. Call the internal 40 # timeout function instead. 41 self._watchdog._timeout_expired() # pylint: disable=protected-access 42 43 def test_expiration_callbacks(self) -> None: 44 self._watchdog.start() 45 46 self._expiration.not_called() 47 48 self._trigger_timeout() 49 50 self._expiration.assert_called_once_with() 51 self._while_expired.assert_not_called() 52 53 self._trigger_timeout() 54 55 self._expiration.assert_called_once_with() 56 self._while_expired.assert_called_once_with() 57 58 self._trigger_timeout() 59 60 self._expiration.assert_called_once_with() 61 self._while_expired.assert_called() 62 63 def test_reset_not_called_unless_expires(self) -> None: 64 self._watchdog.start() 65 self._watchdog.reset() 66 67 self._reset.assert_not_called() 68 self._expiration.assert_not_called() 69 self._while_expired.assert_not_called() 70 71 def test_reset_called_if_expired(self) -> None: 72 self._watchdog.start() 73 self._trigger_timeout() 74 75 self._watchdog.reset() 76 77 self._trigger_timeout() 78 79 self._reset.assert_called_once_with() 80 self._expiration.assert_called() 81 82 83class TestCommandHelper(unittest.TestCase): 84 def setUp(self) -> None: 85 self._commands = {'command_a': 'A', 'command_B': 'B'} 86 self._variables = {'hello': 1, 'world': 2} 87 self._helper = CommandHelper(self._commands, self._variables, 88 'The header', 'The footer') 89 90 def test_help_contents(self) -> None: 91 help_contents = self._helper.help() 92 93 self.assertTrue(help_contents.startswith('The header')) 94 self.assertIn('The footer', help_contents) 95 96 for var_name in self._variables: 97 self.assertIn(var_name, help_contents) 98 99 for cmd_name in self._commands: 100 self.assertIn(cmd_name, help_contents) 101 102 def test_repr_is_help(self): 103 self.assertEqual(repr(self._helper), self._helper.help()) 104 105 106_PROTO = """\ 107syntax = "proto3"; 108 109package the.pkg; 110 111message SomeMessage { 112 uint32 magic_number = 1; 113 114 message AnotherMessage { 115 string payload = 1; 116 } 117 118} 119 120service Service { 121 rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage); 122} 123""" 124 125 126class TestConsoleContext(unittest.TestCase): 127 """Tests console_tools.console.Context.""" 128 def setUp(self) -> None: 129 self._protos = python_protos.Library.from_strings(_PROTO) 130 131 self._info = ClientInfo( 132 'the_client', object(), 133 pw_rpc.Client.from_modules(callback_client.Impl(), [ 134 pw_rpc.Channel(1, lambda _: None), 135 pw_rpc.Channel(2, lambda _: None), 136 ], self._protos.modules())) 137 138 def test_sets_expected_variables(self) -> None: 139 variables = Context([self._info], 140 default_client=self._info.client, 141 protos=self._protos).variables() 142 143 self.assertIn('set_target', variables) 144 145 self.assertIsInstance(variables['help'], CommandHelper) 146 self.assertIs(variables['python_help'], help) 147 self.assertIs(pw_status.Status, variables['Status']) 148 self.assertIs(self._info.client, variables['the_client']) 149 150 def test_set_target_switches_between_clients(self) -> None: 151 client_1_channel = self._info.rpc_client.channel(1).channel 152 153 client_2_channel = pw_rpc.Channel(99, lambda _: None) 154 info_2 = ClientInfo( 155 'other_client', object(), 156 pw_rpc.Client.from_modules(callback_client.Impl(), 157 [client_2_channel], 158 self._protos.modules())) 159 160 context = Context([self._info, info_2], 161 default_client=self._info.client, 162 protos=self._protos) 163 164 # Make sure the RPC service switches from one client to the other. 165 self.assertIs(context.variables()['the'].pkg.Service.Unary.channel, 166 client_1_channel) 167 168 context.set_target(info_2.client) 169 170 self.assertIs(context.variables()['the'].pkg.Service.Unary.channel, 171 client_2_channel) 172 173 def test_default_client_must_be_in_clients(self) -> None: 174 with self.assertRaises(ValueError): 175 Context([self._info], 176 default_client='something else', 177 protos=self._protos) 178 179 def test_set_target_invalid_channel(self) -> None: 180 context = Context([self._info], 181 default_client=self._info.client, 182 protos=self._protos) 183 184 with self.assertRaises(KeyError): 185 context.set_target(self._info.client, 100) 186 187 def test_set_target_non_default_channel(self) -> None: 188 channel_1 = self._info.rpc_client.channel(1).channel 189 channel_2 = self._info.rpc_client.channel(2).channel 190 191 context = Context([self._info], 192 default_client=self._info.client, 193 protos=self._protos) 194 variables = context.variables() 195 196 self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1) 197 198 context.set_target(self._info.client, 2) 199 200 self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2) 201 202 with self.assertRaises(KeyError): 203 context.set_target(self._info.client, 100) 204 205 def test_set_target_requires_client_object(self) -> None: 206 context = Context([self._info], 207 default_client=self._info.client, 208 protos=self._protos) 209 210 with self.assertRaises(ValueError): 211 context.set_target(self._info.rpc_client) 212 213 context.set_target(self._info.client) 214 215 def test_derived_context(self) -> None: 216 called_derived_set_target = False 217 218 class DerivedContext(Context): 219 def set_target(self, 220 unused_selected_client, 221 unused_channel_id: int = None) -> None: 222 nonlocal called_derived_set_target 223 called_derived_set_target = True 224 225 variables = DerivedContext(client_info=[self._info], 226 default_client=self._info.client, 227 protos=self._protos).variables() 228 variables['set_target'](self._info.client) 229 self.assertTrue(called_derived_set_target) 230 231 232if __name__ == '__main__': 233 unittest.main() 234