#!/usr/bin/env python3 # Copyright 2021 The Pigweed Authors # # Licensed under the Apache License, Version 2.0 (the "License"); you may not # use this file except in compliance with the License. You may obtain a copy of # the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations under # the License. """Tests the pw_rpc.console_tools.console module.""" import types import unittest import pw_status from pw_protobuf_compiler import python_protos import pw_rpc from pw_rpc import callback_client from pw_rpc.console_tools.console import ( CommandHelper, Context, ClientInfo, alias_deprecated_command, ) class TestCommandHelper(unittest.TestCase): def setUp(self) -> None: self._commands = {'command_a': 'A', 'command_B': 'B'} self._variables = {'hello': 1, 'world': 2} self._helper = CommandHelper( self._commands, self._variables, 'The header', 'The footer' ) def test_help_contents(self) -> None: help_contents = self._helper.help() self.assertTrue(help_contents.startswith('The header')) self.assertIn('The footer', help_contents) for var_name in self._variables: self.assertIn(var_name, help_contents) for cmd_name in self._commands: self.assertIn(cmd_name, help_contents) def test_repr_is_help(self): self.assertEqual(repr(self._helper), self._helper.help()) _PROTO = """\ syntax = "proto3"; package the.pkg; message SomeMessage { uint32 magic_number = 1; message AnotherMessage { string payload = 1; } } service Service { rpc Unary(SomeMessage) returns (SomeMessage.AnotherMessage); } """ class TestConsoleContext(unittest.TestCase): """Tests console_tools.console.Context.""" def setUp(self) -> None: self._protos = python_protos.Library.from_strings(_PROTO) self._info = ClientInfo( 'the_client', object(), pw_rpc.Client.from_modules( callback_client.Impl(), [ pw_rpc.Channel(1, lambda _: None), pw_rpc.Channel(2, lambda _: None), ], self._protos.modules(), ), ) def test_sets_expected_variables(self) -> None: variables = Context( [self._info], default_client=self._info.client, protos=self._protos ).variables() self.assertIn('set_target', variables) self.assertIsInstance(variables['help'], CommandHelper) self.assertIs(variables['python_help'], help) self.assertIs(pw_status.Status, variables['Status']) self.assertIs(self._info.client, variables['the_client']) def test_set_target_switches_between_clients(self) -> None: client_1_channel = self._info.rpc_client.channel(1).channel client_2_channel = pw_rpc.Channel(99, lambda _: None) info_2 = ClientInfo( 'other_client', object(), pw_rpc.Client.from_modules( callback_client.Impl(), [client_2_channel], self._protos.modules(), ), ) context = Context( [self._info, info_2], default_client=self._info.client, protos=self._protos, ) # Make sure the RPC service switches from one client to the other. self.assertIs( context.variables()['the'].pkg.Service.Unary.channel, client_1_channel, ) context.set_target(info_2.client) self.assertIs( context.variables()['the'].pkg.Service.Unary.channel, client_2_channel, ) def test_default_client_must_be_in_clients(self) -> None: with self.assertRaises(ValueError): Context( [self._info], default_client='something else', protos=self._protos, ) def test_set_target_invalid_channel(self) -> None: context = Context( [self._info], default_client=self._info.client, protos=self._protos ) with self.assertRaises(KeyError): context.set_target(self._info.client, 100) def test_set_target_non_default_channel(self) -> None: channel_1 = self._info.rpc_client.channel(1).channel channel_2 = self._info.rpc_client.channel(2).channel context = Context( [self._info], default_client=self._info.client, protos=self._protos ) variables = context.variables() self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_1) context.set_target(self._info.client, 2) self.assertIs(variables['the'].pkg.Service.Unary.channel, channel_2) with self.assertRaises(KeyError): context.set_target(self._info.client, 100) def test_set_target_requires_client_object(self) -> None: context = Context( [self._info], default_client=self._info.client, protos=self._protos ) with self.assertRaises(ValueError): context.set_target(self._info.rpc_client) context.set_target(self._info.client) def test_derived_context(self) -> None: called_derived_set_target = False class DerivedContext(Context): def set_target( self, unused_selected_client, unused_channel_id: int | None = None, ) -> None: nonlocal called_derived_set_target called_derived_set_target = True variables = DerivedContext( client_info=[self._info], default_client=self._info.client, protos=self._protos, ).variables() variables['set_target'](self._info.client) self.assertTrue(called_derived_set_target) class TestAliasDeprecatedCommand(unittest.TestCase): def test_wraps_command_to_new_package(self) -> None: variables = {'abc': types.SimpleNamespace(command=lambda: 123)} alias_deprecated_command(variables, 'xyz.one.two.three', 'abc.command') self.assertEqual(variables['xyz'].one.two.three(), 123) def test_wraps_command_to_existing_package(self) -> None: variables = { 'abc': types.SimpleNamespace(NewCmd=lambda: 456), 'one': types.SimpleNamespace(), } alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd') self.assertEqual(variables['one'].two.OldCmd(), 456) def test_error_if_new_command_does_not_exist(self) -> None: variables = { 'abc': types.SimpleNamespace(), 'one': types.SimpleNamespace(), } with self.assertRaises(AttributeError): alias_deprecated_command(variables, 'one.two.OldCmd', 'abc.NewCmd') if __name__ == '__main__': unittest.main()