• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2021 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests the pw_rpc.console_tools.console module."""
16
17import types
18import unittest
19
20import pw_status
21
22from pw_protobuf_compiler import python_protos
23import pw_rpc
24from pw_rpc import callback_client
25from pw_rpc.console_tools.console import (CommandHelper, Context, ClientInfo,
26                                          alias_deprecated_command)
27
28
29class TestCommandHelper(unittest.TestCase):
30    def setUp(self) -> None:
31        self._commands = {'command_a': 'A', 'command_B': 'B'}
32        self._variables = {'hello': 1, 'world': 2}
33        self._helper = CommandHelper(self._commands, self._variables,
34                                     'The header', 'The footer')
35
36    def test_help_contents(self) -> None:
37        help_contents = self._helper.help()
38
39        self.assertTrue(help_contents.startswith('The header'))
40        self.assertIn('The footer', help_contents)
41
42        for var_name in self._variables:
43            self.assertIn(var_name, help_contents)
44
45        for cmd_name in self._commands:
46            self.assertIn(cmd_name, help_contents)
47
48    def test_repr_is_help(self):
49        self.assertEqual(repr(self._helper), self._helper.help())
50
51
52_PROTO = """\
53syntax = "proto3";
54
55package the.pkg;
56
57message SomeMessage {
58  uint32 magic_number = 1;
59
60    message AnotherMessage {
61      string payload = 1;
62    }
63
64}
65
66service Service {
67  rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage);
68}
69"""
70
71
72class TestConsoleContext(unittest.TestCase):
73    """Tests console_tools.console.Context."""
74    def setUp(self) -> None:
75        self._protos = python_protos.Library.from_strings(_PROTO)
76
77        self._info = ClientInfo(
78            'the_client', object(),
79            pw_rpc.Client.from_modules(callback_client.Impl(), [
80                pw_rpc.Channel(1, lambda _: None),
81                pw_rpc.Channel(2, lambda _: None),
82            ], self._protos.modules()))
83
84    def test_sets_expected_variables(self) -> None:
85        variables = Context([self._info],
86                            default_client=self._info.client,
87                            protos=self._protos).variables()
88
89        self.assertIn('set_target', variables)
90
91        self.assertIsInstance(variables['help'], CommandHelper)
92        self.assertIs(variables['python_help'], help)
93        self.assertIs(pw_status.Status, variables['Status'])
94        self.assertIs(self._info.client, variables['the_client'])
95
96    def test_set_target_switches_between_clients(self) -> None:
97        client_1_channel = self._info.rpc_client.channel(1).channel
98
99        client_2_channel = pw_rpc.Channel(99, lambda _: None)
100        info_2 = ClientInfo(
101            'other_client', object(),
102            pw_rpc.Client.from_modules(callback_client.Impl(),
103                                       [client_2_channel],
104                                       self._protos.modules()))
105
106        context = Context([self._info, info_2],
107                          default_client=self._info.client,
108                          protos=self._protos)
109
110        # Make sure the RPC service switches from one client to the other.
111        self.assertIs(context.variables()['the'].pkg.Service.Unary.channel,
112                      client_1_channel)
113
114        context.set_target(info_2.client)
115
116        self.assertIs(context.variables()['the'].pkg.Service.Unary.channel,
117                      client_2_channel)
118
119    def test_default_client_must_be_in_clients(self) -> None:
120        with self.assertRaises(ValueError):
121            Context([self._info],
122                    default_client='something else',
123                    protos=self._protos)
124
125    def test_set_target_invalid_channel(self) -> None:
126        context = Context([self._info],
127                          default_client=self._info.client,
128                          protos=self._protos)
129
130        with self.assertRaises(KeyError):
131            context.set_target(self._info.client, 100)
132
133    def test_set_target_non_default_channel(self) -> None:
134        channel_1 = self._info.rpc_client.channel(1).channel
135        channel_2 = self._info.rpc_client.channel(2).channel
136
137        context = Context([self._info],
138                          default_client=self._info.client,
139                          protos=self._protos)
140        variables = context.variables()
141
142        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1)
143
144        context.set_target(self._info.client, 2)
145
146        self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2)
147
148        with self.assertRaises(KeyError):
149            context.set_target(self._info.client, 100)
150
151    def test_set_target_requires_client_object(self) -> None:
152        context = Context([self._info],
153                          default_client=self._info.client,
154                          protos=self._protos)
155
156        with self.assertRaises(ValueError):
157            context.set_target(self._info.rpc_client)
158
159        context.set_target(self._info.client)
160
161    def test_derived_context(self) -> None:
162        called_derived_set_target = False
163
164        class DerivedContext(Context):
165            def set_target(self,
166                           unused_selected_client,
167                           unused_channel_id: int = None) -> None:
168                nonlocal called_derived_set_target
169                called_derived_set_target = True
170
171        variables = DerivedContext(client_info=[self._info],
172                                   default_client=self._info.client,
173                                   protos=self._protos).variables()
174        variables['set_target'](self._info.client)
175        self.assertTrue(called_derived_set_target)
176
177
178class TestAliasDeprecatedCommand(unittest.TestCase):
179    def test_wraps_command_to_new_package(self) -> None:
180        variables = {'abc': types.SimpleNamespace(command=lambda: 123)}
181        alias_deprecated_command(variables, 'xyz.one.two.three', 'abc.command')
182
183        self.assertEqual(variables['xyz'].one.two.three(), 123)
184
185    def test_wraps_command_to_existing_package(self) -> None:
186        variables = {
187            'abc': types.SimpleNamespace(NewCmd=lambda: 456),
188            'one': types.SimpleNamespace(),
189        }
190        alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')
191
192        self.assertEqual(variables['one'].two.OldCmd(), 456)
193
194    def test_error_if_new_command_does_not_exist(self) -> None:
195        variables = {
196            'abc': types.SimpleNamespace(),
197            'one': types.SimpleNamespace(),
198        }
199
200        with self.assertRaises(AttributeError):
201            alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd')
202
203
204if __name__ == '__main__':
205    unittest.main()
206