• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests creating pw_rpc client."""
16
17import unittest
18
19from pw_status import Status
20
21from pw_rpc.internal.packet_pb2 import PacketType, RpcPacket
22from pw_rpc import packets
23
24_TEST_REQUEST = RpcPacket(
25    type=PacketType.REQUEST,
26    channel_id=1,
27    service_id=2,
28    method_id=3,
29    payload=RpcPacket(status=321).SerializeToString(),
30)
31
32
33class PacketsTest(unittest.TestCase):
34    """Tests for packet encoding and decoding."""
35
36    def test_encode_request(self):
37        data = packets.encode_request((1, 2, 3), RpcPacket(status=321))
38        packet = RpcPacket()
39        packet.ParseFromString(data)
40
41        self.assertEqual(_TEST_REQUEST, packet)
42
43    def test_encode_response(self):
44        response = RpcPacket(
45            type=PacketType.RESPONSE,
46            channel_id=1,
47            service_id=2,
48            method_id=3,
49            payload=RpcPacket(status=321).SerializeToString(),
50        )
51
52        data = packets.encode_response((1, 2, 3), RpcPacket(status=321))
53        packet = RpcPacket()
54        packet.ParseFromString(data)
55
56        self.assertEqual(response, packet)
57
58    def test_encode_cancel(self):
59        data = packets.encode_cancel((9, 8, 7))
60
61        packet = RpcPacket()
62        packet.ParseFromString(data)
63
64        self.assertEqual(
65            packet,
66            RpcPacket(
67                type=PacketType.CLIENT_ERROR,
68                channel_id=9,
69                service_id=8,
70                method_id=7,
71                status=Status.CANCELLED.value,
72            ),
73        )
74
75    def test_encode_client_error(self):
76        data = packets.encode_client_error(_TEST_REQUEST, Status.NOT_FOUND)
77
78        packet = RpcPacket()
79        packet.ParseFromString(data)
80
81        self.assertEqual(
82            packet,
83            RpcPacket(
84                type=PacketType.CLIENT_ERROR,
85                channel_id=1,
86                service_id=2,
87                method_id=3,
88                status=Status.NOT_FOUND.value,
89            ),
90        )
91
92    def test_decode(self):
93        self.assertEqual(
94            _TEST_REQUEST, packets.decode(_TEST_REQUEST.SerializeToString())
95        )
96
97    def test_for_server(self):
98        self.assertTrue(packets.for_server(_TEST_REQUEST))
99
100        self.assertFalse(
101            packets.for_server(
102                RpcPacket(
103                    type=PacketType.RESPONSE,
104                    channel_id=1,
105                    service_id=2,
106                    method_id=3,
107                    payload=RpcPacket(status=321).SerializeToString(),
108                )
109            )
110        )
111
112
113if __name__ == '__main__':
114    unittest.main()
115