1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests creating pw_rpc client.""" 16 17import unittest 18 19from pw_status import Status 20 21from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket 22from pw_rpc import packets 23 24_TEST_REQUEST = RpcPacket( 25 type=PacketType.REQUEST, 26 channel_id=1, 27 service_id=2, 28 method_id=3, 29 payload=RpcPacket(status=321).SerializeToString(), 30) 31 32 33class PacketsTest(unittest.TestCase): 34 """Tests for packet encoding and decoding.""" 35 36 def test_encode_request(self): 37 data = packets.encode_request((1, 2, 3), RpcPacket(status=321)) 38 packet = RpcPacket() 39 packet.ParseFromString(data) 40 41 self.assertEqual(_TEST_REQUEST, packet) 42 43 def test_encode_response(self): 44 response = RpcPacket( 45 type=PacketType.RESPONSE, 46 channel_id=1, 47 service_id=2, 48 method_id=3, 49 payload=RpcPacket(status=321).SerializeToString(), 50 ) 51 52 data = packets.encode_response((1, 2, 3), RpcPacket(status=321)) 53 packet = RpcPacket() 54 packet.ParseFromString(data) 55 56 self.assertEqual(response, packet) 57 58 def test_encode_cancel(self): 59 data = packets.encode_cancel((9, 8, 7)) 60 61 packet = RpcPacket() 62 packet.ParseFromString(data) 63 64 self.assertEqual( 65 packet, 66 RpcPacket( 67 type=PacketType.CLIENT_ERROR, 68 channel_id=9, 69 service_id=8, 70 method_id=7, 71 status=Status.CANCELLED.value, 72 ), 73 ) 74 75 def test_encode_client_error(self): 76 data = packets.encode_client_error(_TEST_REQUEST, Status.NOT_FOUND) 77 78 packet = RpcPacket() 79 packet.ParseFromString(data) 80 81 self.assertEqual( 82 packet, 83 RpcPacket( 84 type=PacketType.CLIENT_ERROR, 85 channel_id=1, 86 service_id=2, 87 method_id=3, 88 status=Status.NOT_FOUND.value, 89 ), 90 ) 91 92 def test_decode(self): 93 self.assertEqual( 94 _TEST_REQUEST, packets.decode(_TEST_REQUEST.SerializeToString()) 95 ) 96 97 def test_for_server(self): 98 self.assertTrue(packets.for_server(_TEST_REQUEST)) 99 100 self.assertFalse( 101 packets.for_server( 102 RpcPacket( 103 type=PacketType.RESPONSE, 104 channel_id=1, 105 service_id=2, 106 method_id=3, 107 payload=RpcPacket(status=321).SerializeToString(), 108 ) 109 ) 110 ) 111 112 113if __name__ == '__main__': 114 unittest.main() 115