1#!/usr/bin/env python3 2# Copyright 2021 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests the pw_rpc.console_tools.console module.""" 16 17import types 18from typing import Optional 19import unittest 20 21import pw_status 22 23from pw_protobuf_compiler import python_protos 24import pw_rpc 25from pw_rpc import callback_client 26from pw_rpc.console_tools.console import ( 27 CommandHelper, 28 Context, 29 ClientInfo, 30 alias_deprecated_command, 31) 32 33 34class TestCommandHelper(unittest.TestCase): 35 def setUp(self) -> None: 36 self._commands = {'command_a': 'A', 'command_B': 'B'} 37 self._variables = {'hello': 1, 'world': 2} 38 self._helper = CommandHelper( 39 self._commands, self._variables, 'The header', 'The footer' 40 ) 41 42 def test_help_contents(self) -> None: 43 help_contents = self._helper.help() 44 45 self.assertTrue(help_contents.startswith('The header')) 46 self.assertIn('The footer', help_contents) 47 48 for var_name in self._variables: 49 self.assertIn(var_name, help_contents) 50 51 for cmd_name in self._commands: 52 self.assertIn(cmd_name, help_contents) 53 54 def test_repr_is_help(self): 55 self.assertEqual(repr(self._helper), self._helper.help()) 56 57 58_PROTO = """\ 59syntax = "proto3"; 60 61package the.pkg; 62 63message SomeMessage { 64 uint32 magic_number = 1; 65 66 message AnotherMessage { 67 string payload = 1; 68 } 69 70} 71 72service Service { 73 rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage); 74} 75""" 76 77 78class TestConsoleContext(unittest.TestCase): 79 """Tests console_tools.console.Context.""" 80 81 def setUp(self) -> None: 82 self._protos = python_protos.Library.from_strings(_PROTO) 83 84 self._info = ClientInfo( 85 'the_client', 86 object(), 87 pw_rpc.Client.from_modules( 88 callback_client.Impl(), 89 [ 90 pw_rpc.Channel(1, lambda _: None), 91 pw_rpc.Channel(2, lambda _: None), 92 ], 93 self._protos.modules(), 94 ), 95 ) 96 97 def test_sets_expected_variables(self) -> None: 98 variables = Context( 99 [self._info], default_client=self._info.client, protos=self._protos 100 ).variables() 101 102 self.assertIn('set_target', variables) 103 104 self.assertIsInstance(variables['help'], CommandHelper) 105 self.assertIs(variables['python_help'], help) 106 self.assertIs(pw_status.Status, variables['Status']) 107 self.assertIs(self._info.client, variables['the_client']) 108 109 def test_set_target_switches_between_clients(self) -> None: 110 client_1_channel = self._info.rpc_client.channel(1).channel 111 112 client_2_channel = pw_rpc.Channel(99, lambda _: None) 113 info_2 = ClientInfo( 114 'other_client', 115 object(), 116 pw_rpc.Client.from_modules( 117 callback_client.Impl(), 118 [client_2_channel], 119 self._protos.modules(), 120 ), 121 ) 122 123 context = Context( 124 [self._info, info_2], 125 default_client=self._info.client, 126 protos=self._protos, 127 ) 128 129 # Make sure the RPC service switches from one client to the other. 130 self.assertIs( 131 context.variables()['the'].pkg.Service.Unary.channel, 132 client_1_channel, 133 ) 134 135 context.set_target(info_2.client) 136 137 self.assertIs( 138 context.variables()['the'].pkg.Service.Unary.channel, 139 client_2_channel, 140 ) 141 142 def test_default_client_must_be_in_clients(self) -> None: 143 with self.assertRaises(ValueError): 144 Context( 145 [self._info], 146 default_client='something else', 147 protos=self._protos, 148 ) 149 150 def test_set_target_invalid_channel(self) -> None: 151 context = Context( 152 [self._info], default_client=self._info.client, protos=self._protos 153 ) 154 155 with self.assertRaises(KeyError): 156 context.set_target(self._info.client, 100) 157 158 def test_set_target_non_default_channel(self) -> None: 159 channel_1 = self._info.rpc_client.channel(1).channel 160 channel_2 = self._info.rpc_client.channel(2).channel 161 162 context = Context( 163 [self._info], default_client=self._info.client, protos=self._protos 164 ) 165 variables = context.variables() 166 167 self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1) 168 169 context.set_target(self._info.client, 2) 170 171 self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2) 172 173 with self.assertRaises(KeyError): 174 context.set_target(self._info.client, 100) 175 176 def test_set_target_requires_client_object(self) -> None: 177 context = Context( 178 [self._info], default_client=self._info.client, protos=self._protos 179 ) 180 181 with self.assertRaises(ValueError): 182 context.set_target(self._info.rpc_client) 183 184 context.set_target(self._info.client) 185 186 def test_derived_context(self) -> None: 187 called_derived_set_target = False 188 189 class DerivedContext(Context): 190 def set_target( 191 self, 192 unused_selected_client, 193 unused_channel_id: Optional[int] = None, 194 ) -> None: 195 nonlocal called_derived_set_target 196 called_derived_set_target = True 197 198 variables = DerivedContext( 199 client_info=[self._info], 200 default_client=self._info.client, 201 protos=self._protos, 202 ).variables() 203 variables['set_target'](self._info.client) 204 self.assertTrue(called_derived_set_target) 205 206 207class TestAliasDeprecatedCommand(unittest.TestCase): 208 def test_wraps_command_to_new_package(self) -> None: 209 variables = {'abc': types.SimpleNamespace(command=lambda: 123)} 210 alias_deprecated_command(variables, 'xyz.one.two.three', 'abc.command') 211 212 self.assertEqual(variables['xyz'].one.two.three(), 123) 213 214 def test_wraps_command_to_existing_package(self) -> None: 215 variables = { 216 'abc': types.SimpleNamespace(NewCmd=lambda: 456), 217 'one': types.SimpleNamespace(), 218 } 219 alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd') 220 221 self.assertEqual(variables['one'].two.OldCmd(), 456) 222 223 def test_error_if_new_command_does_not_exist(self) -> None: 224 variables = { 225 'abc': types.SimpleNamespace(), 226 'one': types.SimpleNamespace(), 227 } 228 229 with self.assertRaises(AttributeError): 230 alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd') 231 232 233if __name__ == '__main__': 234 unittest.main() 235