• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright 2020 The Pigweed Authors
3#
4# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5# use this file except in compliance with the License. You may obtain a copy of
6# the License at
7#
8#     https://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13# License for the specific language governing permissions and limitations under
14# the License.
15"""Tests compiling and importing Python protos on the fly."""
16
17from pathlib import Path
18import tempfile
19import unittest
20
21from pw_protobuf_compiler import python_protos
22from pw_protobuf_compiler.python_protos import bytes_repr, proto_repr
23
24PROTO_1 = """\
25syntax = "proto3";
26
27package pw.protobuf_compiler.test1;
28
29message SomeMessage {
30  uint32 magic_number = 1;
31}
32
33message AnotherMessage {
34  enum Result {
35    FAILED = 0;
36    FAILED_MISERABLY = 1;
37    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
38  }
39
40  Result result = 1;
41  string payload = 2;
42}
43
44service PublicService {
45  rpc Unary(SomeMessage) returns (AnotherMessage) {}
46  rpc ServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
47  rpc ClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
48  rpc BidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
49}
50"""
51
52PROTO_2 = """\
53syntax = "proto2";
54
55package pw.protobuf_compiler.test2;
56
57message Request {
58  optional float magic_number = 1;
59}
60
61message Response {
62}
63
64service Alpha {
65  rpc Unary(Request) returns (Response) {}
66}
67
68service Bravo {
69  rpc BidiStreaming(stream Request) returns (stream Response) {}
70}
71"""
72
73PROTO_3 = """\
74syntax = "proto3";
75
76package pw.protobuf_compiler.test2;
77
78enum Greeting {
79  YO = 0;
80  HI = 1;
81}
82
83message Hello {
84  repeated int64 value = 1;
85  Greeting hi = 2;
86}
87
88message NestingMessage {
89  message NestedMessage {
90    message NestedNestedMessage {
91      int32 nested_nested_field = 1;
92    }
93
94    NestedNestedMessage nested_nested_message = 1;
95  }
96
97  NestedMessage nested_message = 1;
98}
99"""
100
101
102class TestCompileAndImport(unittest.TestCase):
103    """Test compiling and importing."""
104
105    def setUp(self):
106        self._proto_dir = tempfile.TemporaryDirectory(prefix='proto_test')
107        self._protos = []
108
109        for i, contents in enumerate([PROTO_1, PROTO_2, PROTO_3], 1):
110            self._protos.append(Path(self._proto_dir.name, f'test_{i}.proto'))
111            self._protos[-1].write_text(contents)
112
113    def tearDown(self):
114        self._proto_dir.cleanup()
115
116    def test_compile_to_temp_dir_and_import(self):
117        modules = {
118            m.DESCRIPTOR.name: m
119            for m in python_protos.compile_and_import(self._protos)
120        }
121        self.assertEqual(3, len(modules))
122
123        # Make sure the protobuf modules contain what we expect them to.
124        mod = modules['test_1.proto']
125        self.assertEqual(
126            4, len(mod.DESCRIPTOR.services_by_name['PublicService'].methods)
127        )
128
129        mod = modules['test_2.proto']
130        self.assertEqual(mod.Request(magic_number=1.5).magic_number, 1.5)
131        self.assertEqual(2, len(mod.DESCRIPTOR.services_by_name))
132
133        mod = modules['test_3.proto']
134        self.assertEqual(mod.Hello(value=[123, 456]).value, [123, 456])
135
136
137class TestProtoLibrary(TestCompileAndImport):
138    """Tests the Library class."""
139
140    def setUp(self):
141        super().setUp()
142        self._library = python_protos.Library(
143            python_protos.compile_and_import(self._protos)
144        )
145
146    def test_packages_can_access_messages(self):
147        msg = self._library.packages.pw.protobuf_compiler.test1.SomeMessage
148        self.assertEqual(msg(magic_number=123).magic_number, 123)
149
150    def test_packages_finds_across_modules(self):
151        msg = self._library.packages.pw.protobuf_compiler.test2.Request
152        self.assertEqual(msg(magic_number=50).magic_number, 50)
153
154        val = self._library.packages.pw.protobuf_compiler.test2.YO
155        self.assertEqual(val, 0)
156
157    def test_packages_invalid_name(self):
158        with self.assertRaises(AttributeError):
159            _ = self._library.packages.nothing
160
161        with self.assertRaises(AttributeError):
162            _ = self._library.packages.pw.NOT_HERE
163
164        with self.assertRaises(AttributeError):
165            _ = self._library.packages.pw.protobuf_compiler.test1.NotARealMsg
166
167    def test_access_modules_by_package(self):
168        test1 = self._library.modules_by_package['pw.protobuf_compiler.test1']
169        self.assertEqual(len(test1), 1)
170        self.assertEqual(test1[0].AnotherMessage.Result.Value('FAILED'), 0)
171
172        test2 = self._library.modules_by_package['pw.protobuf_compiler.test2']
173        self.assertEqual(len(test2), 2)
174
175    def test_access_modules_by_package_unknown(self):
176        with self.assertRaises(KeyError):
177            _ = self._library.modules_by_package['pw.not_real']
178
179    def test_library_from_strings(self):
180        # Replace the package to avoid conflicts with the other proto imports
181        new_protos = [
182            p.replace('pw.protobuf_compiler', 'proto.library.test')
183            for p in [PROTO_1, PROTO_2, PROTO_3]
184        ]
185
186        library = python_protos.Library.from_strings(new_protos)
187
188        # Make sure we can safely import the same proto contents multiple times.
189        library = python_protos.Library.from_strings(new_protos)
190
191        msg = library.packages.proto.library.test.test2.Request
192        self.assertEqual(msg(magic_number=50).magic_number, 50)
193
194        val = library.packages.proto.library.test.test2.YO
195        self.assertEqual(val, 0)
196
197    def test_access_nested_packages_by_name(self):
198        self.assertIs(
199            self._library.packages['pw.protobuf_compiler.test1'],
200            self._library.packages.pw.protobuf_compiler.test1,
201        )
202        self.assertIs(
203            self._library.packages.pw['protobuf_compiler.test1'],
204            self._library.packages.pw.protobuf_compiler.test1,
205        )
206        self.assertIs(
207            self._library.packages.pw.protobuf_compiler['test1'],
208            self._library.packages.pw.protobuf_compiler.test1,
209        )
210
211    def test_access_nested_packages_by_name_unknown_package(self):
212        with self.assertRaises(KeyError):
213            _ = self._library.packages['']
214
215        with self.assertRaises(KeyError):
216            _ = self._library.packages['.']
217
218        with self.assertRaises(KeyError):
219            _ = self._library.packages['protobuf_compiler.test1']
220
221        with self.assertRaises(KeyError):
222            _ = self._library.packages.pw['pw.protobuf_compiler.test1']
223
224        with self.assertRaises(KeyError):
225            _ = self._library.packages.pw.protobuf_compiler['not here']
226
227    def test_messages(self):
228        protos = self._library.packages.pw.protobuf_compiler
229        self.assertEqual(
230            set(self._library.messages()),
231            {
232                protos.test1.SomeMessage,
233                protos.test1.AnotherMessage,
234                protos.test2.Request,
235                protos.test2.Response,
236                protos.test2.Hello,
237                protos.test2.NestingMessage,
238                protos.test2.NestingMessage.NestedMessage,
239                protos.test2.NestingMessage.NestedMessage.NestedNestedMessage,
240            },
241        )
242
243
244PROTO_FOR_REPR = """\
245syntax = "proto3";
246
247package pw.test3;
248
249enum Enum {
250  ZERO = 0;
251  ONE = 1;
252}
253
254message Nested {
255  repeated int64 value = 1;
256  Enum an_enum = 2;
257}
258
259message Message {
260  Nested message = 1;
261  repeated Nested repeated_message = 2;
262
263  fixed32 regular_int = 3;
264  optional int64 optional_int = 4;
265  repeated int32 repeated_int = 5;
266
267  bytes regular_bytes = 6;
268  optional bytes optional_bytes = 7;
269  repeated bytes repeated_bytes = 8;
270
271  string regular_string = 9;
272  optional string optional_string = 10;
273  repeated string repeated_string = 11;
274
275  Enum regular_enum = 12;
276  optional Enum optional_enum = 13;
277  repeated Enum repeated_enum = 14;
278
279  oneof oneof_test {
280    string oneof_1 = 15;
281    int32 oneof_2 = 16;
282    Nested oneof_3 = 17;
283  }
284
285  map<string, Nested> mapping = 18;
286}
287"""
288
289
290class TestProtoRepr(unittest.TestCase):
291    """Tests printing protobufs."""
292
293    def setUp(self):
294        protos = python_protos.Library.from_strings(PROTO_FOR_REPR)
295        self.enum = protos.packages.pw.test3.Enum
296        self.nested = protos.packages.pw.test3.Nested
297        self.message = protos.packages.pw.test3.Message
298
299    def test_empty(self):
300        self.assertEqual('pw.test3.Nested()', proto_repr(self.nested()))
301        self.assertEqual('pw.test3.Message()', proto_repr(self.message()))
302
303    def test_int_fields(self):
304        self.assertEqual(
305            'pw.test3.Message('
306            'regular_int=999, '
307            'optional_int=-1, '
308            'repeated_int=[0, 1, 2])',
309            proto_repr(
310                self.message(
311                    repeated_int=[0, 1, 2], regular_int=999, optional_int=-1
312                ),
313                wrap=False,
314            ),
315        )
316
317    def test_bytes_fields(self):
318        self.assertEqual(
319            'pw.test3.Message('
320            r"regular_bytes=b'\xFE\xED\xBE\xEF', "
321            r"optional_bytes=b'', "
322            r"repeated_bytes=[b'Hello\'\'\''])",
323            proto_repr(
324                self.message(
325                    regular_bytes=b'\xfe\xed\xbe\xef',
326                    optional_bytes=b'',
327                    repeated_bytes=[b"Hello'''"],
328                ),
329                wrap=False,
330            ),
331        )
332
333    def test_string_fields(self):
334        self.assertEqual(
335            'pw.test3.Message('
336            "regular_string='hi', "
337            "optional_string='', "
338            'repeated_string=["\'"])',
339            proto_repr(
340                self.message(
341                    regular_string='hi',
342                    optional_string='',
343                    repeated_string=[b"'"],
344                ),
345                wrap=False,
346            ),
347        )
348
349    def test_enum_fields(self):
350        self.assertEqual(
351            'pw.test3.Nested(an_enum=pw.test3.Enum.ONE)',
352            proto_repr(self.nested(an_enum=1)),
353        )
354        self.assertEqual(
355            'pw.test3.Message(optional_enum=pw.test3.Enum.ONE)',
356            proto_repr(self.message(optional_enum=self.enum.ONE)),
357        )
358        self.assertEqual(
359            'pw.test3.Message(repeated_enum='
360            '[pw.test3.Enum.ONE, pw.test3.Enum.ONE, pw.test3.Enum.ZERO])',
361            proto_repr(self.message(repeated_enum=[1, 1, 0]), wrap=False),
362        )
363
364    def test_message_fields(self):
365        self.assertEqual(
366            'pw.test3.Message(message=pw.test3.Nested(value=[123]))',
367            proto_repr(self.message(message=self.nested(value=[123]))),
368        )
369        self.assertEqual(
370            'pw.test3.Message('
371            'repeated_message=[pw.test3.Nested(value=[123]), '
372            'pw.test3.Nested()])',
373            proto_repr(
374                self.message(
375                    repeated_message=[self.nested(value=[123]), self.nested()]
376                ),
377                wrap=False,
378            ),
379        )
380
381    def test_optional_shown_if_set_to_default(self):
382        self.assertEqual(
383            "pw.test3.Message("
384            "optional_int=0, optional_bytes=b'', optional_string='', "
385            "optional_enum=pw.test3.Enum.ZERO)",
386            proto_repr(
387                self.message(
388                    optional_int=0,
389                    optional_bytes=b'',
390                    optional_string='',
391                    optional_enum=0,
392                ),
393                wrap=False,
394            ),
395        )
396
397    def test_oneof(self):
398        self.assertEqual(
399            proto_repr(self.message(oneof_1='test')),
400            "pw.test3.Message(oneof_1='test')",
401        )
402        self.assertEqual(
403            proto_repr(self.message(oneof_2=123)),
404            "pw.test3.Message(oneof_2=123)",
405        )
406        self.assertEqual(
407            proto_repr(
408                self.message(oneof_3=self.nested(an_enum=self.enum.ONE))
409            ),
410            'pw.test3.Message('
411            'oneof_3=pw.test3.Nested(an_enum=pw.test3.Enum.ONE))',
412        )
413
414        msg = self.message(oneof_1='test')
415        msg.oneof_2 = 99
416        self.assertEqual(proto_repr(msg), "pw.test3.Message(oneof_2=99)")
417
418    def test_map(self):
419        msg = self.message()
420        msg.mapping['zero'].MergeFrom(self.nested())
421        msg.mapping['one'].MergeFrom(
422            self.nested(an_enum=self.enum.ONE, value=[1])
423        )
424
425        result = proto_repr(msg, wrap=False)
426        self.assertRegex(result, r'^pw.test3.Message\(mapping={.*}\)$')
427        self.assertIn("'zero': pw.test3.Nested()", result)
428        self.assertIn(
429            "'one': pw.test3.Nested(value=[1], an_enum=pw.test3.Enum.ONE)",
430            result,
431        )
432
433    def test_bytes_repr(self):
434        self.assertEqual(
435            bytes_repr(b'\xfe\xed\xbe\xef'), r"b'\xFE\xED\xBE\xEF'"
436        )
437        self.assertEqual(
438            bytes_repr(b'\xfe\xed\xbe\xef123'),
439            r"b'\xFE\xED\xBE\xEF\x31\x32\x33'",
440        )
441        self.assertEqual(
442            bytes_repr(b'\xfe\xed\xbe\xef1234'), r"b'\xFE\xED\xBE\xEF1234'"
443        )
444        self.assertEqual(
445            bytes_repr(b'\xfe\xed\xbe\xef12345'), r"b'\xFE\xED\xBE\xEF12345'"
446        )
447
448    def test_wrap_multiple_lines(self):
449        self.assertEqual(
450            """\
451pw.test3.Message(
452    optional_int=0,
453    optional_bytes=b'',
454    optional_string='',
455    optional_enum=pw.test3.Enum.ZERO,
456)""",
457            proto_repr(
458                self.message(
459                    optional_int=0,
460                    optional_bytes=b'',
461                    optional_string='',
462                    optional_enum=0,
463                ),
464                wrap=True,
465            ),
466        )
467
468    def test_wrap_one_line(self):
469        self.assertEqual(
470            "pw.test3.Message(optional_int=0, optional_bytes=b'')",
471            proto_repr(
472                self.message(optional_int=0, optional_bytes=b''), wrap=True
473            ),
474        )
475
476
477if __name__ == '__main__':
478    unittest.main()
479