1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests creating pw_rpc client.""" 16 17import unittest 18 19from pw_status import Status 20 21from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 22from pw_rpc import packets 23 24_TEST_REQUEST = RpcPacket(type=PacketType.REQUEST, 25 channel_id=1, 26 service_id=2, 27 method_id=3, 28 payload=RpcPacket(status=321).SerializeToString()) 29 30 31class PacketsTest(unittest.TestCase): 32 """Tests for packet encoding and decoding.""" 33 def test_encode_request(self): 34 data = packets.encode_request((1, 2, 3), RpcPacket(status=321)) 35 packet = RpcPacket() 36 packet.ParseFromString(data) 37 38 self.assertEqual(_TEST_REQUEST, packet) 39 40 def test_encode_response(self): 41 response = RpcPacket(type=PacketType.RESPONSE, 42 channel_id=1, 43 service_id=2, 44 method_id=3, 45 payload=RpcPacket(status=321).SerializeToString()) 46 47 data = packets.encode_response((1, 2, 3), RpcPacket(status=321)) 48 packet = RpcPacket() 49 packet.ParseFromString(data) 50 51 self.assertEqual(response, packet) 52 53 def test_encode_cancel(self): 54 data = packets.encode_cancel((9, 8, 7)) 55 56 packet = RpcPacket() 57 packet.ParseFromString(data) 58 59 self.assertEqual( 60 packet, 61 RpcPacket(type=PacketType.CANCEL_SERVER_STREAM, 62 channel_id=9, 63 service_id=8, 64 method_id=7)) 65 66 def test_encode_client_error(self): 67 data = packets.encode_client_error(_TEST_REQUEST, Status.NOT_FOUND) 68 69 packet = RpcPacket() 70 packet.ParseFromString(data) 71 72 self.assertEqual( 73 packet, 74 RpcPacket(type=PacketType.CLIENT_ERROR, 75 channel_id=1, 76 service_id=2, 77 method_id=3, 78 status=Status.NOT_FOUND.value)) 79 80 def test_decode(self): 81 self.assertEqual(_TEST_REQUEST, 82 packets.decode(_TEST_REQUEST.SerializeToString())) 83 84 def test_for_server(self): 85 self.assertTrue(packets.for_server(_TEST_REQUEST)) 86 87 self.assertFalse( 88 packets.for_server( 89 RpcPacket(type=PacketType.RESPONSE, 90 channel_id=1, 91 service_id=2, 92 method_id=3, 93 payload=RpcPacket(status=321).SerializeToString()))) 94 95 96if __name__ == '__main__': 97 unittest.main() 98