1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests the pw_rpc.console_tools.console module.""" 16 17import types 18import unittest 19 20import pw_status 21 22from pw_protobuf_compiler import python_protos 23import pw_rpc 24from pw_rpc import callback_client 25from pw_rpc.console_tools.console import (CommandHelper, Context, ClientInfo, 26 alias_deprecated_command) 27 28 29class TestCommandHelper(unittest.TestCase): 30 def setUp(self) -> None: 31 self._commands = {'command_a': 'A', 'command_B': 'B'} 32 self._variables = {'hello': 1, 'world': 2} 33 self._helper = CommandHelper(self._commands, self._variables, 34 'The header', 'The footer') 35 36 def test_help_contents(self) -> None: 37 help_contents = self._helper.help() 38 39 self.assertTrue(help_contents.startswith('The header')) 40 self.assertIn('The footer', help_contents) 41 42 for var_name in self._variables: 43 self.assertIn(var_name, help_contents) 44 45 for cmd_name in self._commands: 46 self.assertIn(cmd_name, help_contents) 47 48 def test_repr_is_help(self): 49 self.assertEqual(repr(self._helper), self._helper.help()) 50 51 52_PROTO = """\ 53syntax = "proto3"; 54 55package the.pkg; 56 57message SomeMessage { 58 uint32 magic_number = 1; 59 60 message AnotherMessage { 61 string payload = 1; 62 } 63 64} 65 66service Service { 67 rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage); 68} 69""" 70 71 72class TestConsoleContext(unittest.TestCase): 73 """Tests console_tools.console.Context.""" 74 def setUp(self) -> None: 75 self._protos = python_protos.Library.from_strings(_PROTO) 76 77 self._info = ClientInfo( 78 'the_client', object(), 79 pw_rpc.Client.from_modules(callback_client.Impl(), [ 80 pw_rpc.Channel(1, lambda _: None), 81 pw_rpc.Channel(2, lambda _: None), 82 ], self._protos.modules())) 83 84 def test_sets_expected_variables(self) -> None: 85 variables = Context([self._info], 86 default_client=self._info.client, 87 protos=self._protos).variables() 88 89 self.assertIn('set_target', variables) 90 91 self.assertIsInstance(variables['help'], CommandHelper) 92 self.assertIs(variables['python_help'], help) 93 self.assertIs(pw_status.Status, variables['Status']) 94 self.assertIs(self._info.client, variables['the_client']) 95 96 def test_set_target_switches_between_clients(self) -> None: 97 client_1_channel = self._info.rpc_client.channel(1).channel 98 99 client_2_channel = pw_rpc.Channel(99, lambda _: None) 100 info_2 = ClientInfo( 101 'other_client', object(), 102 pw_rpc.Client.from_modules(callback_client.Impl(), 103 [client_2_channel], 104 self._protos.modules())) 105 106 context = Context([self._info, info_2], 107 default_client=self._info.client, 108 protos=self._protos) 109 110 # Make sure the RPC service switches from one client to the other. 111 self.assertIs(context.variables()['the'].pkg.Service.Unary.channel, 112 client_1_channel) 113 114 context.set_target(info_2.client) 115 116 self.assertIs(context.variables()['the'].pkg.Service.Unary.channel, 117 client_2_channel) 118 119 def test_default_client_must_be_in_clients(self) -> None: 120 with self.assertRaises(ValueError): 121 Context([self._info], 122 default_client='something else', 123 protos=self._protos) 124 125 def test_set_target_invalid_channel(self) -> None: 126 context = Context([self._info], 127 default_client=self._info.client, 128 protos=self._protos) 129 130 with self.assertRaises(KeyError): 131 context.set_target(self._info.client, 100) 132 133 def test_set_target_non_default_channel(self) -> None: 134 channel_1 = self._info.rpc_client.channel(1).channel 135 channel_2 = self._info.rpc_client.channel(2).channel 136 137 context = Context([self._info], 138 default_client=self._info.client, 139 protos=self._protos) 140 variables = context.variables() 141 142 self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1) 143 144 context.set_target(self._info.client, 2) 145 146 self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2) 147 148 with self.assertRaises(KeyError): 149 context.set_target(self._info.client, 100) 150 151 def test_set_target_requires_client_object(self) -> None: 152 context = Context([self._info], 153 default_client=self._info.client, 154 protos=self._protos) 155 156 with self.assertRaises(ValueError): 157 context.set_target(self._info.rpc_client) 158 159 context.set_target(self._info.client) 160 161 def test_derived_context(self) -> None: 162 called_derived_set_target = False 163 164 class DerivedContext(Context): 165 def set_target(self, 166 unused_selected_client, 167 unused_channel_id: int = None) -> None: 168 nonlocal called_derived_set_target 169 called_derived_set_target = True 170 171 variables = DerivedContext(client_info=[self._info], 172 default_client=self._info.client, 173 protos=self._protos).variables() 174 variables['set_target'](self._info.client) 175 self.assertTrue(called_derived_set_target) 176 177 178class TestAliasDeprecatedCommand(unittest.TestCase): 179 def test_wraps_command_to_new_package(self) -> None: 180 variables = {'abc': types.SimpleNamespace(command=lambda: 123)} 181 alias_deprecated_command(variables, 'xyz.one.two.three', 'abc.command') 182 183 self.assertEqual(variables['xyz'].one.two.three(), 123) 184 185 def test_wraps_command_to_existing_package(self) -> None: 186 variables = { 187 'abc': types.SimpleNamespace(NewCmd=lambda: 456), 188 'one': types.SimpleNamespace(), 189 } 190 alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd') 191 192 self.assertEqual(variables['one'].two.OldCmd(), 456) 193 194 def test_error_if_new_command_does_not_exist(self) -> None: 195 variables = { 196 'abc': types.SimpleNamespace(), 197 'one': types.SimpleNamespace(), 198 } 199 200 with self.assertRaises(AttributeError): 201 alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd') 202 203 204if __name__ == '__main__': 205 unittest.main() 206