1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# https://developers.google.com/protocol-buffers/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Unittest for reflection.py, which also indirectly tests the output of the 35pure-Python protocol compiler. 36""" 37 38import copy 39import gc 40import operator 41import six 42import struct 43 44try: 45 import unittest2 as unittest #PY26 46except ImportError: 47 import unittest 48 49from google.protobuf import unittest_import_pb2 50from google.protobuf import unittest_mset_pb2 51from google.protobuf import unittest_pb2 52from google.protobuf import descriptor_pb2 53from google.protobuf import descriptor 54from google.protobuf import message 55from google.protobuf import reflection 56from google.protobuf import text_format 57from google.protobuf.internal import api_implementation 58from google.protobuf.internal import more_extensions_pb2 59from google.protobuf.internal import more_messages_pb2 60from google.protobuf.internal import message_set_extensions_pb2 61from google.protobuf.internal import wire_format 62from google.protobuf.internal import test_util 63from google.protobuf.internal import decoder 64 65 66class _MiniDecoder(object): 67 """Decodes a stream of values from a string. 68 69 Once upon a time we actually had a class called decoder.Decoder. Then we 70 got rid of it during a redesign that made decoding much, much faster overall. 71 But a couple tests in this file used it to check that the serialized form of 72 a message was correct. So, this class implements just the methods that were 73 used by said tests, so that we don't have to rewrite the tests. 74 """ 75 76 def __init__(self, bytes): 77 self._bytes = bytes 78 self._pos = 0 79 80 def ReadVarint(self): 81 result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) 82 return result 83 84 ReadInt32 = ReadVarint 85 ReadInt64 = ReadVarint 86 ReadUInt32 = ReadVarint 87 ReadUInt64 = ReadVarint 88 89 def ReadSInt64(self): 90 return wire_format.ZigZagDecode(self.ReadVarint()) 91 92 ReadSInt32 = ReadSInt64 93 94 def ReadFieldNumberAndWireType(self): 95 return wire_format.UnpackTag(self.ReadVarint()) 96 97 def ReadFloat(self): 98 result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0] 99 self._pos += 4 100 return result 101 102 def ReadDouble(self): 103 result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0] 104 self._pos += 8 105 return result 106 107 def EndOfStream(self): 108 return self._pos == len(self._bytes) 109 110 111class ReflectionTest(unittest.TestCase): 112 113 def assertListsEqual(self, values, others): 114 self.assertEqual(len(values), len(others)) 115 for i in range(len(values)): 116 self.assertEqual(values[i], others[i]) 117 118 def testScalarConstructor(self): 119 # Constructor with only scalar types should succeed. 120 proto = unittest_pb2.TestAllTypes( 121 optional_int32=24, 122 optional_double=54.321, 123 optional_string='optional_string', 124 optional_float=None) 125 126 self.assertEqual(24, proto.optional_int32) 127 self.assertEqual(54.321, proto.optional_double) 128 self.assertEqual('optional_string', proto.optional_string) 129 self.assertFalse(proto.HasField("optional_float")) 130 131 def testRepeatedScalarConstructor(self): 132 # Constructor with only repeated scalar types should succeed. 133 proto = unittest_pb2.TestAllTypes( 134 repeated_int32=[1, 2, 3, 4], 135 repeated_double=[1.23, 54.321], 136 repeated_bool=[True, False, False], 137 repeated_string=["optional_string"], 138 repeated_float=None) 139 140 self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) 141 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 142 self.assertEqual([True, False, False], list(proto.repeated_bool)) 143 self.assertEqual(["optional_string"], list(proto.repeated_string)) 144 self.assertEqual([], list(proto.repeated_float)) 145 146 def testRepeatedCompositeConstructor(self): 147 # Constructor with only repeated composite types should succeed. 148 proto = unittest_pb2.TestAllTypes( 149 repeated_nested_message=[ 150 unittest_pb2.TestAllTypes.NestedMessage( 151 bb=unittest_pb2.TestAllTypes.FOO), 152 unittest_pb2.TestAllTypes.NestedMessage( 153 bb=unittest_pb2.TestAllTypes.BAR)], 154 repeated_foreign_message=[ 155 unittest_pb2.ForeignMessage(c=-43), 156 unittest_pb2.ForeignMessage(c=45324), 157 unittest_pb2.ForeignMessage(c=12)], 158 repeatedgroup=[ 159 unittest_pb2.TestAllTypes.RepeatedGroup(), 160 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 161 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) 162 163 self.assertEqual( 164 [unittest_pb2.TestAllTypes.NestedMessage( 165 bb=unittest_pb2.TestAllTypes.FOO), 166 unittest_pb2.TestAllTypes.NestedMessage( 167 bb=unittest_pb2.TestAllTypes.BAR)], 168 list(proto.repeated_nested_message)) 169 self.assertEqual( 170 [unittest_pb2.ForeignMessage(c=-43), 171 unittest_pb2.ForeignMessage(c=45324), 172 unittest_pb2.ForeignMessage(c=12)], 173 list(proto.repeated_foreign_message)) 174 self.assertEqual( 175 [unittest_pb2.TestAllTypes.RepeatedGroup(), 176 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 177 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], 178 list(proto.repeatedgroup)) 179 180 def testMixedConstructor(self): 181 # Constructor with only mixed types should succeed. 182 proto = unittest_pb2.TestAllTypes( 183 optional_int32=24, 184 optional_string='optional_string', 185 repeated_double=[1.23, 54.321], 186 repeated_bool=[True, False, False], 187 repeated_nested_message=[ 188 unittest_pb2.TestAllTypes.NestedMessage( 189 bb=unittest_pb2.TestAllTypes.FOO), 190 unittest_pb2.TestAllTypes.NestedMessage( 191 bb=unittest_pb2.TestAllTypes.BAR)], 192 repeated_foreign_message=[ 193 unittest_pb2.ForeignMessage(c=-43), 194 unittest_pb2.ForeignMessage(c=45324), 195 unittest_pb2.ForeignMessage(c=12)], 196 optional_nested_message=None) 197 198 self.assertEqual(24, proto.optional_int32) 199 self.assertEqual('optional_string', proto.optional_string) 200 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 201 self.assertEqual([True, False, False], list(proto.repeated_bool)) 202 self.assertEqual( 203 [unittest_pb2.TestAllTypes.NestedMessage( 204 bb=unittest_pb2.TestAllTypes.FOO), 205 unittest_pb2.TestAllTypes.NestedMessage( 206 bb=unittest_pb2.TestAllTypes.BAR)], 207 list(proto.repeated_nested_message)) 208 self.assertEqual( 209 [unittest_pb2.ForeignMessage(c=-43), 210 unittest_pb2.ForeignMessage(c=45324), 211 unittest_pb2.ForeignMessage(c=12)], 212 list(proto.repeated_foreign_message)) 213 self.assertFalse(proto.HasField("optional_nested_message")) 214 215 def testConstructorTypeError(self): 216 self.assertRaises( 217 TypeError, unittest_pb2.TestAllTypes, optional_int32="foo") 218 self.assertRaises( 219 TypeError, unittest_pb2.TestAllTypes, optional_string=1234) 220 self.assertRaises( 221 TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234) 222 self.assertRaises( 223 TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234) 224 self.assertRaises( 225 TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"]) 226 self.assertRaises( 227 TypeError, unittest_pb2.TestAllTypes, repeated_string=1234) 228 self.assertRaises( 229 TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234]) 230 self.assertRaises( 231 TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234) 232 self.assertRaises( 233 TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234]) 234 235 def testConstructorInvalidatesCachedByteSize(self): 236 message = unittest_pb2.TestAllTypes(optional_int32 = 12) 237 self.assertEqual(2, message.ByteSize()) 238 239 message = unittest_pb2.TestAllTypes( 240 optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) 241 self.assertEqual(3, message.ByteSize()) 242 243 message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) 244 self.assertEqual(3, message.ByteSize()) 245 246 message = unittest_pb2.TestAllTypes( 247 repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) 248 self.assertEqual(3, message.ByteSize()) 249 250 def testSimpleHasBits(self): 251 # Test a scalar. 252 proto = unittest_pb2.TestAllTypes() 253 self.assertTrue(not proto.HasField('optional_int32')) 254 self.assertEqual(0, proto.optional_int32) 255 # HasField() shouldn't be true if all we've done is 256 # read the default value. 257 self.assertTrue(not proto.HasField('optional_int32')) 258 proto.optional_int32 = 1 259 # Setting a value however *should* set the "has" bit. 260 self.assertTrue(proto.HasField('optional_int32')) 261 proto.ClearField('optional_int32') 262 # And clearing that value should unset the "has" bit. 263 self.assertTrue(not proto.HasField('optional_int32')) 264 265 def testHasBitsWithSinglyNestedScalar(self): 266 # Helper used to test foreign messages and groups. 267 # 268 # composite_field_name should be the name of a non-repeated 269 # composite (i.e., foreign or group) field in TestAllTypes, 270 # and scalar_field_name should be the name of an integer-valued 271 # scalar field within that composite. 272 # 273 # I never thought I'd miss C++ macros and templates so much. :( 274 # This helper is semantically just: 275 # 276 # assert proto.composite_field.scalar_field == 0 277 # assert not proto.composite_field.HasField('scalar_field') 278 # assert not proto.HasField('composite_field') 279 # 280 # proto.composite_field.scalar_field = 10 281 # old_composite_field = proto.composite_field 282 # 283 # assert proto.composite_field.scalar_field == 10 284 # assert proto.composite_field.HasField('scalar_field') 285 # assert proto.HasField('composite_field') 286 # 287 # proto.ClearField('composite_field') 288 # 289 # assert not proto.composite_field.HasField('scalar_field') 290 # assert not proto.HasField('composite_field') 291 # assert proto.composite_field.scalar_field == 0 292 # 293 # # Now ensure that ClearField('composite_field') disconnected 294 # # the old field object from the object tree... 295 # assert old_composite_field is not proto.composite_field 296 # old_composite_field.scalar_field = 20 297 # assert not proto.composite_field.HasField('scalar_field') 298 # assert not proto.HasField('composite_field') 299 def TestCompositeHasBits(composite_field_name, scalar_field_name): 300 proto = unittest_pb2.TestAllTypes() 301 # First, check that we can get the scalar value, and see that it's the 302 # default (0), but that proto.HasField('omposite') and 303 # proto.composite.HasField('scalar') will still return False. 304 composite_field = getattr(proto, composite_field_name) 305 original_scalar_value = getattr(composite_field, scalar_field_name) 306 self.assertEqual(0, original_scalar_value) 307 # Assert that the composite object does not "have" the scalar. 308 self.assertTrue(not composite_field.HasField(scalar_field_name)) 309 # Assert that proto does not "have" the composite field. 310 self.assertTrue(not proto.HasField(composite_field_name)) 311 312 # Now set the scalar within the composite field. Ensure that the setting 313 # is reflected, and that proto.HasField('composite') and 314 # proto.composite.HasField('scalar') now both return True. 315 new_val = 20 316 setattr(composite_field, scalar_field_name, new_val) 317 self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) 318 # Hold on to a reference to the current composite_field object. 319 old_composite_field = composite_field 320 # Assert that the has methods now return true. 321 self.assertTrue(composite_field.HasField(scalar_field_name)) 322 self.assertTrue(proto.HasField(composite_field_name)) 323 324 # Now call the clear method... 325 proto.ClearField(composite_field_name) 326 327 # ...and ensure that the "has" bits are all back to False... 328 composite_field = getattr(proto, composite_field_name) 329 self.assertTrue(not composite_field.HasField(scalar_field_name)) 330 self.assertTrue(not proto.HasField(composite_field_name)) 331 # ...and ensure that the scalar field has returned to its default. 332 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 333 334 self.assertTrue(old_composite_field is not composite_field) 335 setattr(old_composite_field, scalar_field_name, new_val) 336 self.assertTrue(not composite_field.HasField(scalar_field_name)) 337 self.assertTrue(not proto.HasField(composite_field_name)) 338 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 339 340 # Test simple, single-level nesting when we set a scalar. 341 TestCompositeHasBits('optionalgroup', 'a') 342 TestCompositeHasBits('optional_nested_message', 'bb') 343 TestCompositeHasBits('optional_foreign_message', 'c') 344 TestCompositeHasBits('optional_import_message', 'd') 345 346 def testReferencesToNestedMessage(self): 347 proto = unittest_pb2.TestAllTypes() 348 nested = proto.optional_nested_message 349 del proto 350 # A previous version had a bug where this would raise an exception when 351 # hitting a now-dead weak reference. 352 nested.bb = 23 353 354 def testDisconnectingNestedMessageBeforeSettingField(self): 355 proto = unittest_pb2.TestAllTypes() 356 nested = proto.optional_nested_message 357 proto.ClearField('optional_nested_message') # Should disconnect from parent 358 self.assertTrue(nested is not proto.optional_nested_message) 359 nested.bb = 23 360 self.assertTrue(not proto.HasField('optional_nested_message')) 361 self.assertEqual(0, proto.optional_nested_message.bb) 362 363 def testGetDefaultMessageAfterDisconnectingDefaultMessage(self): 364 proto = unittest_pb2.TestAllTypes() 365 nested = proto.optional_nested_message 366 proto.ClearField('optional_nested_message') 367 del proto 368 del nested 369 # Force a garbage collect so that the underlying CMessages are freed along 370 # with the Messages they point to. This is to make sure we're not deleting 371 # default message instances. 372 gc.collect() 373 proto = unittest_pb2.TestAllTypes() 374 nested = proto.optional_nested_message 375 376 def testDisconnectingNestedMessageAfterSettingField(self): 377 proto = unittest_pb2.TestAllTypes() 378 nested = proto.optional_nested_message 379 nested.bb = 5 380 self.assertTrue(proto.HasField('optional_nested_message')) 381 proto.ClearField('optional_nested_message') # Should disconnect from parent 382 self.assertEqual(5, nested.bb) 383 self.assertEqual(0, proto.optional_nested_message.bb) 384 self.assertTrue(nested is not proto.optional_nested_message) 385 nested.bb = 23 386 self.assertTrue(not proto.HasField('optional_nested_message')) 387 self.assertEqual(0, proto.optional_nested_message.bb) 388 389 def testDisconnectingNestedMessageBeforeGettingField(self): 390 proto = unittest_pb2.TestAllTypes() 391 self.assertTrue(not proto.HasField('optional_nested_message')) 392 proto.ClearField('optional_nested_message') 393 self.assertTrue(not proto.HasField('optional_nested_message')) 394 395 def testDisconnectingNestedMessageAfterMerge(self): 396 # This test exercises the code path that does not use ReleaseMessage(). 397 # The underlying fear is that if we use ReleaseMessage() incorrectly, 398 # we will have memory leaks. It's hard to check that that doesn't happen, 399 # but at least we can exercise that code path to make sure it works. 400 proto1 = unittest_pb2.TestAllTypes() 401 proto2 = unittest_pb2.TestAllTypes() 402 proto2.optional_nested_message.bb = 5 403 proto1.MergeFrom(proto2) 404 self.assertTrue(proto1.HasField('optional_nested_message')) 405 proto1.ClearField('optional_nested_message') 406 self.assertTrue(not proto1.HasField('optional_nested_message')) 407 408 def testDisconnectingLazyNestedMessage(self): 409 # This test exercises releasing a nested message that is lazy. This test 410 # only exercises real code in the C++ implementation as Python does not 411 # support lazy parsing, but the current C++ implementation results in 412 # memory corruption and a crash. 413 if api_implementation.Type() != 'python': 414 return 415 proto = unittest_pb2.TestAllTypes() 416 proto.optional_lazy_message.bb = 5 417 proto.ClearField('optional_lazy_message') 418 del proto 419 gc.collect() 420 421 def testHasBitsWhenModifyingRepeatedFields(self): 422 # Test nesting when we add an element to a repeated field in a submessage. 423 proto = unittest_pb2.TestNestedMessageHasBits() 424 proto.optional_nested_message.nestedmessage_repeated_int32.append(5) 425 self.assertEqual( 426 [5], proto.optional_nested_message.nestedmessage_repeated_int32) 427 self.assertTrue(proto.HasField('optional_nested_message')) 428 429 # Do the same test, but with a repeated composite field within the 430 # submessage. 431 proto.ClearField('optional_nested_message') 432 self.assertTrue(not proto.HasField('optional_nested_message')) 433 proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() 434 self.assertTrue(proto.HasField('optional_nested_message')) 435 436 def testHasBitsForManyLevelsOfNesting(self): 437 # Test nesting many levels deep. 438 recursive_proto = unittest_pb2.TestMutualRecursionA() 439 self.assertTrue(not recursive_proto.HasField('bb')) 440 self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) 441 self.assertTrue(not recursive_proto.HasField('bb')) 442 recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 443 self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) 444 self.assertTrue(recursive_proto.HasField('bb')) 445 self.assertTrue(recursive_proto.bb.HasField('a')) 446 self.assertTrue(recursive_proto.bb.a.HasField('bb')) 447 self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) 448 self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) 449 self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a')) 450 self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) 451 452 def testSingularListFields(self): 453 proto = unittest_pb2.TestAllTypes() 454 proto.optional_fixed32 = 1 455 proto.optional_int32 = 5 456 proto.optional_string = 'foo' 457 # Access sub-message but don't set it yet. 458 nested_message = proto.optional_nested_message 459 self.assertEqual( 460 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 461 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 462 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], 463 proto.ListFields()) 464 465 proto.optional_nested_message.bb = 123 466 self.assertEqual( 467 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 468 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 469 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), 470 (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], 471 nested_message) ], 472 proto.ListFields()) 473 474 def testRepeatedListFields(self): 475 proto = unittest_pb2.TestAllTypes() 476 proto.repeated_fixed32.append(1) 477 proto.repeated_int32.append(5) 478 proto.repeated_int32.append(11) 479 proto.repeated_string.extend(['foo', 'bar']) 480 proto.repeated_string.extend([]) 481 proto.repeated_string.append('baz') 482 proto.repeated_string.extend(str(x) for x in range(2)) 483 proto.optional_int32 = 21 484 proto.repeated_bool # Access but don't set anything; should not be listed. 485 self.assertEqual( 486 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), 487 (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), 488 (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), 489 (proto.DESCRIPTOR.fields_by_name['repeated_string' ], 490 ['foo', 'bar', 'baz', '0', '1']) ], 491 proto.ListFields()) 492 493 def testSingularListExtensions(self): 494 proto = unittest_pb2.TestAllExtensions() 495 proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 496 proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 497 proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' 498 self.assertEqual( 499 [ (unittest_pb2.optional_int32_extension , 5), 500 (unittest_pb2.optional_fixed32_extension, 1), 501 (unittest_pb2.optional_string_extension , 'foo') ], 502 proto.ListFields()) 503 504 def testRepeatedListExtensions(self): 505 proto = unittest_pb2.TestAllExtensions() 506 proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) 507 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) 508 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) 509 proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') 510 proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') 511 proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') 512 proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 513 self.assertEqual( 514 [ (unittest_pb2.optional_int32_extension , 21), 515 (unittest_pb2.repeated_int32_extension , [5, 11]), 516 (unittest_pb2.repeated_fixed32_extension, [1]), 517 (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], 518 proto.ListFields()) 519 520 def testListFieldsAndExtensions(self): 521 proto = unittest_pb2.TestFieldOrderings() 522 test_util.SetAllFieldsAndExtensions(proto) 523 unittest_pb2.my_extension_int 524 self.assertEqual( 525 [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), 526 (unittest_pb2.my_extension_int , 23), 527 (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), 528 (unittest_pb2.my_extension_string , 'bar'), 529 (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], 530 proto.ListFields()) 531 532 def testDefaultValues(self): 533 proto = unittest_pb2.TestAllTypes() 534 self.assertEqual(0, proto.optional_int32) 535 self.assertEqual(0, proto.optional_int64) 536 self.assertEqual(0, proto.optional_uint32) 537 self.assertEqual(0, proto.optional_uint64) 538 self.assertEqual(0, proto.optional_sint32) 539 self.assertEqual(0, proto.optional_sint64) 540 self.assertEqual(0, proto.optional_fixed32) 541 self.assertEqual(0, proto.optional_fixed64) 542 self.assertEqual(0, proto.optional_sfixed32) 543 self.assertEqual(0, proto.optional_sfixed64) 544 self.assertEqual(0.0, proto.optional_float) 545 self.assertEqual(0.0, proto.optional_double) 546 self.assertEqual(False, proto.optional_bool) 547 self.assertEqual('', proto.optional_string) 548 self.assertEqual(b'', proto.optional_bytes) 549 550 self.assertEqual(41, proto.default_int32) 551 self.assertEqual(42, proto.default_int64) 552 self.assertEqual(43, proto.default_uint32) 553 self.assertEqual(44, proto.default_uint64) 554 self.assertEqual(-45, proto.default_sint32) 555 self.assertEqual(46, proto.default_sint64) 556 self.assertEqual(47, proto.default_fixed32) 557 self.assertEqual(48, proto.default_fixed64) 558 self.assertEqual(49, proto.default_sfixed32) 559 self.assertEqual(-50, proto.default_sfixed64) 560 self.assertEqual(51.5, proto.default_float) 561 self.assertEqual(52e3, proto.default_double) 562 self.assertEqual(True, proto.default_bool) 563 self.assertEqual('hello', proto.default_string) 564 self.assertEqual(b'world', proto.default_bytes) 565 self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) 566 self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) 567 self.assertEqual(unittest_import_pb2.IMPORT_BAR, 568 proto.default_import_enum) 569 570 proto = unittest_pb2.TestExtremeDefaultValues() 571 self.assertEqual(u'\u1234', proto.utf8_string) 572 573 def testHasFieldWithUnknownFieldName(self): 574 proto = unittest_pb2.TestAllTypes() 575 self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') 576 577 def testClearFieldWithUnknownFieldName(self): 578 proto = unittest_pb2.TestAllTypes() 579 self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') 580 581 def testClearRemovesChildren(self): 582 # Make sure there aren't any implementation bugs that are only partially 583 # clearing the message (which can happen in the more complex C++ 584 # implementation which has parallel message lists). 585 proto = unittest_pb2.TestRequiredForeign() 586 for i in range(10): 587 proto.repeated_message.add() 588 proto2 = unittest_pb2.TestRequiredForeign() 589 proto.CopyFrom(proto2) 590 self.assertRaises(IndexError, lambda: proto.repeated_message[5]) 591 592 def testDisallowedAssignments(self): 593 # It's illegal to assign values directly to repeated fields 594 # or to nonrepeated composite fields. Ensure that this fails. 595 proto = unittest_pb2.TestAllTypes() 596 # Repeated fields. 597 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) 598 # Lists shouldn't work, either. 599 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) 600 # Composite fields. 601 self.assertRaises(AttributeError, setattr, proto, 602 'optional_nested_message', 23) 603 # Assignment to a repeated nested message field without specifying 604 # the index in the array of nested messages. 605 self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, 606 'bb', 34) 607 # Assignment to an attribute of a repeated field. 608 self.assertRaises(AttributeError, setattr, proto.repeated_float, 609 'some_attribute', 34) 610 # proto.nonexistent_field = 23 should fail as well. 611 self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) 612 613 def testSingleScalarTypeSafety(self): 614 proto = unittest_pb2.TestAllTypes() 615 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) 616 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') 617 self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) 618 self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) 619 620 def testIntegerTypes(self): 621 def TestGetAndDeserialize(field_name, value, expected_type): 622 proto = unittest_pb2.TestAllTypes() 623 setattr(proto, field_name, value) 624 self.assertIsInstance(getattr(proto, field_name), expected_type) 625 proto2 = unittest_pb2.TestAllTypes() 626 proto2.ParseFromString(proto.SerializeToString()) 627 self.assertIsInstance(getattr(proto2, field_name), expected_type) 628 629 TestGetAndDeserialize('optional_int32', 1, int) 630 TestGetAndDeserialize('optional_int32', 1 << 30, int) 631 TestGetAndDeserialize('optional_uint32', 1 << 30, int) 632 try: 633 integer_64 = long 634 except NameError: # Python3 635 integer_64 = int 636 if struct.calcsize('L') == 4: 637 # Python only has signed ints, so 32-bit python can't fit an uint32 638 # in an int. 639 TestGetAndDeserialize('optional_uint32', 1 << 31, long) 640 else: 641 # 64-bit python can fit uint32 inside an int 642 TestGetAndDeserialize('optional_uint32', 1 << 31, int) 643 TestGetAndDeserialize('optional_int64', 1 << 30, integer_64) 644 TestGetAndDeserialize('optional_int64', 1 << 60, integer_64) 645 TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) 646 TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) 647 648 def testSingleScalarBoundsChecking(self): 649 def TestMinAndMaxIntegers(field_name, expected_min, expected_max): 650 pb = unittest_pb2.TestAllTypes() 651 setattr(pb, field_name, expected_min) 652 self.assertEqual(expected_min, getattr(pb, field_name)) 653 setattr(pb, field_name, expected_max) 654 self.assertEqual(expected_max, getattr(pb, field_name)) 655 self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) 656 self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) 657 658 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) 659 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) 660 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) 661 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) 662 663 pb = unittest_pb2.TestAllTypes() 664 pb.optional_nested_enum = 1 665 self.assertEqual(1, pb.optional_nested_enum) 666 667 def testRepeatedScalarTypeSafety(self): 668 proto = unittest_pb2.TestAllTypes() 669 self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) 670 self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') 671 self.assertRaises(TypeError, proto.repeated_string, 10) 672 self.assertRaises(TypeError, proto.repeated_bytes, 10) 673 674 proto.repeated_int32.append(10) 675 proto.repeated_int32[0] = 23 676 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) 677 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') 678 679 # Repeated enums tests. 680 #proto.repeated_nested_enum.append(0) 681 682 def testSingleScalarGettersAndSetters(self): 683 proto = unittest_pb2.TestAllTypes() 684 self.assertEqual(0, proto.optional_int32) 685 proto.optional_int32 = 1 686 self.assertEqual(1, proto.optional_int32) 687 688 proto.optional_uint64 = 0xffffffffffff 689 self.assertEqual(0xffffffffffff, proto.optional_uint64) 690 proto.optional_uint64 = 0xffffffffffffffff 691 self.assertEqual(0xffffffffffffffff, proto.optional_uint64) 692 # TODO(robinson): Test all other scalar field types. 693 694 def testSingleScalarClearField(self): 695 proto = unittest_pb2.TestAllTypes() 696 # Should be allowed to clear something that's not there (a no-op). 697 proto.ClearField('optional_int32') 698 proto.optional_int32 = 1 699 self.assertTrue(proto.HasField('optional_int32')) 700 proto.ClearField('optional_int32') 701 self.assertEqual(0, proto.optional_int32) 702 self.assertTrue(not proto.HasField('optional_int32')) 703 # TODO(robinson): Test all other scalar field types. 704 705 def testEnums(self): 706 proto = unittest_pb2.TestAllTypes() 707 self.assertEqual(1, proto.FOO) 708 self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) 709 self.assertEqual(2, proto.BAR) 710 self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) 711 self.assertEqual(3, proto.BAZ) 712 self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) 713 714 def testEnum_Name(self): 715 self.assertEqual('FOREIGN_FOO', 716 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO)) 717 self.assertEqual('FOREIGN_BAR', 718 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR)) 719 self.assertEqual('FOREIGN_BAZ', 720 unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ)) 721 self.assertRaises(ValueError, 722 unittest_pb2.ForeignEnum.Name, 11312) 723 724 proto = unittest_pb2.TestAllTypes() 725 self.assertEqual('FOO', 726 proto.NestedEnum.Name(proto.FOO)) 727 self.assertEqual('FOO', 728 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO)) 729 self.assertEqual('BAR', 730 proto.NestedEnum.Name(proto.BAR)) 731 self.assertEqual('BAR', 732 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR)) 733 self.assertEqual('BAZ', 734 proto.NestedEnum.Name(proto.BAZ)) 735 self.assertEqual('BAZ', 736 unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ)) 737 self.assertRaises(ValueError, 738 proto.NestedEnum.Name, 11312) 739 self.assertRaises(ValueError, 740 unittest_pb2.TestAllTypes.NestedEnum.Name, 11312) 741 742 def testEnum_Value(self): 743 self.assertEqual(unittest_pb2.FOREIGN_FOO, 744 unittest_pb2.ForeignEnum.Value('FOREIGN_FOO')) 745 self.assertEqual(unittest_pb2.FOREIGN_BAR, 746 unittest_pb2.ForeignEnum.Value('FOREIGN_BAR')) 747 self.assertEqual(unittest_pb2.FOREIGN_BAZ, 748 unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ')) 749 self.assertRaises(ValueError, 750 unittest_pb2.ForeignEnum.Value, 'FO') 751 752 proto = unittest_pb2.TestAllTypes() 753 self.assertEqual(proto.FOO, 754 proto.NestedEnum.Value('FOO')) 755 self.assertEqual(proto.FOO, 756 unittest_pb2.TestAllTypes.NestedEnum.Value('FOO')) 757 self.assertEqual(proto.BAR, 758 proto.NestedEnum.Value('BAR')) 759 self.assertEqual(proto.BAR, 760 unittest_pb2.TestAllTypes.NestedEnum.Value('BAR')) 761 self.assertEqual(proto.BAZ, 762 proto.NestedEnum.Value('BAZ')) 763 self.assertEqual(proto.BAZ, 764 unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ')) 765 self.assertRaises(ValueError, 766 proto.NestedEnum.Value, 'Foo') 767 self.assertRaises(ValueError, 768 unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo') 769 770 def testEnum_KeysAndValues(self): 771 self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], 772 list(unittest_pb2.ForeignEnum.keys())) 773 self.assertEqual([4, 5, 6], 774 list(unittest_pb2.ForeignEnum.values())) 775 self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), 776 ('FOREIGN_BAZ', 6)], 777 list(unittest_pb2.ForeignEnum.items())) 778 779 proto = unittest_pb2.TestAllTypes() 780 self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], list(proto.NestedEnum.keys())) 781 self.assertEqual([1, 2, 3, -1], list(proto.NestedEnum.values())) 782 self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], 783 list(proto.NestedEnum.items())) 784 785 def testRepeatedScalars(self): 786 proto = unittest_pb2.TestAllTypes() 787 788 self.assertTrue(not proto.repeated_int32) 789 self.assertEqual(0, len(proto.repeated_int32)) 790 proto.repeated_int32.append(5) 791 proto.repeated_int32.append(10) 792 proto.repeated_int32.append(15) 793 self.assertTrue(proto.repeated_int32) 794 self.assertEqual(3, len(proto.repeated_int32)) 795 796 self.assertEqual([5, 10, 15], proto.repeated_int32) 797 798 # Test single retrieval. 799 self.assertEqual(5, proto.repeated_int32[0]) 800 self.assertEqual(15, proto.repeated_int32[-1]) 801 # Test out-of-bounds indices. 802 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) 803 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) 804 # Test incorrect types passed to __getitem__. 805 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') 806 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) 807 808 # Test single assignment. 809 proto.repeated_int32[1] = 20 810 self.assertEqual([5, 20, 15], proto.repeated_int32) 811 812 # Test insertion. 813 proto.repeated_int32.insert(1, 25) 814 self.assertEqual([5, 25, 20, 15], proto.repeated_int32) 815 816 # Test slice retrieval. 817 proto.repeated_int32.append(30) 818 self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) 819 self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) 820 821 # Test slice assignment with an iterator 822 proto.repeated_int32[1:4] = (i for i in range(3)) 823 self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) 824 825 # Test slice assignment. 826 proto.repeated_int32[1:4] = [35, 40, 45] 827 self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) 828 829 # Test that we can use the field as an iterator. 830 result = [] 831 for i in proto.repeated_int32: 832 result.append(i) 833 self.assertEqual([5, 35, 40, 45, 30], result) 834 835 # Test single deletion. 836 del proto.repeated_int32[2] 837 self.assertEqual([5, 35, 45, 30], proto.repeated_int32) 838 839 # Test slice deletion. 840 del proto.repeated_int32[2:] 841 self.assertEqual([5, 35], proto.repeated_int32) 842 843 # Test extending. 844 proto.repeated_int32.extend([3, 13]) 845 self.assertEqual([5, 35, 3, 13], proto.repeated_int32) 846 847 # Test clearing. 848 proto.ClearField('repeated_int32') 849 self.assertTrue(not proto.repeated_int32) 850 self.assertEqual(0, len(proto.repeated_int32)) 851 852 proto.repeated_int32.append(1) 853 self.assertEqual(1, proto.repeated_int32[-1]) 854 # Test assignment to a negative index. 855 proto.repeated_int32[-1] = 2 856 self.assertEqual(2, proto.repeated_int32[-1]) 857 858 # Test deletion at negative indices. 859 proto.repeated_int32[:] = [0, 1, 2, 3] 860 del proto.repeated_int32[-1] 861 self.assertEqual([0, 1, 2], proto.repeated_int32) 862 863 del proto.repeated_int32[-2] 864 self.assertEqual([0, 2], proto.repeated_int32) 865 866 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) 867 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) 868 869 del proto.repeated_int32[-2:-1] 870 self.assertEqual([2], proto.repeated_int32) 871 872 del proto.repeated_int32[100:10000] 873 self.assertEqual([2], proto.repeated_int32) 874 875 def testRepeatedScalarsRemove(self): 876 proto = unittest_pb2.TestAllTypes() 877 878 self.assertTrue(not proto.repeated_int32) 879 self.assertEqual(0, len(proto.repeated_int32)) 880 proto.repeated_int32.append(5) 881 proto.repeated_int32.append(10) 882 proto.repeated_int32.append(5) 883 proto.repeated_int32.append(5) 884 885 self.assertEqual(4, len(proto.repeated_int32)) 886 proto.repeated_int32.remove(5) 887 self.assertEqual(3, len(proto.repeated_int32)) 888 self.assertEqual(10, proto.repeated_int32[0]) 889 self.assertEqual(5, proto.repeated_int32[1]) 890 self.assertEqual(5, proto.repeated_int32[2]) 891 892 proto.repeated_int32.remove(5) 893 self.assertEqual(2, len(proto.repeated_int32)) 894 self.assertEqual(10, proto.repeated_int32[0]) 895 self.assertEqual(5, proto.repeated_int32[1]) 896 897 proto.repeated_int32.remove(10) 898 self.assertEqual(1, len(proto.repeated_int32)) 899 self.assertEqual(5, proto.repeated_int32[0]) 900 901 # Remove a non-existent element. 902 self.assertRaises(ValueError, proto.repeated_int32.remove, 123) 903 904 def testRepeatedComposites(self): 905 proto = unittest_pb2.TestAllTypes() 906 self.assertTrue(not proto.repeated_nested_message) 907 self.assertEqual(0, len(proto.repeated_nested_message)) 908 m0 = proto.repeated_nested_message.add() 909 m1 = proto.repeated_nested_message.add() 910 self.assertTrue(proto.repeated_nested_message) 911 self.assertEqual(2, len(proto.repeated_nested_message)) 912 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 913 self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage) 914 915 # Test out-of-bounds indices. 916 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 917 1234) 918 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 919 -1234) 920 921 # Test incorrect types passed to __getitem__. 922 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 923 'foo') 924 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 925 None) 926 927 # Test slice retrieval. 928 m2 = proto.repeated_nested_message.add() 929 m3 = proto.repeated_nested_message.add() 930 m4 = proto.repeated_nested_message.add() 931 self.assertListsEqual( 932 [m1, m2, m3], proto.repeated_nested_message[1:4]) 933 self.assertListsEqual( 934 [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) 935 self.assertListsEqual( 936 [m0, m1], proto.repeated_nested_message[:2]) 937 self.assertListsEqual( 938 [m2, m3, m4], proto.repeated_nested_message[2:]) 939 self.assertEqual( 940 m0, proto.repeated_nested_message[0]) 941 self.assertListsEqual( 942 [m0], proto.repeated_nested_message[:1]) 943 944 # Test that we can use the field as an iterator. 945 result = [] 946 for i in proto.repeated_nested_message: 947 result.append(i) 948 self.assertListsEqual([m0, m1, m2, m3, m4], result) 949 950 # Test single deletion. 951 del proto.repeated_nested_message[2] 952 self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) 953 954 # Test slice deletion. 955 del proto.repeated_nested_message[2:] 956 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 957 958 # Test extending. 959 n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) 960 n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) 961 proto.repeated_nested_message.extend([n1,n2]) 962 self.assertEqual(4, len(proto.repeated_nested_message)) 963 self.assertEqual(n1, proto.repeated_nested_message[2]) 964 self.assertEqual(n2, proto.repeated_nested_message[3]) 965 966 # Test clearing. 967 proto.ClearField('repeated_nested_message') 968 self.assertTrue(not proto.repeated_nested_message) 969 self.assertEqual(0, len(proto.repeated_nested_message)) 970 971 # Test constructing an element while adding it. 972 proto.repeated_nested_message.add(bb=23) 973 self.assertEqual(1, len(proto.repeated_nested_message)) 974 self.assertEqual(23, proto.repeated_nested_message[0].bb) 975 976 def testRepeatedCompositeRemove(self): 977 proto = unittest_pb2.TestAllTypes() 978 979 self.assertEqual(0, len(proto.repeated_nested_message)) 980 m0 = proto.repeated_nested_message.add() 981 # Need to set some differentiating variable so m0 != m1 != m2: 982 m0.bb = len(proto.repeated_nested_message) 983 m1 = proto.repeated_nested_message.add() 984 m1.bb = len(proto.repeated_nested_message) 985 self.assertTrue(m0 != m1) 986 m2 = proto.repeated_nested_message.add() 987 m2.bb = len(proto.repeated_nested_message) 988 self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) 989 990 self.assertEqual(3, len(proto.repeated_nested_message)) 991 proto.repeated_nested_message.remove(m0) 992 self.assertEqual(2, len(proto.repeated_nested_message)) 993 self.assertEqual(m1, proto.repeated_nested_message[0]) 994 self.assertEqual(m2, proto.repeated_nested_message[1]) 995 996 # Removing m0 again or removing None should raise error 997 self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) 998 self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) 999 self.assertEqual(2, len(proto.repeated_nested_message)) 1000 1001 proto.repeated_nested_message.remove(m2) 1002 self.assertEqual(1, len(proto.repeated_nested_message)) 1003 self.assertEqual(m1, proto.repeated_nested_message[0]) 1004 1005 def testHandWrittenReflection(self): 1006 # Hand written extensions are only supported by the pure-Python 1007 # implementation of the API. 1008 if api_implementation.Type() != 'python': 1009 return 1010 1011 FieldDescriptor = descriptor.FieldDescriptor 1012 foo_field_descriptor = FieldDescriptor( 1013 name='foo_field', full_name='MyProto.foo_field', 1014 index=0, number=1, type=FieldDescriptor.TYPE_INT64, 1015 cpp_type=FieldDescriptor.CPPTYPE_INT64, 1016 label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, 1017 containing_type=None, message_type=None, enum_type=None, 1018 is_extension=False, extension_scope=None, 1019 options=descriptor_pb2.FieldOptions()) 1020 mydescriptor = descriptor.Descriptor( 1021 name='MyProto', full_name='MyProto', filename='ignored', 1022 containing_type=None, nested_types=[], enum_types=[], 1023 fields=[foo_field_descriptor], extensions=[], 1024 options=descriptor_pb2.MessageOptions()) 1025 class MyProtoClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 1026 DESCRIPTOR = mydescriptor 1027 myproto_instance = MyProtoClass() 1028 self.assertEqual(0, myproto_instance.foo_field) 1029 self.assertTrue(not myproto_instance.HasField('foo_field')) 1030 myproto_instance.foo_field = 23 1031 self.assertEqual(23, myproto_instance.foo_field) 1032 self.assertTrue(myproto_instance.HasField('foo_field')) 1033 1034 def testDescriptorProtoSupport(self): 1035 # Hand written descriptors/reflection are only supported by the pure-Python 1036 # implementation of the API. 1037 if api_implementation.Type() != 'python': 1038 return 1039 1040 def AddDescriptorField(proto, field_name, field_type): 1041 AddDescriptorField.field_index += 1 1042 new_field = proto.field.add() 1043 new_field.name = field_name 1044 new_field.type = field_type 1045 new_field.number = AddDescriptorField.field_index 1046 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL 1047 1048 AddDescriptorField.field_index = 0 1049 1050 desc_proto = descriptor_pb2.DescriptorProto() 1051 desc_proto.name = 'Car' 1052 fdp = descriptor_pb2.FieldDescriptorProto 1053 AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) 1054 AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) 1055 AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) 1056 AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) 1057 # Add a repeated field 1058 AddDescriptorField.field_index += 1 1059 new_field = desc_proto.field.add() 1060 new_field.name = 'owners' 1061 new_field.type = fdp.TYPE_STRING 1062 new_field.number = AddDescriptorField.field_index 1063 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED 1064 1065 desc = descriptor.MakeDescriptor(desc_proto) 1066 self.assertTrue('name' in desc.fields_by_name) 1067 self.assertTrue('year' in desc.fields_by_name) 1068 self.assertTrue('automatic' in desc.fields_by_name) 1069 self.assertTrue('price' in desc.fields_by_name) 1070 self.assertTrue('owners' in desc.fields_by_name) 1071 1072 class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 1073 DESCRIPTOR = desc 1074 1075 prius = CarMessage() 1076 prius.name = 'prius' 1077 prius.year = 2010 1078 prius.automatic = True 1079 prius.price = 25134.75 1080 prius.owners.extend(['bob', 'susan']) 1081 1082 serialized_prius = prius.SerializeToString() 1083 new_prius = reflection.ParseMessage(desc, serialized_prius) 1084 self.assertTrue(new_prius is not prius) 1085 self.assertEqual(prius, new_prius) 1086 1087 # these are unnecessary assuming message equality works as advertised but 1088 # explicitly check to be safe since we're mucking about in metaclass foo 1089 self.assertEqual(prius.name, new_prius.name) 1090 self.assertEqual(prius.year, new_prius.year) 1091 self.assertEqual(prius.automatic, new_prius.automatic) 1092 self.assertEqual(prius.price, new_prius.price) 1093 self.assertEqual(prius.owners, new_prius.owners) 1094 1095 def testTopLevelExtensionsForOptionalScalar(self): 1096 extendee_proto = unittest_pb2.TestAllExtensions() 1097 extension = unittest_pb2.optional_int32_extension 1098 self.assertTrue(not extendee_proto.HasExtension(extension)) 1099 self.assertEqual(0, extendee_proto.Extensions[extension]) 1100 # As with normal scalar fields, just doing a read doesn't actually set the 1101 # "has" bit. 1102 self.assertTrue(not extendee_proto.HasExtension(extension)) 1103 # Actually set the thing. 1104 extendee_proto.Extensions[extension] = 23 1105 self.assertEqual(23, extendee_proto.Extensions[extension]) 1106 self.assertTrue(extendee_proto.HasExtension(extension)) 1107 # Ensure that clearing works as well. 1108 extendee_proto.ClearExtension(extension) 1109 self.assertEqual(0, extendee_proto.Extensions[extension]) 1110 self.assertTrue(not extendee_proto.HasExtension(extension)) 1111 1112 def testTopLevelExtensionsForRepeatedScalar(self): 1113 extendee_proto = unittest_pb2.TestAllExtensions() 1114 extension = unittest_pb2.repeated_string_extension 1115 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1116 extendee_proto.Extensions[extension].append('foo') 1117 self.assertEqual(['foo'], extendee_proto.Extensions[extension]) 1118 string_list = extendee_proto.Extensions[extension] 1119 extendee_proto.ClearExtension(extension) 1120 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1121 self.assertTrue(string_list is not extendee_proto.Extensions[extension]) 1122 # Shouldn't be allowed to do Extensions[extension] = 'a' 1123 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1124 extension, 'a') 1125 1126 def testTopLevelExtensionsForOptionalMessage(self): 1127 extendee_proto = unittest_pb2.TestAllExtensions() 1128 extension = unittest_pb2.optional_foreign_message_extension 1129 self.assertTrue(not extendee_proto.HasExtension(extension)) 1130 self.assertEqual(0, extendee_proto.Extensions[extension].c) 1131 # As with normal (non-extension) fields, merely reading from the 1132 # thing shouldn't set the "has" bit. 1133 self.assertTrue(not extendee_proto.HasExtension(extension)) 1134 extendee_proto.Extensions[extension].c = 23 1135 self.assertEqual(23, extendee_proto.Extensions[extension].c) 1136 self.assertTrue(extendee_proto.HasExtension(extension)) 1137 # Save a reference here. 1138 foreign_message = extendee_proto.Extensions[extension] 1139 extendee_proto.ClearExtension(extension) 1140 self.assertTrue(foreign_message is not extendee_proto.Extensions[extension]) 1141 # Setting a field on foreign_message now shouldn't set 1142 # any "has" bits on extendee_proto. 1143 foreign_message.c = 42 1144 self.assertEqual(42, foreign_message.c) 1145 self.assertTrue(foreign_message.HasField('c')) 1146 self.assertTrue(not extendee_proto.HasExtension(extension)) 1147 # Shouldn't be allowed to do Extensions[extension] = 'a' 1148 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1149 extension, 'a') 1150 1151 def testTopLevelExtensionsForRepeatedMessage(self): 1152 extendee_proto = unittest_pb2.TestAllExtensions() 1153 extension = unittest_pb2.repeatedgroup_extension 1154 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1155 group = extendee_proto.Extensions[extension].add() 1156 group.a = 23 1157 self.assertEqual(23, extendee_proto.Extensions[extension][0].a) 1158 group.a = 42 1159 self.assertEqual(42, extendee_proto.Extensions[extension][0].a) 1160 group_list = extendee_proto.Extensions[extension] 1161 extendee_proto.ClearExtension(extension) 1162 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1163 self.assertTrue(group_list is not extendee_proto.Extensions[extension]) 1164 # Shouldn't be allowed to do Extensions[extension] = 'a' 1165 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1166 extension, 'a') 1167 1168 def testNestedExtensions(self): 1169 extendee_proto = unittest_pb2.TestAllExtensions() 1170 extension = unittest_pb2.TestRequired.single 1171 1172 # We just test the non-repeated case. 1173 self.assertTrue(not extendee_proto.HasExtension(extension)) 1174 required = extendee_proto.Extensions[extension] 1175 self.assertEqual(0, required.a) 1176 self.assertTrue(not extendee_proto.HasExtension(extension)) 1177 required.a = 23 1178 self.assertEqual(23, extendee_proto.Extensions[extension].a) 1179 self.assertTrue(extendee_proto.HasExtension(extension)) 1180 extendee_proto.ClearExtension(extension) 1181 self.assertTrue(required is not extendee_proto.Extensions[extension]) 1182 self.assertTrue(not extendee_proto.HasExtension(extension)) 1183 1184 def testRegisteredExtensions(self): 1185 self.assertTrue('protobuf_unittest.optional_int32_extension' in 1186 unittest_pb2.TestAllExtensions._extensions_by_name) 1187 self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) 1188 # Make sure extensions haven't been registered into types that shouldn't 1189 # have any. 1190 self.assertEqual(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) 1191 1192 # If message A directly contains message B, and 1193 # a.HasField('b') is currently False, then mutating any 1194 # extension in B should change a.HasField('b') to True 1195 # (and so on up the object tree). 1196 def testHasBitsForAncestorsOfExtendedMessage(self): 1197 # Optional scalar extension. 1198 toplevel = more_extensions_pb2.TopLevelMessage() 1199 self.assertTrue(not toplevel.HasField('submessage')) 1200 self.assertEqual(0, toplevel.submessage.Extensions[ 1201 more_extensions_pb2.optional_int_extension]) 1202 self.assertTrue(not toplevel.HasField('submessage')) 1203 toplevel.submessage.Extensions[ 1204 more_extensions_pb2.optional_int_extension] = 23 1205 self.assertEqual(23, toplevel.submessage.Extensions[ 1206 more_extensions_pb2.optional_int_extension]) 1207 self.assertTrue(toplevel.HasField('submessage')) 1208 1209 # Repeated scalar extension. 1210 toplevel = more_extensions_pb2.TopLevelMessage() 1211 self.assertTrue(not toplevel.HasField('submessage')) 1212 self.assertEqual([], toplevel.submessage.Extensions[ 1213 more_extensions_pb2.repeated_int_extension]) 1214 self.assertTrue(not toplevel.HasField('submessage')) 1215 toplevel.submessage.Extensions[ 1216 more_extensions_pb2.repeated_int_extension].append(23) 1217 self.assertEqual([23], toplevel.submessage.Extensions[ 1218 more_extensions_pb2.repeated_int_extension]) 1219 self.assertTrue(toplevel.HasField('submessage')) 1220 1221 # Optional message extension. 1222 toplevel = more_extensions_pb2.TopLevelMessage() 1223 self.assertTrue(not toplevel.HasField('submessage')) 1224 self.assertEqual(0, toplevel.submessage.Extensions[ 1225 more_extensions_pb2.optional_message_extension].foreign_message_int) 1226 self.assertTrue(not toplevel.HasField('submessage')) 1227 toplevel.submessage.Extensions[ 1228 more_extensions_pb2.optional_message_extension].foreign_message_int = 23 1229 self.assertEqual(23, toplevel.submessage.Extensions[ 1230 more_extensions_pb2.optional_message_extension].foreign_message_int) 1231 self.assertTrue(toplevel.HasField('submessage')) 1232 1233 # Repeated message extension. 1234 toplevel = more_extensions_pb2.TopLevelMessage() 1235 self.assertTrue(not toplevel.HasField('submessage')) 1236 self.assertEqual(0, len(toplevel.submessage.Extensions[ 1237 more_extensions_pb2.repeated_message_extension])) 1238 self.assertTrue(not toplevel.HasField('submessage')) 1239 foreign = toplevel.submessage.Extensions[ 1240 more_extensions_pb2.repeated_message_extension].add() 1241 self.assertEqual(foreign, toplevel.submessage.Extensions[ 1242 more_extensions_pb2.repeated_message_extension][0]) 1243 self.assertTrue(toplevel.HasField('submessage')) 1244 1245 def testDisconnectionAfterClearingEmptyMessage(self): 1246 toplevel = more_extensions_pb2.TopLevelMessage() 1247 extendee_proto = toplevel.submessage 1248 extension = more_extensions_pb2.optional_message_extension 1249 extension_proto = extendee_proto.Extensions[extension] 1250 extendee_proto.ClearExtension(extension) 1251 extension_proto.foreign_message_int = 23 1252 1253 self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) 1254 1255 def testExtensionFailureModes(self): 1256 extendee_proto = unittest_pb2.TestAllExtensions() 1257 1258 # Try non-extension-handle arguments to HasExtension, 1259 # ClearExtension(), and Extensions[]... 1260 self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) 1261 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) 1262 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) 1263 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) 1264 1265 # Try something that *is* an extension handle, just not for 1266 # this message... 1267 for unknown_handle in (more_extensions_pb2.optional_int_extension, 1268 more_extensions_pb2.optional_message_extension, 1269 more_extensions_pb2.repeated_int_extension, 1270 more_extensions_pb2.repeated_message_extension): 1271 self.assertRaises(KeyError, extendee_proto.HasExtension, 1272 unknown_handle) 1273 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1274 unknown_handle) 1275 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1276 unknown_handle) 1277 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1278 unknown_handle, 5) 1279 1280 # Try call HasExtension() with a valid handle, but for a 1281 # *repeated* field. (Just as with non-extension repeated 1282 # fields, Has*() isn't supported for extension repeated fields). 1283 self.assertRaises(KeyError, extendee_proto.HasExtension, 1284 unittest_pb2.repeated_string_extension) 1285 1286 def testStaticParseFrom(self): 1287 proto1 = unittest_pb2.TestAllTypes() 1288 test_util.SetAllFields(proto1) 1289 1290 string1 = proto1.SerializeToString() 1291 proto2 = unittest_pb2.TestAllTypes.FromString(string1) 1292 1293 # Messages should be equal. 1294 self.assertEqual(proto2, proto1) 1295 1296 def testMergeFromSingularField(self): 1297 # Test merge with just a singular field. 1298 proto1 = unittest_pb2.TestAllTypes() 1299 proto1.optional_int32 = 1 1300 1301 proto2 = unittest_pb2.TestAllTypes() 1302 # This shouldn't get overwritten. 1303 proto2.optional_string = 'value' 1304 1305 proto2.MergeFrom(proto1) 1306 self.assertEqual(1, proto2.optional_int32) 1307 self.assertEqual('value', proto2.optional_string) 1308 1309 def testMergeFromRepeatedField(self): 1310 # Test merge with just a repeated field. 1311 proto1 = unittest_pb2.TestAllTypes() 1312 proto1.repeated_int32.append(1) 1313 proto1.repeated_int32.append(2) 1314 1315 proto2 = unittest_pb2.TestAllTypes() 1316 proto2.repeated_int32.append(0) 1317 proto2.MergeFrom(proto1) 1318 1319 self.assertEqual(0, proto2.repeated_int32[0]) 1320 self.assertEqual(1, proto2.repeated_int32[1]) 1321 self.assertEqual(2, proto2.repeated_int32[2]) 1322 1323 def testMergeFromOptionalGroup(self): 1324 # Test merge with an optional group. 1325 proto1 = unittest_pb2.TestAllTypes() 1326 proto1.optionalgroup.a = 12 1327 proto2 = unittest_pb2.TestAllTypes() 1328 proto2.MergeFrom(proto1) 1329 self.assertEqual(12, proto2.optionalgroup.a) 1330 1331 def testMergeFromRepeatedNestedMessage(self): 1332 # Test merge with a repeated nested message. 1333 proto1 = unittest_pb2.TestAllTypes() 1334 m = proto1.repeated_nested_message.add() 1335 m.bb = 123 1336 m = proto1.repeated_nested_message.add() 1337 m.bb = 321 1338 1339 proto2 = unittest_pb2.TestAllTypes() 1340 m = proto2.repeated_nested_message.add() 1341 m.bb = 999 1342 proto2.MergeFrom(proto1) 1343 self.assertEqual(999, proto2.repeated_nested_message[0].bb) 1344 self.assertEqual(123, proto2.repeated_nested_message[1].bb) 1345 self.assertEqual(321, proto2.repeated_nested_message[2].bb) 1346 1347 proto3 = unittest_pb2.TestAllTypes() 1348 proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) 1349 self.assertEqual(999, proto3.repeated_nested_message[0].bb) 1350 self.assertEqual(123, proto3.repeated_nested_message[1].bb) 1351 self.assertEqual(321, proto3.repeated_nested_message[2].bb) 1352 1353 def testMergeFromAllFields(self): 1354 # With all fields set. 1355 proto1 = unittest_pb2.TestAllTypes() 1356 test_util.SetAllFields(proto1) 1357 proto2 = unittest_pb2.TestAllTypes() 1358 proto2.MergeFrom(proto1) 1359 1360 # Messages should be equal. 1361 self.assertEqual(proto2, proto1) 1362 1363 # Serialized string should be equal too. 1364 string1 = proto1.SerializeToString() 1365 string2 = proto2.SerializeToString() 1366 self.assertEqual(string1, string2) 1367 1368 def testMergeFromExtensionsSingular(self): 1369 proto1 = unittest_pb2.TestAllExtensions() 1370 proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 1371 1372 proto2 = unittest_pb2.TestAllExtensions() 1373 proto2.MergeFrom(proto1) 1374 self.assertEqual( 1375 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) 1376 1377 def testMergeFromExtensionsRepeated(self): 1378 proto1 = unittest_pb2.TestAllExtensions() 1379 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) 1380 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) 1381 1382 proto2 = unittest_pb2.TestAllExtensions() 1383 proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) 1384 proto2.MergeFrom(proto1) 1385 self.assertEqual( 1386 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) 1387 self.assertEqual( 1388 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) 1389 self.assertEqual( 1390 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) 1391 self.assertEqual( 1392 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) 1393 1394 def testMergeFromExtensionsNestedMessage(self): 1395 proto1 = unittest_pb2.TestAllExtensions() 1396 ext1 = proto1.Extensions[ 1397 unittest_pb2.repeated_nested_message_extension] 1398 m = ext1.add() 1399 m.bb = 222 1400 m = ext1.add() 1401 m.bb = 333 1402 1403 proto2 = unittest_pb2.TestAllExtensions() 1404 ext2 = proto2.Extensions[ 1405 unittest_pb2.repeated_nested_message_extension] 1406 m = ext2.add() 1407 m.bb = 111 1408 1409 proto2.MergeFrom(proto1) 1410 ext2 = proto2.Extensions[ 1411 unittest_pb2.repeated_nested_message_extension] 1412 self.assertEqual(3, len(ext2)) 1413 self.assertEqual(111, ext2[0].bb) 1414 self.assertEqual(222, ext2[1].bb) 1415 self.assertEqual(333, ext2[2].bb) 1416 1417 def testMergeFromBug(self): 1418 message1 = unittest_pb2.TestAllTypes() 1419 message2 = unittest_pb2.TestAllTypes() 1420 1421 # Cause optional_nested_message to be instantiated within message1, even 1422 # though it is not considered to be "present". 1423 message1.optional_nested_message 1424 self.assertFalse(message1.HasField('optional_nested_message')) 1425 1426 # Merge into message2. This should not instantiate the field is message2. 1427 message2.MergeFrom(message1) 1428 self.assertFalse(message2.HasField('optional_nested_message')) 1429 1430 def testCopyFromSingularField(self): 1431 # Test copy with just a singular field. 1432 proto1 = unittest_pb2.TestAllTypes() 1433 proto1.optional_int32 = 1 1434 proto1.optional_string = 'important-text' 1435 1436 proto2 = unittest_pb2.TestAllTypes() 1437 proto2.optional_string = 'value' 1438 1439 proto2.CopyFrom(proto1) 1440 self.assertEqual(1, proto2.optional_int32) 1441 self.assertEqual('important-text', proto2.optional_string) 1442 1443 def testCopyFromRepeatedField(self): 1444 # Test copy with a repeated field. 1445 proto1 = unittest_pb2.TestAllTypes() 1446 proto1.repeated_int32.append(1) 1447 proto1.repeated_int32.append(2) 1448 1449 proto2 = unittest_pb2.TestAllTypes() 1450 proto2.repeated_int32.append(0) 1451 proto2.CopyFrom(proto1) 1452 1453 self.assertEqual(1, proto2.repeated_int32[0]) 1454 self.assertEqual(2, proto2.repeated_int32[1]) 1455 1456 def testCopyFromAllFields(self): 1457 # With all fields set. 1458 proto1 = unittest_pb2.TestAllTypes() 1459 test_util.SetAllFields(proto1) 1460 proto2 = unittest_pb2.TestAllTypes() 1461 proto2.CopyFrom(proto1) 1462 1463 # Messages should be equal. 1464 self.assertEqual(proto2, proto1) 1465 1466 # Serialized string should be equal too. 1467 string1 = proto1.SerializeToString() 1468 string2 = proto2.SerializeToString() 1469 self.assertEqual(string1, string2) 1470 1471 def testCopyFromSelf(self): 1472 proto1 = unittest_pb2.TestAllTypes() 1473 proto1.repeated_int32.append(1) 1474 proto1.optional_int32 = 2 1475 proto1.optional_string = 'important-text' 1476 1477 proto1.CopyFrom(proto1) 1478 self.assertEqual(1, proto1.repeated_int32[0]) 1479 self.assertEqual(2, proto1.optional_int32) 1480 self.assertEqual('important-text', proto1.optional_string) 1481 1482 def testCopyFromBadType(self): 1483 # The python implementation doesn't raise an exception in this 1484 # case. In theory it should. 1485 if api_implementation.Type() == 'python': 1486 return 1487 proto1 = unittest_pb2.TestAllTypes() 1488 proto2 = unittest_pb2.TestAllExtensions() 1489 self.assertRaises(TypeError, proto1.CopyFrom, proto2) 1490 1491 def testDeepCopy(self): 1492 proto1 = unittest_pb2.TestAllTypes() 1493 proto1.optional_int32 = 1 1494 proto2 = copy.deepcopy(proto1) 1495 self.assertEqual(1, proto2.optional_int32) 1496 1497 proto1.repeated_int32.append(2) 1498 proto1.repeated_int32.append(3) 1499 container = copy.deepcopy(proto1.repeated_int32) 1500 self.assertEqual([2, 3], container) 1501 1502 # TODO(anuraag): Implement deepcopy for repeated composite / extension dict 1503 1504 def testClear(self): 1505 proto = unittest_pb2.TestAllTypes() 1506 # C++ implementation does not support lazy fields right now so leave it 1507 # out for now. 1508 if api_implementation.Type() == 'python': 1509 test_util.SetAllFields(proto) 1510 else: 1511 test_util.SetAllNonLazyFields(proto) 1512 # Clear the message. 1513 proto.Clear() 1514 self.assertEqual(proto.ByteSize(), 0) 1515 empty_proto = unittest_pb2.TestAllTypes() 1516 self.assertEqual(proto, empty_proto) 1517 1518 # Test if extensions which were set are cleared. 1519 proto = unittest_pb2.TestAllExtensions() 1520 test_util.SetAllExtensions(proto) 1521 # Clear the message. 1522 proto.Clear() 1523 self.assertEqual(proto.ByteSize(), 0) 1524 empty_proto = unittest_pb2.TestAllExtensions() 1525 self.assertEqual(proto, empty_proto) 1526 1527 def testDisconnectingBeforeClear(self): 1528 proto = unittest_pb2.TestAllTypes() 1529 nested = proto.optional_nested_message 1530 proto.Clear() 1531 self.assertTrue(nested is not proto.optional_nested_message) 1532 nested.bb = 23 1533 self.assertTrue(not proto.HasField('optional_nested_message')) 1534 self.assertEqual(0, proto.optional_nested_message.bb) 1535 1536 proto = unittest_pb2.TestAllTypes() 1537 nested = proto.optional_nested_message 1538 nested.bb = 5 1539 foreign = proto.optional_foreign_message 1540 foreign.c = 6 1541 1542 proto.Clear() 1543 self.assertTrue(nested is not proto.optional_nested_message) 1544 self.assertTrue(foreign is not proto.optional_foreign_message) 1545 self.assertEqual(5, nested.bb) 1546 self.assertEqual(6, foreign.c) 1547 nested.bb = 15 1548 foreign.c = 16 1549 self.assertFalse(proto.HasField('optional_nested_message')) 1550 self.assertEqual(0, proto.optional_nested_message.bb) 1551 self.assertFalse(proto.HasField('optional_foreign_message')) 1552 self.assertEqual(0, proto.optional_foreign_message.c) 1553 1554 def testOneOf(self): 1555 proto = unittest_pb2.TestAllTypes() 1556 proto.oneof_uint32 = 10 1557 proto.oneof_nested_message.bb = 11 1558 self.assertEqual(11, proto.oneof_nested_message.bb) 1559 self.assertFalse(proto.HasField('oneof_uint32')) 1560 nested = proto.oneof_nested_message 1561 proto.oneof_string = 'abc' 1562 self.assertEqual('abc', proto.oneof_string) 1563 self.assertEqual(11, nested.bb) 1564 self.assertFalse(proto.HasField('oneof_nested_message')) 1565 1566 def assertInitialized(self, proto): 1567 self.assertTrue(proto.IsInitialized()) 1568 # Neither method should raise an exception. 1569 proto.SerializeToString() 1570 proto.SerializePartialToString() 1571 1572 def assertNotInitialized(self, proto): 1573 self.assertFalse(proto.IsInitialized()) 1574 self.assertRaises(message.EncodeError, proto.SerializeToString) 1575 # "Partial" serialization doesn't care if message is uninitialized. 1576 proto.SerializePartialToString() 1577 1578 def testIsInitialized(self): 1579 # Trivial cases - all optional fields and extensions. 1580 proto = unittest_pb2.TestAllTypes() 1581 self.assertInitialized(proto) 1582 proto = unittest_pb2.TestAllExtensions() 1583 self.assertInitialized(proto) 1584 1585 # The case of uninitialized required fields. 1586 proto = unittest_pb2.TestRequired() 1587 self.assertNotInitialized(proto) 1588 proto.a = proto.b = proto.c = 2 1589 self.assertInitialized(proto) 1590 1591 # The case of uninitialized submessage. 1592 proto = unittest_pb2.TestRequiredForeign() 1593 self.assertInitialized(proto) 1594 proto.optional_message.a = 1 1595 self.assertNotInitialized(proto) 1596 proto.optional_message.b = 0 1597 proto.optional_message.c = 0 1598 self.assertInitialized(proto) 1599 1600 # Uninitialized repeated submessage. 1601 message1 = proto.repeated_message.add() 1602 self.assertNotInitialized(proto) 1603 message1.a = message1.b = message1.c = 0 1604 self.assertInitialized(proto) 1605 1606 # Uninitialized repeated group in an extension. 1607 proto = unittest_pb2.TestAllExtensions() 1608 extension = unittest_pb2.TestRequired.multi 1609 message1 = proto.Extensions[extension].add() 1610 message2 = proto.Extensions[extension].add() 1611 self.assertNotInitialized(proto) 1612 message1.a = 1 1613 message1.b = 1 1614 message1.c = 1 1615 self.assertNotInitialized(proto) 1616 message2.a = 2 1617 message2.b = 2 1618 message2.c = 2 1619 self.assertInitialized(proto) 1620 1621 # Uninitialized nonrepeated message in an extension. 1622 proto = unittest_pb2.TestAllExtensions() 1623 extension = unittest_pb2.TestRequired.single 1624 proto.Extensions[extension].a = 1 1625 self.assertNotInitialized(proto) 1626 proto.Extensions[extension].b = 2 1627 proto.Extensions[extension].c = 3 1628 self.assertInitialized(proto) 1629 1630 # Try passing an errors list. 1631 errors = [] 1632 proto = unittest_pb2.TestRequired() 1633 self.assertFalse(proto.IsInitialized(errors)) 1634 self.assertEqual(errors, ['a', 'b', 'c']) 1635 1636 @unittest.skipIf( 1637 api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, 1638 'Errors are only available from the most recent C++ implementation.') 1639 def testFileDescriptorErrors(self): 1640 file_name = 'test_file_descriptor_errors.proto' 1641 package_name = 'test_file_descriptor_errors.proto' 1642 file_descriptor_proto = descriptor_pb2.FileDescriptorProto() 1643 file_descriptor_proto.name = file_name 1644 file_descriptor_proto.package = package_name 1645 m1 = file_descriptor_proto.message_type.add() 1646 m1.name = 'msg1' 1647 # Compiles the proto into the C++ descriptor pool 1648 descriptor.FileDescriptor( 1649 file_name, 1650 package_name, 1651 serialized_pb=file_descriptor_proto.SerializeToString()) 1652 # Add a FileDescriptorProto that has duplicate symbols 1653 another_file_name = 'another_test_file_descriptor_errors.proto' 1654 file_descriptor_proto.name = another_file_name 1655 m2 = file_descriptor_proto.message_type.add() 1656 m2.name = 'msg2' 1657 with self.assertRaises(TypeError) as cm: 1658 descriptor.FileDescriptor( 1659 another_file_name, 1660 package_name, 1661 serialized_pb=file_descriptor_proto.SerializeToString()) 1662 self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % 1663 getattr(cm.expected, '__name__', cm.expected)) 1664 self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) 1665 # Error message will say something about this definition being a 1666 # duplicate, though we don't check the message exactly to avoid a 1667 # dependency on the C++ logging code. 1668 self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) 1669 1670 def testStringUTF8Encoding(self): 1671 proto = unittest_pb2.TestAllTypes() 1672 1673 # Assignment of a unicode object to a field of type 'bytes' is not allowed. 1674 self.assertRaises(TypeError, 1675 setattr, proto, 'optional_bytes', u'unicode object') 1676 1677 # Check that the default value is of python's 'unicode' type. 1678 self.assertEqual(type(proto.optional_string), six.text_type) 1679 1680 proto.optional_string = six.text_type('Testing') 1681 self.assertEqual(proto.optional_string, str('Testing')) 1682 1683 # Assign a value of type 'str' which can be encoded in UTF-8. 1684 proto.optional_string = str('Testing') 1685 self.assertEqual(proto.optional_string, six.text_type('Testing')) 1686 1687 # Try to assign a 'bytes' object which contains non-UTF-8. 1688 self.assertRaises(ValueError, 1689 setattr, proto, 'optional_string', b'a\x80a') 1690 # No exception: Assign already encoded UTF-8 bytes to a string field. 1691 utf8_bytes = u'Тест'.encode('utf-8') 1692 proto.optional_string = utf8_bytes 1693 # No exception: Assign the a non-ascii unicode object. 1694 proto.optional_string = u'Тест' 1695 # No exception thrown (normal str assignment containing ASCII). 1696 proto.optional_string = 'abc' 1697 1698 def testStringUTF8Serialization(self): 1699 proto = message_set_extensions_pb2.TestMessageSet() 1700 extension_message = message_set_extensions_pb2.TestMessageSetExtension2 1701 extension = extension_message.message_set_extension 1702 1703 test_utf8 = u'Тест' 1704 test_utf8_bytes = test_utf8.encode('utf-8') 1705 1706 # 'Test' in another language, using UTF-8 charset. 1707 proto.Extensions[extension].str = test_utf8 1708 1709 # Serialize using the MessageSet wire format (this is specified in the 1710 # .proto file). 1711 serialized = proto.SerializeToString() 1712 1713 # Check byte size. 1714 self.assertEqual(proto.ByteSize(), len(serialized)) 1715 1716 raw = unittest_mset_pb2.RawMessageSet() 1717 bytes_read = raw.MergeFromString(serialized) 1718 self.assertEqual(len(serialized), bytes_read) 1719 1720 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 1721 1722 self.assertEqual(1, len(raw.item)) 1723 # Check that the type_id is the same as the tag ID in the .proto file. 1724 self.assertEqual(raw.item[0].type_id, 98418634) 1725 1726 # Check the actual bytes on the wire. 1727 self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) 1728 bytes_read = message2.MergeFromString(raw.item[0].message) 1729 self.assertEqual(len(raw.item[0].message), bytes_read) 1730 1731 self.assertEqual(type(message2.str), six.text_type) 1732 self.assertEqual(message2.str, test_utf8) 1733 1734 # The pure Python API throws an exception on MergeFromString(), 1735 # if any of the string fields of the message can't be UTF-8 decoded. 1736 # The C++ implementation of the API has no way to check that on 1737 # MergeFromString and thus has no way to throw the exception. 1738 # 1739 # The pure Python API always returns objects of type 'unicode' (UTF-8 1740 # encoded), or 'bytes' (in 7 bit ASCII). 1741 badbytes = raw.item[0].message.replace( 1742 test_utf8_bytes, len(test_utf8_bytes) * b'\xff') 1743 1744 unicode_decode_failed = False 1745 try: 1746 message2.MergeFromString(badbytes) 1747 except UnicodeDecodeError: 1748 unicode_decode_failed = True 1749 string_field = message2.str 1750 self.assertTrue(unicode_decode_failed or type(string_field) is bytes) 1751 1752 def testBytesInTextFormat(self): 1753 proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') 1754 self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', 1755 six.text_type(proto)) 1756 1757 def testEmptyNestedMessage(self): 1758 proto = unittest_pb2.TestAllTypes() 1759 proto.optional_nested_message.MergeFrom( 1760 unittest_pb2.TestAllTypes.NestedMessage()) 1761 self.assertTrue(proto.HasField('optional_nested_message')) 1762 1763 proto = unittest_pb2.TestAllTypes() 1764 proto.optional_nested_message.CopyFrom( 1765 unittest_pb2.TestAllTypes.NestedMessage()) 1766 self.assertTrue(proto.HasField('optional_nested_message')) 1767 1768 proto = unittest_pb2.TestAllTypes() 1769 bytes_read = proto.optional_nested_message.MergeFromString(b'') 1770 self.assertEqual(0, bytes_read) 1771 self.assertTrue(proto.HasField('optional_nested_message')) 1772 1773 proto = unittest_pb2.TestAllTypes() 1774 proto.optional_nested_message.ParseFromString(b'') 1775 self.assertTrue(proto.HasField('optional_nested_message')) 1776 1777 serialized = proto.SerializeToString() 1778 proto2 = unittest_pb2.TestAllTypes() 1779 self.assertEqual( 1780 len(serialized), 1781 proto2.MergeFromString(serialized)) 1782 self.assertTrue(proto2.HasField('optional_nested_message')) 1783 1784 def testSetInParent(self): 1785 proto = unittest_pb2.TestAllTypes() 1786 self.assertFalse(proto.HasField('optionalgroup')) 1787 proto.optionalgroup.SetInParent() 1788 self.assertTrue(proto.HasField('optionalgroup')) 1789 1790 def testPackageInitializationImport(self): 1791 """Test that we can import nested messages from their __init__.py. 1792 1793 Such setup is not trivial since at the time of processing of __init__.py one 1794 can't refer to its submodules by name in code, so expressions like 1795 google.protobuf.internal.import_test_package.inner_pb2 1796 don't work. They do work in imports, so we have assign an alias at import 1797 and then use that alias in generated code. 1798 """ 1799 # We import here since it's the import that used to fail, and we want 1800 # the failure to have the right context. 1801 # pylint: disable=g-import-not-at-top 1802 from google.protobuf.internal import import_test_package 1803 # pylint: enable=g-import-not-at-top 1804 msg = import_test_package.myproto.Outer() 1805 # Just check the default value. 1806 self.assertEqual(57, msg.inner.value) 1807 1808# Since we had so many tests for protocol buffer equality, we broke these out 1809# into separate TestCase classes. 1810 1811 1812class TestAllTypesEqualityTest(unittest.TestCase): 1813 1814 def setUp(self): 1815 self.first_proto = unittest_pb2.TestAllTypes() 1816 self.second_proto = unittest_pb2.TestAllTypes() 1817 1818 def testNotHashable(self): 1819 self.assertRaises(TypeError, hash, self.first_proto) 1820 1821 def testSelfEquality(self): 1822 self.assertEqual(self.first_proto, self.first_proto) 1823 1824 def testEmptyProtosEqual(self): 1825 self.assertEqual(self.first_proto, self.second_proto) 1826 1827 1828class FullProtosEqualityTest(unittest.TestCase): 1829 1830 """Equality tests using completely-full protos as a starting point.""" 1831 1832 def setUp(self): 1833 self.first_proto = unittest_pb2.TestAllTypes() 1834 self.second_proto = unittest_pb2.TestAllTypes() 1835 test_util.SetAllFields(self.first_proto) 1836 test_util.SetAllFields(self.second_proto) 1837 1838 def testNotHashable(self): 1839 self.assertRaises(TypeError, hash, self.first_proto) 1840 1841 def testNoneNotEqual(self): 1842 self.assertNotEqual(self.first_proto, None) 1843 self.assertNotEqual(None, self.second_proto) 1844 1845 def testNotEqualToOtherMessage(self): 1846 third_proto = unittest_pb2.TestRequired() 1847 self.assertNotEqual(self.first_proto, third_proto) 1848 self.assertNotEqual(third_proto, self.second_proto) 1849 1850 def testAllFieldsFilledEquality(self): 1851 self.assertEqual(self.first_proto, self.second_proto) 1852 1853 def testNonRepeatedScalar(self): 1854 # Nonrepeated scalar field change should cause inequality. 1855 self.first_proto.optional_int32 += 1 1856 self.assertNotEqual(self.first_proto, self.second_proto) 1857 # ...as should clearing a field. 1858 self.first_proto.ClearField('optional_int32') 1859 self.assertNotEqual(self.first_proto, self.second_proto) 1860 1861 def testNonRepeatedComposite(self): 1862 # Change a nonrepeated composite field. 1863 self.first_proto.optional_nested_message.bb += 1 1864 self.assertNotEqual(self.first_proto, self.second_proto) 1865 self.first_proto.optional_nested_message.bb -= 1 1866 self.assertEqual(self.first_proto, self.second_proto) 1867 # Clear a field in the nested message. 1868 self.first_proto.optional_nested_message.ClearField('bb') 1869 self.assertNotEqual(self.first_proto, self.second_proto) 1870 self.first_proto.optional_nested_message.bb = ( 1871 self.second_proto.optional_nested_message.bb) 1872 self.assertEqual(self.first_proto, self.second_proto) 1873 # Remove the nested message entirely. 1874 self.first_proto.ClearField('optional_nested_message') 1875 self.assertNotEqual(self.first_proto, self.second_proto) 1876 1877 def testRepeatedScalar(self): 1878 # Change a repeated scalar field. 1879 self.first_proto.repeated_int32.append(5) 1880 self.assertNotEqual(self.first_proto, self.second_proto) 1881 self.first_proto.ClearField('repeated_int32') 1882 self.assertNotEqual(self.first_proto, self.second_proto) 1883 1884 def testRepeatedComposite(self): 1885 # Change value within a repeated composite field. 1886 self.first_proto.repeated_nested_message[0].bb += 1 1887 self.assertNotEqual(self.first_proto, self.second_proto) 1888 self.first_proto.repeated_nested_message[0].bb -= 1 1889 self.assertEqual(self.first_proto, self.second_proto) 1890 # Add a value to a repeated composite field. 1891 self.first_proto.repeated_nested_message.add() 1892 self.assertNotEqual(self.first_proto, self.second_proto) 1893 self.second_proto.repeated_nested_message.add() 1894 self.assertEqual(self.first_proto, self.second_proto) 1895 1896 def testNonRepeatedScalarHasBits(self): 1897 # Ensure that we test "has" bits as well as value for 1898 # nonrepeated scalar field. 1899 self.first_proto.ClearField('optional_int32') 1900 self.second_proto.optional_int32 = 0 1901 self.assertNotEqual(self.first_proto, self.second_proto) 1902 1903 def testNonRepeatedCompositeHasBits(self): 1904 # Ensure that we test "has" bits as well as value for 1905 # nonrepeated composite field. 1906 self.first_proto.ClearField('optional_nested_message') 1907 self.second_proto.optional_nested_message.ClearField('bb') 1908 self.assertNotEqual(self.first_proto, self.second_proto) 1909 self.first_proto.optional_nested_message.bb = 0 1910 self.first_proto.optional_nested_message.ClearField('bb') 1911 self.assertEqual(self.first_proto, self.second_proto) 1912 1913 1914class ExtensionEqualityTest(unittest.TestCase): 1915 1916 def testExtensionEquality(self): 1917 first_proto = unittest_pb2.TestAllExtensions() 1918 second_proto = unittest_pb2.TestAllExtensions() 1919 self.assertEqual(first_proto, second_proto) 1920 test_util.SetAllExtensions(first_proto) 1921 self.assertNotEqual(first_proto, second_proto) 1922 test_util.SetAllExtensions(second_proto) 1923 self.assertEqual(first_proto, second_proto) 1924 1925 # Ensure that we check value equality. 1926 first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 1927 self.assertNotEqual(first_proto, second_proto) 1928 first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 1929 self.assertEqual(first_proto, second_proto) 1930 1931 # Ensure that we also look at "has" bits. 1932 first_proto.ClearExtension(unittest_pb2.optional_int32_extension) 1933 second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 1934 self.assertNotEqual(first_proto, second_proto) 1935 first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 1936 self.assertEqual(first_proto, second_proto) 1937 1938 # Ensure that differences in cached values 1939 # don't matter if "has" bits are both false. 1940 first_proto = unittest_pb2.TestAllExtensions() 1941 second_proto = unittest_pb2.TestAllExtensions() 1942 self.assertEqual( 1943 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) 1944 self.assertEqual(first_proto, second_proto) 1945 1946 1947class MutualRecursionEqualityTest(unittest.TestCase): 1948 1949 def testEqualityWithMutualRecursion(self): 1950 first_proto = unittest_pb2.TestMutualRecursionA() 1951 second_proto = unittest_pb2.TestMutualRecursionA() 1952 self.assertEqual(first_proto, second_proto) 1953 first_proto.bb.a.bb.optional_int32 = 23 1954 self.assertNotEqual(first_proto, second_proto) 1955 second_proto.bb.a.bb.optional_int32 = 23 1956 self.assertEqual(first_proto, second_proto) 1957 1958 1959class ByteSizeTest(unittest.TestCase): 1960 1961 def setUp(self): 1962 self.proto = unittest_pb2.TestAllTypes() 1963 self.extended_proto = more_extensions_pb2.ExtendedMessage() 1964 self.packed_proto = unittest_pb2.TestPackedTypes() 1965 self.packed_extended_proto = unittest_pb2.TestPackedExtensions() 1966 1967 def Size(self): 1968 return self.proto.ByteSize() 1969 1970 def testEmptyMessage(self): 1971 self.assertEqual(0, self.proto.ByteSize()) 1972 1973 def testSizedOnKwargs(self): 1974 # Use a separate message to ensure testing right after creation. 1975 proto = unittest_pb2.TestAllTypes() 1976 self.assertEqual(0, proto.ByteSize()) 1977 proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) 1978 # One byte for the tag, one to encode varint 1. 1979 self.assertEqual(2, proto_kwargs.ByteSize()) 1980 1981 def testVarints(self): 1982 def Test(i, expected_varint_size): 1983 self.proto.Clear() 1984 self.proto.optional_int64 = i 1985 # Add one to the varint size for the tag info 1986 # for tag 1. 1987 self.assertEqual(expected_varint_size + 1, self.Size()) 1988 Test(0, 1) 1989 Test(1, 1) 1990 for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): 1991 Test((1 << i) - 1, num_bytes) 1992 Test(-1, 10) 1993 Test(-2, 10) 1994 Test(-(1 << 63), 10) 1995 1996 def testStrings(self): 1997 self.proto.optional_string = '' 1998 # Need one byte for tag info (tag #14), and one byte for length. 1999 self.assertEqual(2, self.Size()) 2000 2001 self.proto.optional_string = 'abc' 2002 # Need one byte for tag info (tag #14), and one byte for length. 2003 self.assertEqual(2 + len(self.proto.optional_string), self.Size()) 2004 2005 self.proto.optional_string = 'x' * 128 2006 # Need one byte for tag info (tag #14), and TWO bytes for length. 2007 self.assertEqual(3 + len(self.proto.optional_string), self.Size()) 2008 2009 def testOtherNumerics(self): 2010 self.proto.optional_fixed32 = 1234 2011 # One byte for tag and 4 bytes for fixed32. 2012 self.assertEqual(5, self.Size()) 2013 self.proto = unittest_pb2.TestAllTypes() 2014 2015 self.proto.optional_fixed64 = 1234 2016 # One byte for tag and 8 bytes for fixed64. 2017 self.assertEqual(9, self.Size()) 2018 self.proto = unittest_pb2.TestAllTypes() 2019 2020 self.proto.optional_float = 1.234 2021 # One byte for tag and 4 bytes for float. 2022 self.assertEqual(5, self.Size()) 2023 self.proto = unittest_pb2.TestAllTypes() 2024 2025 self.proto.optional_double = 1.234 2026 # One byte for tag and 8 bytes for float. 2027 self.assertEqual(9, self.Size()) 2028 self.proto = unittest_pb2.TestAllTypes() 2029 2030 self.proto.optional_sint32 = 64 2031 # One byte for tag and 2 bytes for zig-zag-encoded 64. 2032 self.assertEqual(3, self.Size()) 2033 self.proto = unittest_pb2.TestAllTypes() 2034 2035 def testComposites(self): 2036 # 3 bytes. 2037 self.proto.optional_nested_message.bb = (1 << 14) 2038 # Plus one byte for bb tag. 2039 # Plus 1 byte for optional_nested_message serialized size. 2040 # Plus two bytes for optional_nested_message tag. 2041 self.assertEqual(3 + 1 + 1 + 2, self.Size()) 2042 2043 def testGroups(self): 2044 # 4 bytes. 2045 self.proto.optionalgroup.a = (1 << 21) 2046 # Plus two bytes for |a| tag. 2047 # Plus 2 * two bytes for START_GROUP and END_GROUP tags. 2048 self.assertEqual(4 + 2 + 2*2, self.Size()) 2049 2050 def testRepeatedScalars(self): 2051 self.proto.repeated_int32.append(10) # 1 byte. 2052 self.proto.repeated_int32.append(128) # 2 bytes. 2053 # Also need 2 bytes for each entry for tag. 2054 self.assertEqual(1 + 2 + 2*2, self.Size()) 2055 2056 def testRepeatedScalarsExtend(self): 2057 self.proto.repeated_int32.extend([10, 128]) # 3 bytes. 2058 # Also need 2 bytes for each entry for tag. 2059 self.assertEqual(1 + 2 + 2*2, self.Size()) 2060 2061 def testRepeatedScalarsRemove(self): 2062 self.proto.repeated_int32.append(10) # 1 byte. 2063 self.proto.repeated_int32.append(128) # 2 bytes. 2064 # Also need 2 bytes for each entry for tag. 2065 self.assertEqual(1 + 2 + 2*2, self.Size()) 2066 self.proto.repeated_int32.remove(128) 2067 self.assertEqual(1 + 2, self.Size()) 2068 2069 def testRepeatedComposites(self): 2070 # Empty message. 2 bytes tag plus 1 byte length. 2071 foreign_message_0 = self.proto.repeated_nested_message.add() 2072 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2073 foreign_message_1 = self.proto.repeated_nested_message.add() 2074 foreign_message_1.bb = 7 2075 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2076 2077 def testRepeatedCompositesDelete(self): 2078 # Empty message. 2 bytes tag plus 1 byte length. 2079 foreign_message_0 = self.proto.repeated_nested_message.add() 2080 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2081 foreign_message_1 = self.proto.repeated_nested_message.add() 2082 foreign_message_1.bb = 9 2083 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2084 2085 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2086 del self.proto.repeated_nested_message[0] 2087 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2088 2089 # Now add a new message. 2090 foreign_message_2 = self.proto.repeated_nested_message.add() 2091 foreign_message_2.bb = 12 2092 2093 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2094 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2095 self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) 2096 2097 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2098 del self.proto.repeated_nested_message[1] 2099 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2100 2101 del self.proto.repeated_nested_message[0] 2102 self.assertEqual(0, self.Size()) 2103 2104 def testRepeatedGroups(self): 2105 # 2-byte START_GROUP plus 2-byte END_GROUP. 2106 group_0 = self.proto.repeatedgroup.add() 2107 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| 2108 # plus 2-byte END_GROUP. 2109 group_1 = self.proto.repeatedgroup.add() 2110 group_1.a = 7 2111 self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) 2112 2113 def testExtensions(self): 2114 proto = unittest_pb2.TestAllExtensions() 2115 self.assertEqual(0, proto.ByteSize()) 2116 extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. 2117 proto.Extensions[extension] = 23 2118 # 1 byte for tag, 1 byte for value. 2119 self.assertEqual(2, proto.ByteSize()) 2120 2121 def testCacheInvalidationForNonrepeatedScalar(self): 2122 # Test non-extension. 2123 self.proto.optional_int32 = 1 2124 self.assertEqual(2, self.proto.ByteSize()) 2125 self.proto.optional_int32 = 128 2126 self.assertEqual(3, self.proto.ByteSize()) 2127 self.proto.ClearField('optional_int32') 2128 self.assertEqual(0, self.proto.ByteSize()) 2129 2130 # Test within extension. 2131 extension = more_extensions_pb2.optional_int_extension 2132 self.extended_proto.Extensions[extension] = 1 2133 self.assertEqual(2, self.extended_proto.ByteSize()) 2134 self.extended_proto.Extensions[extension] = 128 2135 self.assertEqual(3, self.extended_proto.ByteSize()) 2136 self.extended_proto.ClearExtension(extension) 2137 self.assertEqual(0, self.extended_proto.ByteSize()) 2138 2139 def testCacheInvalidationForRepeatedScalar(self): 2140 # Test non-extension. 2141 self.proto.repeated_int32.append(1) 2142 self.assertEqual(3, self.proto.ByteSize()) 2143 self.proto.repeated_int32.append(1) 2144 self.assertEqual(6, self.proto.ByteSize()) 2145 self.proto.repeated_int32[1] = 128 2146 self.assertEqual(7, self.proto.ByteSize()) 2147 self.proto.ClearField('repeated_int32') 2148 self.assertEqual(0, self.proto.ByteSize()) 2149 2150 # Test within extension. 2151 extension = more_extensions_pb2.repeated_int_extension 2152 repeated = self.extended_proto.Extensions[extension] 2153 repeated.append(1) 2154 self.assertEqual(2, self.extended_proto.ByteSize()) 2155 repeated.append(1) 2156 self.assertEqual(4, self.extended_proto.ByteSize()) 2157 repeated[1] = 128 2158 self.assertEqual(5, self.extended_proto.ByteSize()) 2159 self.extended_proto.ClearExtension(extension) 2160 self.assertEqual(0, self.extended_proto.ByteSize()) 2161 2162 def testCacheInvalidationForNonrepeatedMessage(self): 2163 # Test non-extension. 2164 self.proto.optional_foreign_message.c = 1 2165 self.assertEqual(5, self.proto.ByteSize()) 2166 self.proto.optional_foreign_message.c = 128 2167 self.assertEqual(6, self.proto.ByteSize()) 2168 self.proto.optional_foreign_message.ClearField('c') 2169 self.assertEqual(3, self.proto.ByteSize()) 2170 self.proto.ClearField('optional_foreign_message') 2171 self.assertEqual(0, self.proto.ByteSize()) 2172 2173 if api_implementation.Type() == 'python': 2174 # This is only possible in pure-Python implementation of the API. 2175 child = self.proto.optional_foreign_message 2176 self.proto.ClearField('optional_foreign_message') 2177 child.c = 128 2178 self.assertEqual(0, self.proto.ByteSize()) 2179 2180 # Test within extension. 2181 extension = more_extensions_pb2.optional_message_extension 2182 child = self.extended_proto.Extensions[extension] 2183 self.assertEqual(0, self.extended_proto.ByteSize()) 2184 child.foreign_message_int = 1 2185 self.assertEqual(4, self.extended_proto.ByteSize()) 2186 child.foreign_message_int = 128 2187 self.assertEqual(5, self.extended_proto.ByteSize()) 2188 self.extended_proto.ClearExtension(extension) 2189 self.assertEqual(0, self.extended_proto.ByteSize()) 2190 2191 def testCacheInvalidationForRepeatedMessage(self): 2192 # Test non-extension. 2193 child0 = self.proto.repeated_foreign_message.add() 2194 self.assertEqual(3, self.proto.ByteSize()) 2195 self.proto.repeated_foreign_message.add() 2196 self.assertEqual(6, self.proto.ByteSize()) 2197 child0.c = 1 2198 self.assertEqual(8, self.proto.ByteSize()) 2199 self.proto.ClearField('repeated_foreign_message') 2200 self.assertEqual(0, self.proto.ByteSize()) 2201 2202 # Test within extension. 2203 extension = more_extensions_pb2.repeated_message_extension 2204 child_list = self.extended_proto.Extensions[extension] 2205 child0 = child_list.add() 2206 self.assertEqual(2, self.extended_proto.ByteSize()) 2207 child_list.add() 2208 self.assertEqual(4, self.extended_proto.ByteSize()) 2209 child0.foreign_message_int = 1 2210 self.assertEqual(6, self.extended_proto.ByteSize()) 2211 child0.ClearField('foreign_message_int') 2212 self.assertEqual(4, self.extended_proto.ByteSize()) 2213 self.extended_proto.ClearExtension(extension) 2214 self.assertEqual(0, self.extended_proto.ByteSize()) 2215 2216 def testPackedRepeatedScalars(self): 2217 self.assertEqual(0, self.packed_proto.ByteSize()) 2218 2219 self.packed_proto.packed_int32.append(10) # 1 byte. 2220 self.packed_proto.packed_int32.append(128) # 2 bytes. 2221 # The tag is 2 bytes (the field number is 90), and the varint 2222 # storing the length is 1 byte. 2223 int_size = 1 + 2 + 3 2224 self.assertEqual(int_size, self.packed_proto.ByteSize()) 2225 2226 self.packed_proto.packed_double.append(4.2) # 8 bytes 2227 self.packed_proto.packed_double.append(3.25) # 8 bytes 2228 # 2 more tag bytes, 1 more length byte. 2229 double_size = 8 + 8 + 3 2230 self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) 2231 2232 self.packed_proto.ClearField('packed_int32') 2233 self.assertEqual(double_size, self.packed_proto.ByteSize()) 2234 2235 def testPackedExtensions(self): 2236 self.assertEqual(0, self.packed_extended_proto.ByteSize()) 2237 extension = self.packed_extended_proto.Extensions[ 2238 unittest_pb2.packed_fixed32_extension] 2239 extension.extend([1, 2, 3, 4]) # 16 bytes 2240 # Tag is 3 bytes. 2241 self.assertEqual(19, self.packed_extended_proto.ByteSize()) 2242 2243 2244# Issues to be sure to cover include: 2245# * Handling of unrecognized tags ("uninterpreted_bytes"). 2246# * Handling of MessageSets. 2247# * Consistent ordering of tags in the wire format, 2248# including ordering between extensions and non-extension 2249# fields. 2250# * Consistent serialization of negative numbers, especially 2251# negative int32s. 2252# * Handling of empty submessages (with and without "has" 2253# bits set). 2254 2255class SerializationTest(unittest.TestCase): 2256 2257 def testSerializeEmtpyMessage(self): 2258 first_proto = unittest_pb2.TestAllTypes() 2259 second_proto = unittest_pb2.TestAllTypes() 2260 serialized = first_proto.SerializeToString() 2261 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2262 self.assertEqual( 2263 len(serialized), 2264 second_proto.MergeFromString(serialized)) 2265 self.assertEqual(first_proto, second_proto) 2266 2267 def testSerializeAllFields(self): 2268 first_proto = unittest_pb2.TestAllTypes() 2269 second_proto = unittest_pb2.TestAllTypes() 2270 test_util.SetAllFields(first_proto) 2271 serialized = first_proto.SerializeToString() 2272 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2273 self.assertEqual( 2274 len(serialized), 2275 second_proto.MergeFromString(serialized)) 2276 self.assertEqual(first_proto, second_proto) 2277 2278 def testSerializeAllExtensions(self): 2279 first_proto = unittest_pb2.TestAllExtensions() 2280 second_proto = unittest_pb2.TestAllExtensions() 2281 test_util.SetAllExtensions(first_proto) 2282 serialized = first_proto.SerializeToString() 2283 self.assertEqual( 2284 len(serialized), 2285 second_proto.MergeFromString(serialized)) 2286 self.assertEqual(first_proto, second_proto) 2287 2288 def testSerializeWithOptionalGroup(self): 2289 first_proto = unittest_pb2.TestAllTypes() 2290 second_proto = unittest_pb2.TestAllTypes() 2291 first_proto.optionalgroup.a = 242 2292 serialized = first_proto.SerializeToString() 2293 self.assertEqual( 2294 len(serialized), 2295 second_proto.MergeFromString(serialized)) 2296 self.assertEqual(first_proto, second_proto) 2297 2298 def testSerializeNegativeValues(self): 2299 first_proto = unittest_pb2.TestAllTypes() 2300 2301 first_proto.optional_int32 = -1 2302 first_proto.optional_int64 = -(2 << 40) 2303 first_proto.optional_sint32 = -3 2304 first_proto.optional_sint64 = -(4 << 40) 2305 first_proto.optional_sfixed32 = -5 2306 first_proto.optional_sfixed64 = -(6 << 40) 2307 2308 second_proto = unittest_pb2.TestAllTypes.FromString( 2309 first_proto.SerializeToString()) 2310 2311 self.assertEqual(first_proto, second_proto) 2312 2313 def testParseTruncated(self): 2314 # This test is only applicable for the Python implementation of the API. 2315 if api_implementation.Type() != 'python': 2316 return 2317 2318 first_proto = unittest_pb2.TestAllTypes() 2319 test_util.SetAllFields(first_proto) 2320 serialized = first_proto.SerializeToString() 2321 2322 for truncation_point in range(len(serialized) + 1): 2323 try: 2324 second_proto = unittest_pb2.TestAllTypes() 2325 unknown_fields = unittest_pb2.TestEmptyMessage() 2326 pos = second_proto._InternalParse(serialized, 0, truncation_point) 2327 # If we didn't raise an error then we read exactly the amount expected. 2328 self.assertEqual(truncation_point, pos) 2329 2330 # Parsing to unknown fields should not throw if parsing to known fields 2331 # did not. 2332 try: 2333 pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) 2334 self.assertEqual(truncation_point, pos2) 2335 except message.DecodeError: 2336 self.fail('Parsing unknown fields failed when parsing known fields ' 2337 'did not.') 2338 except message.DecodeError: 2339 # Parsing unknown fields should also fail. 2340 self.assertRaises(message.DecodeError, unknown_fields._InternalParse, 2341 serialized, 0, truncation_point) 2342 2343 def testCanonicalSerializationOrder(self): 2344 proto = more_messages_pb2.OutOfOrderFields() 2345 # These are also their tag numbers. Even though we're setting these in 2346 # reverse-tag order AND they're listed in reverse tag-order in the .proto 2347 # file, they should nonetheless be serialized in tag order. 2348 proto.optional_sint32 = 5 2349 proto.Extensions[more_messages_pb2.optional_uint64] = 4 2350 proto.optional_uint32 = 3 2351 proto.Extensions[more_messages_pb2.optional_int64] = 2 2352 proto.optional_int32 = 1 2353 serialized = proto.SerializeToString() 2354 self.assertEqual(proto.ByteSize(), len(serialized)) 2355 d = _MiniDecoder(serialized) 2356 ReadTag = d.ReadFieldNumberAndWireType 2357 self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) 2358 self.assertEqual(1, d.ReadInt32()) 2359 self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) 2360 self.assertEqual(2, d.ReadInt64()) 2361 self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) 2362 self.assertEqual(3, d.ReadUInt32()) 2363 self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) 2364 self.assertEqual(4, d.ReadUInt64()) 2365 self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) 2366 self.assertEqual(5, d.ReadSInt32()) 2367 2368 def testCanonicalSerializationOrderSameAsCpp(self): 2369 # Copy of the same test we use for C++. 2370 proto = unittest_pb2.TestFieldOrderings() 2371 test_util.SetAllFieldsAndExtensions(proto) 2372 serialized = proto.SerializeToString() 2373 test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) 2374 2375 def testMergeFromStringWhenFieldsAlreadySet(self): 2376 first_proto = unittest_pb2.TestAllTypes() 2377 first_proto.repeated_string.append('foobar') 2378 first_proto.optional_int32 = 23 2379 first_proto.optional_nested_message.bb = 42 2380 serialized = first_proto.SerializeToString() 2381 2382 second_proto = unittest_pb2.TestAllTypes() 2383 second_proto.repeated_string.append('baz') 2384 second_proto.optional_int32 = 100 2385 second_proto.optional_nested_message.bb = 999 2386 2387 bytes_parsed = second_proto.MergeFromString(serialized) 2388 self.assertEqual(len(serialized), bytes_parsed) 2389 2390 # Ensure that we append to repeated fields. 2391 self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) 2392 # Ensure that we overwrite nonrepeatd scalars. 2393 self.assertEqual(23, second_proto.optional_int32) 2394 # Ensure that we recursively call MergeFromString() on 2395 # submessages. 2396 self.assertEqual(42, second_proto.optional_nested_message.bb) 2397 2398 def testMessageSetWireFormat(self): 2399 proto = message_set_extensions_pb2.TestMessageSet() 2400 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2401 extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 2402 extension1 = extension_message1.message_set_extension 2403 extension2 = extension_message2.message_set_extension 2404 extension3 = message_set_extensions_pb2.message_set_extension3 2405 proto.Extensions[extension1].i = 123 2406 proto.Extensions[extension2].str = 'foo' 2407 proto.Extensions[extension3].text = 'bar' 2408 2409 # Serialize using the MessageSet wire format (this is specified in the 2410 # .proto file). 2411 serialized = proto.SerializeToString() 2412 2413 raw = unittest_mset_pb2.RawMessageSet() 2414 self.assertEqual(False, 2415 raw.DESCRIPTOR.GetOptions().message_set_wire_format) 2416 self.assertEqual( 2417 len(serialized), 2418 raw.MergeFromString(serialized)) 2419 self.assertEqual(3, len(raw.item)) 2420 2421 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2422 self.assertEqual( 2423 len(raw.item[0].message), 2424 message1.MergeFromString(raw.item[0].message)) 2425 self.assertEqual(123, message1.i) 2426 2427 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 2428 self.assertEqual( 2429 len(raw.item[1].message), 2430 message2.MergeFromString(raw.item[1].message)) 2431 self.assertEqual('foo', message2.str) 2432 2433 message3 = message_set_extensions_pb2.TestMessageSetExtension3() 2434 self.assertEqual( 2435 len(raw.item[2].message), 2436 message3.MergeFromString(raw.item[2].message)) 2437 self.assertEqual('bar', message3.text) 2438 2439 # Deserialize using the MessageSet wire format. 2440 proto2 = message_set_extensions_pb2.TestMessageSet() 2441 self.assertEqual( 2442 len(serialized), 2443 proto2.MergeFromString(serialized)) 2444 self.assertEqual(123, proto2.Extensions[extension1].i) 2445 self.assertEqual('foo', proto2.Extensions[extension2].str) 2446 self.assertEqual('bar', proto2.Extensions[extension3].text) 2447 2448 # Check byte size. 2449 self.assertEqual(proto2.ByteSize(), len(serialized)) 2450 self.assertEqual(proto.ByteSize(), len(serialized)) 2451 2452 def testMessageSetWireFormatUnknownExtension(self): 2453 # Create a message using the message set wire format with an unknown 2454 # message. 2455 raw = unittest_mset_pb2.RawMessageSet() 2456 2457 # Add an item. 2458 item = raw.item.add() 2459 item.type_id = 98418603 2460 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2461 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2462 message1.i = 12345 2463 item.message = message1.SerializeToString() 2464 2465 # Add a second, unknown extension. 2466 item = raw.item.add() 2467 item.type_id = 98418604 2468 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2469 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2470 message1.i = 12346 2471 item.message = message1.SerializeToString() 2472 2473 # Add another unknown extension. 2474 item = raw.item.add() 2475 item.type_id = 98418605 2476 message1 = message_set_extensions_pb2.TestMessageSetExtension2() 2477 message1.str = 'foo' 2478 item.message = message1.SerializeToString() 2479 2480 serialized = raw.SerializeToString() 2481 2482 # Parse message using the message set wire format. 2483 proto = message_set_extensions_pb2.TestMessageSet() 2484 self.assertEqual( 2485 len(serialized), 2486 proto.MergeFromString(serialized)) 2487 2488 # Check that the message parsed well. 2489 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2490 extension1 = extension_message1.message_set_extension 2491 self.assertEqual(12345, proto.Extensions[extension1].i) 2492 2493 def testUnknownFields(self): 2494 proto = unittest_pb2.TestAllTypes() 2495 test_util.SetAllFields(proto) 2496 2497 serialized = proto.SerializeToString() 2498 2499 # The empty message should be parsable with all of the fields 2500 # unknown. 2501 proto2 = unittest_pb2.TestEmptyMessage() 2502 2503 # Parsing this message should succeed. 2504 self.assertEqual( 2505 len(serialized), 2506 proto2.MergeFromString(serialized)) 2507 2508 # Now test with a int64 field set. 2509 proto = unittest_pb2.TestAllTypes() 2510 proto.optional_int64 = 0x0fffffffffffffff 2511 serialized = proto.SerializeToString() 2512 # The empty message should be parsable with all of the fields 2513 # unknown. 2514 proto2 = unittest_pb2.TestEmptyMessage() 2515 # Parsing this message should succeed. 2516 self.assertEqual( 2517 len(serialized), 2518 proto2.MergeFromString(serialized)) 2519 2520 def _CheckRaises(self, exc_class, callable_obj, exception): 2521 """This method checks if the excpetion type and message are as expected.""" 2522 try: 2523 callable_obj() 2524 except exc_class as ex: 2525 # Check if the exception message is the right one. 2526 self.assertEqual(exception, str(ex)) 2527 return 2528 else: 2529 raise self.failureException('%s not raised' % str(exc_class)) 2530 2531 def testSerializeUninitialized(self): 2532 proto = unittest_pb2.TestRequired() 2533 self._CheckRaises( 2534 message.EncodeError, 2535 proto.SerializeToString, 2536 'Message protobuf_unittest.TestRequired is missing required fields: ' 2537 'a,b,c') 2538 # Shouldn't raise exceptions. 2539 partial = proto.SerializePartialToString() 2540 2541 proto2 = unittest_pb2.TestRequired() 2542 self.assertFalse(proto2.HasField('a')) 2543 # proto2 ParseFromString does not check that required fields are set. 2544 proto2.ParseFromString(partial) 2545 self.assertFalse(proto2.HasField('a')) 2546 2547 proto.a = 1 2548 self._CheckRaises( 2549 message.EncodeError, 2550 proto.SerializeToString, 2551 'Message protobuf_unittest.TestRequired is missing required fields: b,c') 2552 # Shouldn't raise exceptions. 2553 partial = proto.SerializePartialToString() 2554 2555 proto.b = 2 2556 self._CheckRaises( 2557 message.EncodeError, 2558 proto.SerializeToString, 2559 'Message protobuf_unittest.TestRequired is missing required fields: c') 2560 # Shouldn't raise exceptions. 2561 partial = proto.SerializePartialToString() 2562 2563 proto.c = 3 2564 serialized = proto.SerializeToString() 2565 # Shouldn't raise exceptions. 2566 partial = proto.SerializePartialToString() 2567 2568 proto2 = unittest_pb2.TestRequired() 2569 self.assertEqual( 2570 len(serialized), 2571 proto2.MergeFromString(serialized)) 2572 self.assertEqual(1, proto2.a) 2573 self.assertEqual(2, proto2.b) 2574 self.assertEqual(3, proto2.c) 2575 self.assertEqual( 2576 len(partial), 2577 proto2.MergeFromString(partial)) 2578 self.assertEqual(1, proto2.a) 2579 self.assertEqual(2, proto2.b) 2580 self.assertEqual(3, proto2.c) 2581 2582 def testSerializeUninitializedSubMessage(self): 2583 proto = unittest_pb2.TestRequiredForeign() 2584 2585 # Sub-message doesn't exist yet, so this succeeds. 2586 proto.SerializeToString() 2587 2588 proto.optional_message.a = 1 2589 self._CheckRaises( 2590 message.EncodeError, 2591 proto.SerializeToString, 2592 'Message protobuf_unittest.TestRequiredForeign ' 2593 'is missing required fields: ' 2594 'optional_message.b,optional_message.c') 2595 2596 proto.optional_message.b = 2 2597 proto.optional_message.c = 3 2598 proto.SerializeToString() 2599 2600 proto.repeated_message.add().a = 1 2601 proto.repeated_message.add().b = 2 2602 self._CheckRaises( 2603 message.EncodeError, 2604 proto.SerializeToString, 2605 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 2606 'repeated_message[0].b,repeated_message[0].c,' 2607 'repeated_message[1].a,repeated_message[1].c') 2608 2609 proto.repeated_message[0].b = 2 2610 proto.repeated_message[0].c = 3 2611 proto.repeated_message[1].a = 1 2612 proto.repeated_message[1].c = 3 2613 proto.SerializeToString() 2614 2615 def testSerializeAllPackedFields(self): 2616 first_proto = unittest_pb2.TestPackedTypes() 2617 second_proto = unittest_pb2.TestPackedTypes() 2618 test_util.SetAllPackedFields(first_proto) 2619 serialized = first_proto.SerializeToString() 2620 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2621 bytes_read = second_proto.MergeFromString(serialized) 2622 self.assertEqual(second_proto.ByteSize(), bytes_read) 2623 self.assertEqual(first_proto, second_proto) 2624 2625 def testSerializeAllPackedExtensions(self): 2626 first_proto = unittest_pb2.TestPackedExtensions() 2627 second_proto = unittest_pb2.TestPackedExtensions() 2628 test_util.SetAllPackedExtensions(first_proto) 2629 serialized = first_proto.SerializeToString() 2630 bytes_read = second_proto.MergeFromString(serialized) 2631 self.assertEqual(second_proto.ByteSize(), bytes_read) 2632 self.assertEqual(first_proto, second_proto) 2633 2634 def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): 2635 first_proto = unittest_pb2.TestPackedTypes() 2636 first_proto.packed_int32.extend([1, 2]) 2637 first_proto.packed_double.append(3.0) 2638 serialized = first_proto.SerializeToString() 2639 2640 second_proto = unittest_pb2.TestPackedTypes() 2641 second_proto.packed_int32.append(3) 2642 second_proto.packed_double.extend([1.0, 2.0]) 2643 second_proto.packed_sint32.append(4) 2644 2645 self.assertEqual( 2646 len(serialized), 2647 second_proto.MergeFromString(serialized)) 2648 self.assertEqual([3, 1, 2], second_proto.packed_int32) 2649 self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) 2650 self.assertEqual([4], second_proto.packed_sint32) 2651 2652 def testPackedFieldsWireFormat(self): 2653 proto = unittest_pb2.TestPackedTypes() 2654 proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes 2655 proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes 2656 proto.packed_float.append(2.0) # 4 bytes, will be before double 2657 serialized = proto.SerializeToString() 2658 self.assertEqual(proto.ByteSize(), len(serialized)) 2659 d = _MiniDecoder(serialized) 2660 ReadTag = d.ReadFieldNumberAndWireType 2661 self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2662 self.assertEqual(1+1+1+2, d.ReadInt32()) 2663 self.assertEqual(1, d.ReadInt32()) 2664 self.assertEqual(2, d.ReadInt32()) 2665 self.assertEqual(150, d.ReadInt32()) 2666 self.assertEqual(3, d.ReadInt32()) 2667 self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2668 self.assertEqual(4, d.ReadInt32()) 2669 self.assertEqual(2.0, d.ReadFloat()) 2670 self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 2671 self.assertEqual(8+8, d.ReadInt32()) 2672 self.assertEqual(1.0, d.ReadDouble()) 2673 self.assertEqual(1000.0, d.ReadDouble()) 2674 self.assertTrue(d.EndOfStream()) 2675 2676 def testParsePackedFromUnpacked(self): 2677 unpacked = unittest_pb2.TestUnpackedTypes() 2678 test_util.SetAllUnpackedFields(unpacked) 2679 packed = unittest_pb2.TestPackedTypes() 2680 serialized = unpacked.SerializeToString() 2681 self.assertEqual( 2682 len(serialized), 2683 packed.MergeFromString(serialized)) 2684 expected = unittest_pb2.TestPackedTypes() 2685 test_util.SetAllPackedFields(expected) 2686 self.assertEqual(expected, packed) 2687 2688 def testParseUnpackedFromPacked(self): 2689 packed = unittest_pb2.TestPackedTypes() 2690 test_util.SetAllPackedFields(packed) 2691 unpacked = unittest_pb2.TestUnpackedTypes() 2692 serialized = packed.SerializeToString() 2693 self.assertEqual( 2694 len(serialized), 2695 unpacked.MergeFromString(serialized)) 2696 expected = unittest_pb2.TestUnpackedTypes() 2697 test_util.SetAllUnpackedFields(expected) 2698 self.assertEqual(expected, unpacked) 2699 2700 def testFieldNumbers(self): 2701 proto = unittest_pb2.TestAllTypes() 2702 self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) 2703 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) 2704 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) 2705 self.assertEqual( 2706 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) 2707 self.assertEqual( 2708 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) 2709 self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) 2710 self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) 2711 self.assertEqual( 2712 unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) 2713 self.assertEqual( 2714 unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) 2715 2716 def testExtensionFieldNumbers(self): 2717 self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) 2718 self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) 2719 self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) 2720 self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) 2721 self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) 2722 self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) 2723 self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) 2724 self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) 2725 self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) 2726 self.assertEqual( 2727 unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) 2728 self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) 2729 self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2730 21) 2731 self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) 2732 self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) 2733 self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) 2734 self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) 2735 self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) 2736 self.assertEqual( 2737 unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) 2738 self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) 2739 self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 2740 51) 2741 2742 def testInitKwargs(self): 2743 proto = unittest_pb2.TestAllTypes( 2744 optional_int32=1, 2745 optional_string='foo', 2746 optional_bool=True, 2747 optional_bytes=b'bar', 2748 optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), 2749 optional_foreign_message=unittest_pb2.ForeignMessage(c=1), 2750 optional_nested_enum=unittest_pb2.TestAllTypes.FOO, 2751 optional_foreign_enum=unittest_pb2.FOREIGN_FOO, 2752 repeated_int32=[1, 2, 3]) 2753 self.assertTrue(proto.IsInitialized()) 2754 self.assertTrue(proto.HasField('optional_int32')) 2755 self.assertTrue(proto.HasField('optional_string')) 2756 self.assertTrue(proto.HasField('optional_bool')) 2757 self.assertTrue(proto.HasField('optional_bytes')) 2758 self.assertTrue(proto.HasField('optional_nested_message')) 2759 self.assertTrue(proto.HasField('optional_foreign_message')) 2760 self.assertTrue(proto.HasField('optional_nested_enum')) 2761 self.assertTrue(proto.HasField('optional_foreign_enum')) 2762 self.assertEqual(1, proto.optional_int32) 2763 self.assertEqual('foo', proto.optional_string) 2764 self.assertEqual(True, proto.optional_bool) 2765 self.assertEqual(b'bar', proto.optional_bytes) 2766 self.assertEqual(1, proto.optional_nested_message.bb) 2767 self.assertEqual(1, proto.optional_foreign_message.c) 2768 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 2769 proto.optional_nested_enum) 2770 self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) 2771 self.assertEqual([1, 2, 3], proto.repeated_int32) 2772 2773 def testInitArgsUnknownFieldName(self): 2774 def InitalizeEmptyMessageWithExtraKeywordArg(): 2775 unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') 2776 self._CheckRaises( 2777 ValueError, 2778 InitalizeEmptyMessageWithExtraKeywordArg, 2779 'Protocol message TestEmptyMessage has no "unknown" field.') 2780 2781 def testInitRequiredKwargs(self): 2782 proto = unittest_pb2.TestRequired(a=1, b=1, c=1) 2783 self.assertTrue(proto.IsInitialized()) 2784 self.assertTrue(proto.HasField('a')) 2785 self.assertTrue(proto.HasField('b')) 2786 self.assertTrue(proto.HasField('c')) 2787 self.assertTrue(not proto.HasField('dummy2')) 2788 self.assertEqual(1, proto.a) 2789 self.assertEqual(1, proto.b) 2790 self.assertEqual(1, proto.c) 2791 2792 def testInitRequiredForeignKwargs(self): 2793 proto = unittest_pb2.TestRequiredForeign( 2794 optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) 2795 self.assertTrue(proto.IsInitialized()) 2796 self.assertTrue(proto.HasField('optional_message')) 2797 self.assertTrue(proto.optional_message.IsInitialized()) 2798 self.assertTrue(proto.optional_message.HasField('a')) 2799 self.assertTrue(proto.optional_message.HasField('b')) 2800 self.assertTrue(proto.optional_message.HasField('c')) 2801 self.assertTrue(not proto.optional_message.HasField('dummy2')) 2802 self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), 2803 proto.optional_message) 2804 self.assertEqual(1, proto.optional_message.a) 2805 self.assertEqual(1, proto.optional_message.b) 2806 self.assertEqual(1, proto.optional_message.c) 2807 2808 def testInitRepeatedKwargs(self): 2809 proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) 2810 self.assertTrue(proto.IsInitialized()) 2811 self.assertEqual(1, proto.repeated_int32[0]) 2812 self.assertEqual(2, proto.repeated_int32[1]) 2813 self.assertEqual(3, proto.repeated_int32[2]) 2814 2815 2816class OptionsTest(unittest.TestCase): 2817 2818 def testMessageOptions(self): 2819 proto = message_set_extensions_pb2.TestMessageSet() 2820 self.assertEqual(True, 2821 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 2822 proto = unittest_pb2.TestAllTypes() 2823 self.assertEqual(False, 2824 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 2825 2826 def testPackedOptions(self): 2827 proto = unittest_pb2.TestAllTypes() 2828 proto.optional_int32 = 1 2829 proto.optional_double = 3.0 2830 for field_descriptor, _ in proto.ListFields(): 2831 self.assertEqual(False, field_descriptor.GetOptions().packed) 2832 2833 proto = unittest_pb2.TestPackedTypes() 2834 proto.packed_int32.append(1) 2835 proto.packed_double.append(3.0) 2836 for field_descriptor, _ in proto.ListFields(): 2837 self.assertEqual(True, field_descriptor.GetOptions().packed) 2838 self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, 2839 field_descriptor.label) 2840 2841 2842 2843class ClassAPITest(unittest.TestCase): 2844 2845 @unittest.skipIf( 2846 api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 2847 'C++ implementation requires a call to MakeDescriptor()') 2848 def testMakeClassWithNestedDescriptor(self): 2849 leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', 2850 containing_type=None, fields=[], 2851 nested_types=[], enum_types=[], 2852 extensions=[]) 2853 child_desc = descriptor.Descriptor('child', 'package.parent.child', '', 2854 containing_type=None, fields=[], 2855 nested_types=[leaf_desc], enum_types=[], 2856 extensions=[]) 2857 sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', 2858 '', containing_type=None, fields=[], 2859 nested_types=[], enum_types=[], 2860 extensions=[]) 2861 parent_desc = descriptor.Descriptor('parent', 'package.parent', '', 2862 containing_type=None, fields=[], 2863 nested_types=[child_desc, sibling_desc], 2864 enum_types=[], extensions=[]) 2865 message_class = reflection.MakeClass(parent_desc) 2866 self.assertIn('child', message_class.__dict__) 2867 self.assertIn('sibling', message_class.__dict__) 2868 self.assertIn('leaf', message_class.child.__dict__) 2869 2870 def _GetSerializedFileDescriptor(self, name): 2871 """Get a serialized representation of a test FileDescriptorProto. 2872 2873 Args: 2874 name: All calls to this must use a unique message name, to avoid 2875 collisions in the cpp descriptor pool. 2876 Returns: 2877 A string containing the serialized form of a test FileDescriptorProto. 2878 """ 2879 file_descriptor_str = ( 2880 'message_type {' 2881 ' name: "' + name + '"' 2882 ' field {' 2883 ' name: "flat"' 2884 ' number: 1' 2885 ' label: LABEL_REPEATED' 2886 ' type: TYPE_UINT32' 2887 ' }' 2888 ' field {' 2889 ' name: "bar"' 2890 ' number: 2' 2891 ' label: LABEL_OPTIONAL' 2892 ' type: TYPE_MESSAGE' 2893 ' type_name: "Bar"' 2894 ' }' 2895 ' nested_type {' 2896 ' name: "Bar"' 2897 ' field {' 2898 ' name: "baz"' 2899 ' number: 3' 2900 ' label: LABEL_OPTIONAL' 2901 ' type: TYPE_MESSAGE' 2902 ' type_name: "Baz"' 2903 ' }' 2904 ' nested_type {' 2905 ' name: "Baz"' 2906 ' enum_type {' 2907 ' name: "deep_enum"' 2908 ' value {' 2909 ' name: "VALUE_A"' 2910 ' number: 0' 2911 ' }' 2912 ' }' 2913 ' field {' 2914 ' name: "deep"' 2915 ' number: 4' 2916 ' label: LABEL_OPTIONAL' 2917 ' type: TYPE_UINT32' 2918 ' }' 2919 ' }' 2920 ' }' 2921 '}') 2922 file_descriptor = descriptor_pb2.FileDescriptorProto() 2923 text_format.Merge(file_descriptor_str, file_descriptor) 2924 return file_descriptor.SerializeToString() 2925 2926 def testParsingFlatClassWithExplicitClassDeclaration(self): 2927 """Test that the generated class can parse a flat message.""" 2928 # TODO(xiaofeng): This test fails with cpp implemetnation in the call 2929 # of six.with_metaclass(). The other two callsites of with_metaclass 2930 # in this file are both excluded from cpp test, so it might be expected 2931 # to fail. Need someone more familiar with the python code to take a 2932 # look at this. 2933 if api_implementation.Type() != 'python': 2934 return 2935 file_descriptor = descriptor_pb2.FileDescriptorProto() 2936 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) 2937 msg_descriptor = descriptor.MakeDescriptor( 2938 file_descriptor.message_type[0]) 2939 2940 class MessageClass(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): 2941 DESCRIPTOR = msg_descriptor 2942 msg = MessageClass() 2943 msg_str = ( 2944 'flat: 0 ' 2945 'flat: 1 ' 2946 'flat: 2 ') 2947 text_format.Merge(msg_str, msg) 2948 self.assertEqual(msg.flat, [0, 1, 2]) 2949 2950 def testParsingFlatClass(self): 2951 """Test that the generated class can parse a flat message.""" 2952 file_descriptor = descriptor_pb2.FileDescriptorProto() 2953 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) 2954 msg_descriptor = descriptor.MakeDescriptor( 2955 file_descriptor.message_type[0]) 2956 msg_class = reflection.MakeClass(msg_descriptor) 2957 msg = msg_class() 2958 msg_str = ( 2959 'flat: 0 ' 2960 'flat: 1 ' 2961 'flat: 2 ') 2962 text_format.Merge(msg_str, msg) 2963 self.assertEqual(msg.flat, [0, 1, 2]) 2964 2965 def testParsingNestedClass(self): 2966 """Test that the generated class can parse a nested message.""" 2967 file_descriptor = descriptor_pb2.FileDescriptorProto() 2968 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) 2969 msg_descriptor = descriptor.MakeDescriptor( 2970 file_descriptor.message_type[0]) 2971 msg_class = reflection.MakeClass(msg_descriptor) 2972 msg = msg_class() 2973 msg_str = ( 2974 'bar {' 2975 ' baz {' 2976 ' deep: 4' 2977 ' }' 2978 '}') 2979 text_format.Merge(msg_str, msg) 2980 self.assertEqual(msg.bar.baz.deep, 4) 2981 2982if __name__ == '__main__': 2983 unittest.main() 2984