1#!/usr/bin/env python3 2# Copyright 2020 The Pigweed Authors 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); you may not 5# use this file except in compliance with the License. You may obtain a copy of 6# the License at 7# 8# https://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13# License for the specific language governing permissions and limitations under 14# the License. 15"""Tests compiling and importing Python protos on the fly.""" 16 17from pathlib import Path 18import tempfile 19import unittest 20 21from pw_protobuf_compiler import python_protos 22from pw_protobuf_compiler.python_protos import bytes_repr, proto_repr 23 24PROTO_1 = """\ 25syntax = "proto3"; 26 27package pw.protobuf_compiler.test1; 28 29message SomeMessage { 30 uint32 magic_number = 1; 31} 32 33message AnotherMessage { 34 enum Result { 35 FAILED = 0; 36 FAILED_MISERABLY = 1; 37 I_DONT_WANT_TO_TALK_ABOUT_IT = 2; 38 } 39 40 Result result = 1; 41 string payload = 2; 42} 43 44service PublicService { 45 rpc Unary(SomeMessage) returns (AnotherMessage) {} 46 rpc ServerStreaming(SomeMessage) returns (stream AnotherMessage) {} 47 rpc ClientStreaming(stream SomeMessage) returns (AnotherMessage) {} 48 rpc BidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {} 49} 50""" 51 52PROTO_2 = """\ 53syntax = "proto2"; 54 55package pw.protobuf_compiler.test2; 56 57message Request { 58 optional float magic_number = 1; 59} 60 61message Response { 62} 63 64service Alpha { 65 rpc Unary(Request) returns (Response) {} 66} 67 68service Bravo { 69 rpc BidiStreaming(stream Request) returns (stream Response) {} 70} 71""" 72 73PROTO_3 = """\ 74syntax = "proto3"; 75 76package pw.protobuf_compiler.test2; 77 78enum Greeting { 79 YO = 0; 80 HI = 1; 81} 82 83message Hello { 84 repeated int64 value = 1; 85 Greeting hi = 2; 86} 87 88message NestingMessage { 89 message NestedMessage { 90 message NestedNestedMessage { 91 int32 nested_nested_field = 1; 92 } 93 94 NestedNestedMessage nested_nested_message = 1; 95 } 96 97 NestedMessage nested_message = 1; 98} 99""" 100 101 102class TestCompileAndImport(unittest.TestCase): 103 """Test compiling and importing.""" 104 105 def setUp(self): 106 self._proto_dir = tempfile.TemporaryDirectory(prefix='proto_test') 107 self._protos = [] 108 109 for i, contents in enumerate([PROTO_1, PROTO_2, PROTO_3], 1): 110 self._protos.append(Path(self._proto_dir.name, f'test_{i}.proto')) 111 self._protos[-1].write_text(contents) 112 113 def tearDown(self): 114 self._proto_dir.cleanup() 115 116 def test_compile_to_temp_dir_and_import(self): 117 modules = { 118 m.DESCRIPTOR.name: m 119 for m in python_protos.compile_and_import(self._protos) 120 } 121 self.assertEqual(3, len(modules)) 122 123 # Make sure the protobuf modules contain what we expect them to. 124 mod = modules['test_1.proto'] 125 self.assertEqual( 126 4, len(mod.DESCRIPTOR.services_by_name['PublicService'].methods) 127 ) 128 129 mod = modules['test_2.proto'] 130 self.assertEqual(mod.Request(magic_number=1.5).magic_number, 1.5) 131 self.assertEqual(2, len(mod.DESCRIPTOR.services_by_name)) 132 133 mod = modules['test_3.proto'] 134 self.assertEqual(mod.Hello(value=[123, 456]).value, [123, 456]) 135 136 137class TestProtoLibrary(TestCompileAndImport): 138 """Tests the Library class.""" 139 140 def setUp(self): 141 super().setUp() 142 self._library = python_protos.Library( 143 python_protos.compile_and_import(self._protos) 144 ) 145 146 def test_packages_can_access_messages(self): 147 msg = self._library.packages.pw.protobuf_compiler.test1.SomeMessage 148 self.assertEqual(msg(magic_number=123).magic_number, 123) 149 150 def test_packages_finds_across_modules(self): 151 msg = self._library.packages.pw.protobuf_compiler.test2.Request 152 self.assertEqual(msg(magic_number=50).magic_number, 50) 153 154 val = self._library.packages.pw.protobuf_compiler.test2.YO 155 self.assertEqual(val, 0) 156 157 def test_packages_invalid_name(self): 158 with self.assertRaises(AttributeError): 159 _ = self._library.packages.nothing 160 161 with self.assertRaises(AttributeError): 162 _ = self._library.packages.pw.NOT_HERE 163 164 with self.assertRaises(AttributeError): 165 _ = self._library.packages.pw.protobuf_compiler.test1.NotARealMsg 166 167 def test_access_modules_by_package(self): 168 test1 = self._library.modules_by_package['pw.protobuf_compiler.test1'] 169 self.assertEqual(len(test1), 1) 170 self.assertEqual(test1[0].AnotherMessage.Result.Value('FAILED'), 0) 171 172 test2 = self._library.modules_by_package['pw.protobuf_compiler.test2'] 173 self.assertEqual(len(test2), 2) 174 175 def test_access_modules_by_package_unknown(self): 176 with self.assertRaises(KeyError): 177 _ = self._library.modules_by_package['pw.not_real'] 178 179 def test_library_from_strings(self): 180 # Replace the package to avoid conflicts with the other proto imports 181 new_protos = [ 182 p.replace('pw.protobuf_compiler', 'proto.library.test') 183 for p in [PROTO_1, PROTO_2, PROTO_3] 184 ] 185 186 library = python_protos.Library.from_strings(new_protos) 187 188 # Make sure we can safely import the same proto contents multiple times. 189 library = python_protos.Library.from_strings(new_protos) 190 191 msg = library.packages.proto.library.test.test2.Request 192 self.assertEqual(msg(magic_number=50).magic_number, 50) 193 194 val = library.packages.proto.library.test.test2.YO 195 self.assertEqual(val, 0) 196 197 def test_access_nested_packages_by_name(self): 198 self.assertIs( 199 self._library.packages['pw.protobuf_compiler.test1'], 200 self._library.packages.pw.protobuf_compiler.test1, 201 ) 202 self.assertIs( 203 self._library.packages.pw['protobuf_compiler.test1'], 204 self._library.packages.pw.protobuf_compiler.test1, 205 ) 206 self.assertIs( 207 self._library.packages.pw.protobuf_compiler['test1'], 208 self._library.packages.pw.protobuf_compiler.test1, 209 ) 210 211 def test_access_nested_packages_by_name_unknown_package(self): 212 with self.assertRaises(KeyError): 213 _ = self._library.packages[''] 214 215 with self.assertRaises(KeyError): 216 _ = self._library.packages['.'] 217 218 with self.assertRaises(KeyError): 219 _ = self._library.packages['protobuf_compiler.test1'] 220 221 with self.assertRaises(KeyError): 222 _ = self._library.packages.pw['pw.protobuf_compiler.test1'] 223 224 with self.assertRaises(KeyError): 225 _ = self._library.packages.pw.protobuf_compiler['not here'] 226 227 def test_messages(self): 228 protos = self._library.packages.pw.protobuf_compiler 229 self.assertEqual( 230 set(self._library.messages()), 231 { 232 protos.test1.SomeMessage, 233 protos.test1.AnotherMessage, 234 protos.test2.Request, 235 protos.test2.Response, 236 protos.test2.Hello, 237 protos.test2.NestingMessage, 238 protos.test2.NestingMessage.NestedMessage, 239 protos.test2.NestingMessage.NestedMessage.NestedNestedMessage, 240 }, 241 ) 242 243 244PROTO_FOR_REPR = """\ 245syntax = "proto3"; 246 247package pw.test3; 248 249enum Enum { 250 ZERO = 0; 251 ONE = 1; 252} 253 254message Nested { 255 repeated int64 value = 1; 256 Enum an_enum = 2; 257} 258 259message Message { 260 Nested message = 1; 261 repeated Nested repeated_message = 2; 262 263 fixed32 regular_int = 3; 264 optional int64 optional_int = 4; 265 repeated int32 repeated_int = 5; 266 267 bytes regular_bytes = 6; 268 optional bytes optional_bytes = 7; 269 repeated bytes repeated_bytes = 8; 270 271 string regular_string = 9; 272 optional string optional_string = 10; 273 repeated string repeated_string = 11; 274 275 Enum regular_enum = 12; 276 optional Enum optional_enum = 13; 277 repeated Enum repeated_enum = 14; 278 279 oneof oneof_test { 280 string oneof_1 = 15; 281 int32 oneof_2 = 16; 282 Nested oneof_3 = 17; 283 } 284 285 map<string, Nested> mapping = 18; 286} 287""" 288 289 290class TestProtoRepr(unittest.TestCase): 291 """Tests printing protobufs.""" 292 293 def setUp(self): 294 protos = python_protos.Library.from_strings(PROTO_FOR_REPR) 295 self.enum = protos.packages.pw.test3.Enum 296 self.nested = protos.packages.pw.test3.Nested 297 self.message = protos.packages.pw.test3.Message 298 299 def test_empty(self): 300 self.assertEqual('pw.test3.Nested()', proto_repr(self.nested())) 301 self.assertEqual('pw.test3.Message()', proto_repr(self.message())) 302 303 def test_int_fields(self): 304 self.assertEqual( 305 'pw.test3.Message(' 306 'regular_int=999, ' 307 'optional_int=-1, ' 308 'repeated_int=[0, 1, 2])', 309 proto_repr( 310 self.message( 311 repeated_int=[0, 1, 2], regular_int=999, optional_int=-1 312 ), 313 wrap=False, 314 ), 315 ) 316 317 def test_bytes_fields(self): 318 self.assertEqual( 319 'pw.test3.Message(' 320 r"regular_bytes=b'\xFE\xED\xBE\xEF', " 321 r"optional_bytes=b'', " 322 r"repeated_bytes=[b'Hello\'\'\''])", 323 proto_repr( 324 self.message( 325 regular_bytes=b'\xfe\xed\xbe\xef', 326 optional_bytes=b'', 327 repeated_bytes=[b"Hello'''"], 328 ), 329 wrap=False, 330 ), 331 ) 332 333 def test_string_fields(self): 334 self.assertEqual( 335 'pw.test3.Message(' 336 "regular_string='hi', " 337 "optional_string='', " 338 'repeated_string=["\'"])', 339 proto_repr( 340 self.message( 341 regular_string='hi', 342 optional_string='', 343 repeated_string=[b"'"], 344 ), 345 wrap=False, 346 ), 347 ) 348 349 def test_enum_fields(self): 350 self.assertEqual( 351 'pw.test3.Nested(an_enum=pw.test3.Enum.ONE)', 352 proto_repr(self.nested(an_enum=1)), 353 ) 354 self.assertEqual( 355 'pw.test3.Message(optional_enum=pw.test3.Enum.ONE)', 356 proto_repr(self.message(optional_enum=self.enum.ONE)), 357 ) 358 self.assertEqual( 359 'pw.test3.Message(repeated_enum=' 360 '[pw.test3.Enum.ONE, pw.test3.Enum.ONE, pw.test3.Enum.ZERO])', 361 proto_repr(self.message(repeated_enum=[1, 1, 0]), wrap=False), 362 ) 363 364 def test_message_fields(self): 365 self.assertEqual( 366 'pw.test3.Message(message=pw.test3.Nested(value=[123]))', 367 proto_repr(self.message(message=self.nested(value=[123]))), 368 ) 369 self.assertEqual( 370 'pw.test3.Message(' 371 'repeated_message=[pw.test3.Nested(value=[123]), ' 372 'pw.test3.Nested()])', 373 proto_repr( 374 self.message( 375 repeated_message=[self.nested(value=[123]), self.nested()] 376 ), 377 wrap=False, 378 ), 379 ) 380 381 def test_optional_shown_if_set_to_default(self): 382 self.assertEqual( 383 "pw.test3.Message(" 384 "optional_int=0, optional_bytes=b'', optional_string='', " 385 "optional_enum=pw.test3.Enum.ZERO)", 386 proto_repr( 387 self.message( 388 optional_int=0, 389 optional_bytes=b'', 390 optional_string='', 391 optional_enum=0, 392 ), 393 wrap=False, 394 ), 395 ) 396 397 def test_oneof(self): 398 self.assertEqual( 399 proto_repr(self.message(oneof_1='test')), 400 "pw.test3.Message(oneof_1='test')", 401 ) 402 self.assertEqual( 403 proto_repr(self.message(oneof_2=123)), 404 "pw.test3.Message(oneof_2=123)", 405 ) 406 self.assertEqual( 407 proto_repr( 408 self.message(oneof_3=self.nested(an_enum=self.enum.ONE)) 409 ), 410 'pw.test3.Message(' 411 'oneof_3=pw.test3.Nested(an_enum=pw.test3.Enum.ONE))', 412 ) 413 414 msg = self.message(oneof_1='test') 415 msg.oneof_2 = 99 416 self.assertEqual(proto_repr(msg), "pw.test3.Message(oneof_2=99)") 417 418 def test_map(self): 419 msg = self.message() 420 msg.mapping['zero'].MergeFrom(self.nested()) 421 msg.mapping['one'].MergeFrom( 422 self.nested(an_enum=self.enum.ONE, value=[1]) 423 ) 424 425 result = proto_repr(msg, wrap=False) 426 self.assertRegex(result, r'^pw.test3.Message\(mapping={.*}\)$') 427 self.assertIn("'zero': pw.test3.Nested()", result) 428 self.assertIn( 429 "'one': pw.test3.Nested(value=[1], an_enum=pw.test3.Enum.ONE)", 430 result, 431 ) 432 433 def test_bytes_repr(self): 434 self.assertEqual( 435 bytes_repr(b'\xfe\xed\xbe\xef'), r"b'\xFE\xED\xBE\xEF'" 436 ) 437 self.assertEqual( 438 bytes_repr(b'\xfe\xed\xbe\xef123'), 439 r"b'\xFE\xED\xBE\xEF\x31\x32\x33'", 440 ) 441 self.assertEqual( 442 bytes_repr(b'\xfe\xed\xbe\xef1234'), r"b'\xFE\xED\xBE\xEF1234'" 443 ) 444 self.assertEqual( 445 bytes_repr(b'\xfe\xed\xbe\xef12345'), r"b'\xFE\xED\xBE\xEF12345'" 446 ) 447 448 def test_wrap_multiple_lines(self): 449 self.assertEqual( 450 """\ 451pw.test3.Message( 452 optional_int=0, 453 optional_bytes=b'', 454 optional_string='', 455 optional_enum=pw.test3.Enum.ZERO, 456)""", 457 proto_repr( 458 self.message( 459 optional_int=0, 460 optional_bytes=b'', 461 optional_string='', 462 optional_enum=0, 463 ), 464 wrap=True, 465 ), 466 ) 467 468 def test_wrap_one_line(self): 469 self.assertEqual( 470 "pw.test3.Message(optional_int=0, optional_bytes=b'')", 471 proto_repr( 472 self.message(optional_int=0, optional_bytes=b''), wrap=True 473 ), 474 ) 475 476 477if __name__ == '__main__': 478 unittest.main() 479