• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests encoding HDLC frames."""
16
17import unittest
18from unittest import mock
19
20import pw_status
21
22from pw_protobuf_compiler import python_protos
23import pw_rpc
24from pw_rpc import callback_client
25from pw_rpc.console_tools import CommandHelper, Context, ClientInfo, Watchdog
26
27
28class TestWatchdog(unittest.TestCase):
29    """Tests the Watchdog class."""
30    def setUp(self) -> None:
31        self._reset = mock.Mock()
32        self._expiration = mock.Mock()
33        self._while_expired = mock.Mock()
34
35        self._watchdog = Watchdog(self._reset, self._expiration,
36                                  self._while_expired, 99999)
37
38    def _trigger_timeout(self) -> None:
39        # Don't wait for the timeout -- that's too flaky. Call the internal
40        # timeout function instead.
41        self._watchdog._timeout_expired()  # pylint: disable=protected-access
42
43    def test_expiration_callbacks(self) -> None:
44        self._watchdog.start()
45
46        self._expiration.not_called()
47
48        self._trigger_timeout()
49
50        self._expiration.assert_called_once_with()
51        self._while_expired.assert_not_called()
52
53        self._trigger_timeout()
54
55        self._expiration.assert_called_once_with()
56        self._while_expired.assert_called_once_with()
57
58        self._trigger_timeout()
59
60        self._expiration.assert_called_once_with()
61        self._while_expired.assert_called()
62
63    def test_reset_not_called_unless_expires(self) -> None:
64        self._watchdog.start()
65        self._watchdog.reset()
66
67        self._reset.assert_not_called()
68        self._expiration.assert_not_called()
69        self._while_expired.assert_not_called()
70
71    def test_reset_called_if_expired(self) -> None:
72        self._watchdog.start()
73        self._trigger_timeout()
74
75        self._watchdog.reset()
76
77        self._trigger_timeout()
78
79        self._reset.assert_called_once_with()
80        self._expiration.assert_called()
81
82
83class TestCommandHelper(unittest.TestCase):
84    def setUp(self) -> None:
85        self._commands = {'command_a': 'A', 'command_B': 'B'}
86        self._variables = {'hello': 1, 'world': 2}
87        self._helper = CommandHelper(self._commands, self._variables,
88                                     'The header', 'The footer')
89
90    def test_help_contents(self) -> None:
91        help_contents = self._helper.help()
92
93        self.assertTrue(help_contents.startswith('The header'))
94        self.assertIn('The footer', help_contents)
95
96        for var_name in self._variables:
97            self.assertIn(var_name, help_contents)
98
99        for cmd_name in self._commands:
100            self.assertIn(cmd_name, help_contents)
101
102    def test_repr_is_help(self):
103        self.assertEqual(repr(self._helper), self._helper.help())
104
105
106_PROTO = """\
107syntax = "proto3";
108
109package the.pkg;
110
111message SomeMessage {
112  uint32 magic_number = 1;
113
114    message AnotherMessage {
115      string payload = 1;
116    }
117
118}
119
120service Service {
121  rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage);
122}
123"""
124
125
126class TestConsoleContext(unittest.TestCase):
127    """Tests console_tools.console.Context."""
128    def setUp(self) -> None:
129        self._protos = python_protos.Library.from_strings(_PROTO)
130
131        self._info = ClientInfo(
132            'the_client', object(),
133            pw_rpc.Client.from_modules(callback_client.Impl(), [
134                pw_rpc.Channel(1, lambda _: None),
135                pw_rpc.Channel(2, lambda _: None),
136            ], self._protos.modules()))
137
138    def test_sets_expected_variables(self) -> None:
139        variables = Context([self._info],
140                            default_client=self._info.client,
141                            protos=self._protos).variables()
142
143        self.assertIn('set_target', variables)
144
145        self.assertIsInstance(variables['help'], CommandHelper)
146        self.assertIs(variables['python_help'], help)
147        self.assertIs(pw_status.Status, variables['Status'])
148        self.assertIs(self._info.client, variables['the_client'])
149
150    def test_set_target_switches_between_clients(self) -> None:
151        client_1_channel = self._info.rpc_client.channel(1).channel
152
153        client_2_channel = pw_rpc.Channel(99, lambda _: None)
154        info_2 = ClientInfo(
155            'other_client', object(),
156            pw_rpc.Client.from_modules(callback_client.Impl(),
157                                       [client_2_channel],
158                                       self._protos.modules()))
159
160        context = Context([self._info, info_2],
161                          default_client=self._info.client,
162                          protos=self._protos)
163
164        # Make sure the RPC service switches from one client to the other.
165        self.assertIs(context.variables()['the'].pkg.Service.Unary.channel,
166                      client_1_channel)
167
168        context.set_target(info_2.client)
169
170        self.assertIs(context.variables()['the'].pkg.Service.Unary.channel,
171                      client_2_channel)
172
173    def test_default_client_must_be_in_clients(self) -> None:
174        with self.assertRaises(ValueError):
175            Context([self._info],
176                    default_client='something else',
177                    protos=self._protos)
178
179    def test_set_target_invalid_channel(self) -> None:
180        context = Context([self._info],
181                          default_client=self._info.client,
182                          protos=self._protos)
183
184        with self.assertRaises(KeyError):
185            context.set_target(self._info.client, 100)
186
187    def test_set_target_non_default_channel(self) -> None:
188        channel_1 = self._info.rpc_client.channel(1).channel
189        channel_2 = self._info.rpc_client.channel(2).channel
190
191        context = Context([self._info],
192                          default_client=self._info.client,
193                          protos=self._protos)
194        variables = context.variables()
195
196        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1)
197
198        context.set_target(self._info.client, 2)
199
200        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2)
201
202        with self.assertRaises(KeyError):
203            context.set_target(self._info.client, 100)
204
205    def test_set_target_requires_client_object(self) -> None:
206        context = Context([self._info],
207                          default_client=self._info.client,
208                          protos=self._protos)
209
210        with self.assertRaises(ValueError):
211            context.set_target(self._info.rpc_client)
212
213        context.set_target(self._info.client)
214
215    def test_derived_context(self) -> None:
216        called_derived_set_target = False
217
218        class DerivedContext(Context):
219            def set_target(self,
220                           unused_selected_client,
221                           unused_channel_id: int = None) -> None:
222                nonlocal called_derived_set_target
223                called_derived_set_target = True
224
225        variables = DerivedContext(client_info=[self._info],
226                                   default_client=self._info.client,
227                                   protos=self._protos).variables()
228        variables['set_target'](self._info.client)
229        self.assertTrue(called_derived_set_target)
230
231
232if __name__ == '__main__':
233    unittest.main()
234