• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests the pw_rpc.console_tools.console module."""
16
17import types
18from typing import Optional
19import unittest
20
21import pw_status
22
23from pw_protobuf_compiler import python_protos
24import pw_rpc
25from pw_rpc import callback_client
26from pw_rpc.console_tools.console import (
27    CommandHelper,
28    Context,
29    ClientInfo,
30    alias_deprecated_command,
31)
32
33
34class TestCommandHelper(unittest.TestCase):
35    def setUp(self) -> None:
36        self._commands = {'command_a': 'A', 'command_B': 'B'}
37        self._variables = {'hello': 1, 'world': 2}
38        self._helper = CommandHelper(
39            self._commands, self._variables, 'The header', 'The footer'
40        )
41
42    def test_help_contents(self) -> None:
43        help_contents = self._helper.help()
44
45        self.assertTrue(help_contents.startswith('The header'))
46        self.assertIn('The footer', help_contents)
47
48        for var_name in self._variables:
49            self.assertIn(var_name, help_contents)
50
51        for cmd_name in self._commands:
52            self.assertIn(cmd_name, help_contents)
53
54    def test_repr_is_help(self):
55        self.assertEqual(repr(self._helper), self._helper.help())
56
57
58_PROTO = """\
59syntax = "proto3";
60
61package the.pkg;
62
63message SomeMessage {
64  uint32 magic_number = 1;
65
66    message AnotherMessage {
67      string payload = 1;
68    }
69
70}
71
72service Service {
73  rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage);
74}
75"""
76
77
78class TestConsoleContext(unittest.TestCase):
79    """Tests console_tools.console.Context."""
80
81    def setUp(self) -> None:
82        self._protos = python_protos.Library.from_strings(_PROTO)
83
84        self._info = ClientInfo(
85            'the_client',
86            object(),
87            pw_rpc.Client.from_modules(
88                callback_client.Impl(),
89                [
90                    pw_rpc.Channel(1, lambda _: None),
91                    pw_rpc.Channel(2, lambda _: None),
92                ],
93                self._protos.modules(),
94            ),
95        )
96
97    def test_sets_expected_variables(self) -> None:
98        variables = Context(
99            [self._info], default_client=self._info.client, protos=self._protos
100        ).variables()
101
102        self.assertIn('set_target', variables)
103
104        self.assertIsInstance(variables['help'], CommandHelper)
105        self.assertIs(variables['python_help'], help)
106        self.assertIs(pw_status.Status, variables['Status'])
107        self.assertIs(self._info.client, variables['the_client'])
108
109    def test_set_target_switches_between_clients(self) -> None:
110        client_1_channel = self._info.rpc_client.channel(1).channel
111
112        client_2_channel = pw_rpc.Channel(99, lambda _: None)
113        info_2 = ClientInfo(
114            'other_client',
115            object(),
116            pw_rpc.Client.from_modules(
117                callback_client.Impl(),
118                [client_2_channel],
119                self._protos.modules(),
120            ),
121        )
122
123        context = Context(
124            [self._info, info_2],
125            default_client=self._info.client,
126            protos=self._protos,
127        )
128
129        # Make sure the RPC service switches from one client to the other.
130        self.assertIs(
131            context.variables()['the'].pkg.Service.Unary.channel,
132            client_1_channel,
133        )
134
135        context.set_target(info_2.client)
136
137        self.assertIs(
138            context.variables()['the'].pkg.Service.Unary.channel,
139            client_2_channel,
140        )
141
142    def test_default_client_must_be_in_clients(self) -> None:
143        with self.assertRaises(ValueError):
144            Context(
145                [self._info],
146                default_client='something else',
147                protos=self._protos,
148            )
149
150    def test_set_target_invalid_channel(self) -> None:
151        context = Context(
152            [self._info], default_client=self._info.client, protos=self._protos
153        )
154
155        with self.assertRaises(KeyError):
156            context.set_target(self._info.client, 100)
157
158    def test_set_target_non_default_channel(self) -> None:
159        channel_1 = self._info.rpc_client.channel(1).channel
160        channel_2 = self._info.rpc_client.channel(2).channel
161
162        context = Context(
163            [self._info], default_client=self._info.client, protos=self._protos
164        )
165        variables = context.variables()
166
167        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1)
168
169        context.set_target(self._info.client, 2)
170
171        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2)
172
173        with self.assertRaises(KeyError):
174            context.set_target(self._info.client, 100)
175
176    def test_set_target_requires_client_object(self) -> None:
177        context = Context(
178            [self._info], default_client=self._info.client, protos=self._protos
179        )
180
181        with self.assertRaises(ValueError):
182            context.set_target(self._info.rpc_client)
183
184        context.set_target(self._info.client)
185
186    def test_derived_context(self) -> None:
187        called_derived_set_target = False
188
189        class DerivedContext(Context):
190            def set_target(
191                self,
192                unused_selected_client,
193                unused_channel_id: Optional[int] = None,
194            ) -> None:
195                nonlocal called_derived_set_target
196                called_derived_set_target = True
197
198        variables = DerivedContext(
199            client_info=[self._info],
200            default_client=self._info.client,
201            protos=self._protos,
202        ).variables()
203        variables['set_target'](self._info.client)
204        self.assertTrue(called_derived_set_target)
205
206
207class TestAliasDeprecatedCommand(unittest.TestCase):
208    def test_wraps_command_to_new_package(self) -> None:
209        variables = {'abc': types.SimpleNamespace(command=lambda: 123)}
210        alias_deprecated_command(variables, 'xyz.one.two.three', 'abc.command')
211
212        self.assertEqual(variables['xyz'].one.two.three(), 123)
213
214    def test_wraps_command_to_existing_package(self) -> None:
215        variables = {
216            'abc': types.SimpleNamespace(NewCmd=lambda: 456),
217            'one': types.SimpleNamespace(),
218        }
219        alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')
220
221        self.assertEqual(variables['one'].two.OldCmd(), 456)
222
223    def test_error_if_new_command_does_not_exist(self) -> None:
224        variables = {
225            'abc': types.SimpleNamespace(),
226            'one': types.SimpleNamespace(),
227        }
228
229        with self.assertRaises(AttributeError):
230            alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')
231
232
233if __name__ == '__main__':
234    unittest.main()
235