1# -*- coding: utf-8 -*- 2# Protocol Buffers - Google's data interchange format 3# Copyright 2008 Google Inc. All rights reserved. 4# 5# Use of this source code is governed by a BSD-style 6# license that can be found in the LICENSE file or at 7# https://developers.google.com/open-source/licenses/bsd 8 9"""Unittest for reflection.py, which also indirectly tests the output of the 10pure-Python protocol compiler. 11""" 12 13import copy 14import gc 15import operator 16import struct 17import sys 18import unittest 19import warnings 20 21from google.protobuf import descriptor 22from google.protobuf import descriptor_pb2 23from google.protobuf import message 24from google.protobuf import message_factory 25from google.protobuf import reflection 26from google.protobuf import text_format 27from google.protobuf.internal import api_implementation 28from google.protobuf.internal import decoder 29from google.protobuf.internal import message_set_extensions_pb2 30from google.protobuf.internal import more_extensions_pb2 31from google.protobuf.internal import more_messages_pb2 32from google.protobuf.internal import test_util 33from google.protobuf.internal import testing_refleaks 34from google.protobuf.internal import wire_format 35from google.protobuf.internal import _parameterized 36from google.protobuf import unittest_import_pb2 37from google.protobuf import unittest_mset_pb2 38from google.protobuf import unittest_pb2 39from google.protobuf import unittest_proto3_arena_pb2 40 41 42warnings.simplefilter('error', DeprecationWarning) 43 44 45class _MiniDecoder(object): 46 """Decodes a stream of values from a string. 47 48 Once upon a time we actually had a class called decoder.Decoder. Then we 49 got rid of it during a redesign that made decoding much, much faster overall. 50 But a couple tests in this file used it to check that the serialized form of 51 a message was correct. So, this class implements just the methods that were 52 used by said tests, so that we don't have to rewrite the tests. 53 """ 54 55 def __init__(self, bytes): 56 self._bytes = bytes 57 self._pos = 0 58 59 def ReadVarint(self): 60 result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) 61 return result 62 63 ReadInt32 = ReadVarint 64 ReadInt64 = ReadVarint 65 ReadUInt32 = ReadVarint 66 ReadUInt64 = ReadVarint 67 68 def ReadSInt64(self): 69 return wire_format.ZigZagDecode(self.ReadVarint()) 70 71 ReadSInt32 = ReadSInt64 72 73 def ReadFieldNumberAndWireType(self): 74 return wire_format.UnpackTag(self.ReadVarint()) 75 76 def ReadFloat(self): 77 result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0] 78 self._pos += 4 79 return result 80 81 def ReadDouble(self): 82 result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0] 83 self._pos += 8 84 return result 85 86 def EndOfStream(self): 87 return self._pos == len(self._bytes) 88 89 90@_parameterized.named_parameters( 91 ('_proto2', unittest_pb2), 92 ('_proto3', unittest_proto3_arena_pb2)) 93@testing_refleaks.TestCase 94class ReflectionTest(unittest.TestCase): 95 96 def assertListsEqual(self, values, others): 97 self.assertEqual(len(values), len(others)) 98 for i in range(len(values)): 99 self.assertEqual(values[i], others[i]) 100 101 def testScalarConstructor(self, message_module): 102 # Constructor with only scalar types should succeed. 103 proto = message_module.TestAllTypes( 104 optional_int32=24, 105 optional_double=54.321, 106 optional_string='optional_string', 107 optional_float=None) 108 109 self.assertEqual(24, proto.optional_int32) 110 self.assertEqual(54.321, proto.optional_double) 111 self.assertEqual('optional_string', proto.optional_string) 112 if message_module is unittest_pb2: 113 self.assertFalse(proto.HasField("optional_float")) 114 115 def testRepeatedScalarConstructor(self, message_module): 116 # Constructor with only repeated scalar types should succeed. 117 proto = message_module.TestAllTypes( 118 repeated_int32=[1, 2, 3, 4], 119 repeated_double=[1.23, 54.321], 120 repeated_bool=[True, False, False], 121 repeated_string=["optional_string"], 122 repeated_float=None) 123 124 self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) 125 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 126 self.assertEqual([True, False, False], list(proto.repeated_bool)) 127 self.assertEqual(["optional_string"], list(proto.repeated_string)) 128 self.assertEqual([], list(proto.repeated_float)) 129 130 def testMixedConstructor(self, message_module): 131 # Constructor with only mixed types should succeed. 132 proto = message_module.TestAllTypes( 133 optional_int32=24, 134 optional_string='optional_string', 135 repeated_double=[1.23, 54.321], 136 repeated_bool=[True, False, False], 137 repeated_nested_message=[ 138 message_module.TestAllTypes.NestedMessage( 139 bb=message_module.TestAllTypes.FOO), 140 message_module.TestAllTypes.NestedMessage( 141 bb=message_module.TestAllTypes.BAR)], 142 repeated_foreign_message=[ 143 message_module.ForeignMessage(c=-43), 144 message_module.ForeignMessage(c=45324), 145 message_module.ForeignMessage(c=12)], 146 optional_nested_message=None) 147 148 self.assertEqual(24, proto.optional_int32) 149 self.assertEqual('optional_string', proto.optional_string) 150 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 151 self.assertEqual([True, False, False], list(proto.repeated_bool)) 152 self.assertEqual( 153 [message_module.TestAllTypes.NestedMessage( 154 bb=message_module.TestAllTypes.FOO), 155 message_module.TestAllTypes.NestedMessage( 156 bb=message_module.TestAllTypes.BAR)], 157 list(proto.repeated_nested_message)) 158 self.assertEqual( 159 [message_module.ForeignMessage(c=-43), 160 message_module.ForeignMessage(c=45324), 161 message_module.ForeignMessage(c=12)], 162 list(proto.repeated_foreign_message)) 163 self.assertFalse(proto.HasField("optional_nested_message")) 164 165 def testConstructorTypeError(self, message_module): 166 self.assertRaises( 167 TypeError, message_module.TestAllTypes, optional_int32='foo') 168 self.assertRaises( 169 TypeError, message_module.TestAllTypes, optional_string=1234) 170 self.assertRaises( 171 TypeError, message_module.TestAllTypes, optional_nested_message=1234) 172 self.assertRaises( 173 TypeError, message_module.TestAllTypes, repeated_int32=1234) 174 self.assertRaises( 175 TypeError, message_module.TestAllTypes, repeated_int32=['foo']) 176 self.assertRaises( 177 TypeError, message_module.TestAllTypes, repeated_string=1234) 178 self.assertRaises( 179 TypeError, message_module.TestAllTypes, repeated_string=[1234]) 180 self.assertRaises( 181 TypeError, message_module.TestAllTypes, repeated_nested_message=1234) 182 self.assertRaises( 183 TypeError, message_module.TestAllTypes, repeated_nested_message=[1234]) 184 185 def testConstructorInvalidatesCachedByteSize(self, message_module): 186 message = message_module.TestAllTypes(optional_int32=12) 187 self.assertEqual(2, message.ByteSize()) 188 189 message = message_module.TestAllTypes( 190 optional_nested_message=message_module.TestAllTypes.NestedMessage()) 191 self.assertEqual(3, message.ByteSize()) 192 193 message = message_module.TestAllTypes(repeated_int32=[12]) 194 # TODO: Add this test back for proto3 195 if message_module is unittest_pb2: 196 self.assertEqual(3, message.ByteSize()) 197 198 message = message_module.TestAllTypes( 199 repeated_nested_message=[message_module.TestAllTypes.NestedMessage()]) 200 self.assertEqual(3, message.ByteSize()) 201 202 def testReferencesToNestedMessage(self, message_module): 203 proto = message_module.TestAllTypes() 204 nested = proto.optional_nested_message 205 del proto 206 # A previous version had a bug where this would raise an exception when 207 # hitting a now-dead weak reference. 208 nested.bb = 23 209 210 def testOneOf(self, message_module): 211 proto = message_module.TestAllTypes() 212 proto.oneof_uint32 = 10 213 proto.oneof_nested_message.bb = 11 214 self.assertEqual(11, proto.oneof_nested_message.bb) 215 self.assertFalse(proto.HasField('oneof_uint32')) 216 nested = proto.oneof_nested_message 217 proto.oneof_string = 'abc' 218 self.assertEqual('abc', proto.oneof_string) 219 self.assertEqual(11, nested.bb) 220 self.assertFalse(proto.HasField('oneof_nested_message')) 221 222 def testGetDefaultMessageAfterDisconnectingDefaultMessage( 223 self, message_module): 224 proto = message_module.TestAllTypes() 225 nested = proto.optional_nested_message 226 proto.ClearField('optional_nested_message') 227 del proto 228 del nested 229 # Force a garbage collect so that the underlying CMessages are freed along 230 # with the Messages they point to. This is to make sure we're not deleting 231 # default message instances. 232 gc.collect() 233 proto = message_module.TestAllTypes() 234 nested = proto.optional_nested_message 235 236 def testDisconnectingNestedMessageAfterSettingField(self, message_module): 237 proto = message_module.TestAllTypes() 238 nested = proto.optional_nested_message 239 nested.bb = 5 240 self.assertTrue(proto.HasField('optional_nested_message')) 241 proto.ClearField('optional_nested_message') # Should disconnect from parent 242 self.assertEqual(5, nested.bb) 243 self.assertEqual(0, proto.optional_nested_message.bb) 244 self.assertIsNot(nested, proto.optional_nested_message) 245 nested.bb = 23 246 self.assertFalse(proto.HasField('optional_nested_message')) 247 self.assertEqual(0, proto.optional_nested_message.bb) 248 249 def testDisconnectingNestedMessageBeforeGettingField(self, message_module): 250 proto = message_module.TestAllTypes() 251 self.assertFalse(proto.HasField('optional_nested_message')) 252 proto.ClearField('optional_nested_message') 253 self.assertFalse(proto.HasField('optional_nested_message')) 254 255 def testDisconnectingNestedMessageAfterMerge(self, message_module): 256 # This test exercises the code path that does not use ReleaseMessage(). 257 # The underlying fear is that if we use ReleaseMessage() incorrectly, 258 # we will have memory leaks. It's hard to check that that doesn't happen, 259 # but at least we can exercise that code path to make sure it works. 260 proto1 = message_module.TestAllTypes() 261 proto2 = message_module.TestAllTypes() 262 proto2.optional_nested_message.bb = 5 263 proto1.MergeFrom(proto2) 264 self.assertTrue(proto1.HasField('optional_nested_message')) 265 proto1.ClearField('optional_nested_message') 266 self.assertFalse(proto1.HasField('optional_nested_message')) 267 268 def testDisconnectingLazyNestedMessage(self, message_module): 269 # This test exercises releasing a nested message that is lazy. This test 270 # only exercises real code in the C++ implementation as Python does not 271 # support lazy parsing, but the current C++ implementation results in 272 # memory corruption and a crash. 273 if api_implementation.Type() != 'python': 274 return 275 proto = message_module.TestAllTypes() 276 proto.optional_lazy_message.bb = 5 277 proto.ClearField('optional_lazy_message') 278 del proto 279 gc.collect() 280 281 def testSingularListFields(self, message_module): 282 proto = message_module.TestAllTypes() 283 proto.optional_fixed32 = 1 284 proto.optional_int32 = 5 285 proto.optional_string = 'foo' 286 # Access sub-message but don't set it yet. 287 nested_message = proto.optional_nested_message 288 self.assertEqual( 289 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 290 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 291 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], 292 proto.ListFields()) 293 294 proto.optional_nested_message.bb = 123 295 self.assertEqual( 296 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 297 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 298 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), 299 (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], 300 nested_message) ], 301 proto.ListFields()) 302 303 def testRepeatedListFields(self, message_module): 304 proto = message_module.TestAllTypes() 305 proto.repeated_fixed32.append(1) 306 proto.repeated_int32.append(5) 307 proto.repeated_int32.append(11) 308 proto.repeated_string.extend(['foo', 'bar']) 309 proto.repeated_string.extend([]) 310 proto.repeated_string.append('baz') 311 proto.repeated_string.extend(str(x) for x in range(2)) 312 proto.optional_int32 = 21 313 proto.repeated_bool # Access but don't set anything; should not be listed. 314 self.assertEqual( 315 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), 316 (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), 317 (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), 318 (proto.DESCRIPTOR.fields_by_name['repeated_string' ], 319 ['foo', 'bar', 'baz', '0', '1']) ], 320 proto.ListFields()) 321 322 def testClearFieldWithUnknownFieldName(self, message_module): 323 proto = message_module.TestAllTypes() 324 self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') 325 self.assertRaises(ValueError, proto.ClearField, b'nonexistent_field') 326 327 def testDisallowedAssignments(self, message_module): 328 # It's illegal to assign values directly to repeated fields 329 # or to nonrepeated composite fields. Ensure that this fails. 330 proto = message_module.TestAllTypes() 331 # Repeated fields. 332 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) 333 # Lists shouldn't work, either. 334 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) 335 # Composite fields. 336 self.assertRaises(AttributeError, setattr, proto, 337 'optional_nested_message', 23) 338 # Assignment to a repeated nested message field without specifying 339 # the index in the array of nested messages. 340 self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, 341 'bb', 34) 342 # Assignment to an attribute of a repeated field. 343 self.assertRaises(AttributeError, setattr, proto.repeated_float, 344 'some_attribute', 34) 345 # proto.nonexistent_field = 23 should fail as well. 346 self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) 347 348 def testSingleScalarTypeSafety(self, message_module): 349 proto = message_module.TestAllTypes() 350 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) 351 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') 352 self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) 353 self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) 354 self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo') 355 self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo') 356 self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo') 357 # TODO: Fix type checking difference for python and c extension 358 if (api_implementation.Type() == 'python' or 359 (sys.version_info.major, sys.version_info.minor) >= (3, 10)): 360 self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1) 361 else: 362 proto.optional_bool = 1.1 363 364 def assertIntegerTypes(self, integer_fn, message_module): 365 """Verifies setting of scalar integers. 366 367 Args: 368 integer_fn: A function to wrap the integers that will be assigned. 369 message_module: unittest_pb2 or unittest_proto3_arena_pb2 370 """ 371 def TestGetAndDeserialize(field_name, value, expected_type): 372 proto = message_module.TestAllTypes() 373 value = integer_fn(value) 374 setattr(proto, field_name, value) 375 self.assertIsInstance(getattr(proto, field_name), expected_type) 376 proto2 = message_module.TestAllTypes() 377 proto2.ParseFromString(proto.SerializeToString()) 378 self.assertIsInstance(getattr(proto2, field_name), expected_type) 379 380 TestGetAndDeserialize('optional_int32', 1, int) 381 TestGetAndDeserialize('optional_int32', 1 << 30, int) 382 TestGetAndDeserialize('optional_uint32', 1 << 30, int) 383 integer_64 = int 384 if struct.calcsize('L') == 4: 385 # Python only has signed ints, so 32-bit python can't fit an uint32 386 # in an int. 387 TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64) 388 else: 389 # 64-bit python can fit uint32 inside an int 390 TestGetAndDeserialize('optional_uint32', 1 << 31, int) 391 TestGetAndDeserialize('optional_int64', 1 << 30, integer_64) 392 TestGetAndDeserialize('optional_int64', 1 << 60, integer_64) 393 TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) 394 TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) 395 396 def testIntegerTypes(self, message_module): 397 self.assertIntegerTypes(lambda x: x, message_module) 398 399 def testNonStandardIntegerTypes(self, message_module): 400 self.assertIntegerTypes(test_util.NonStandardInteger, message_module) 401 402 def testIllegalValuesForIntegers(self, message_module): 403 pb = message_module.TestAllTypes() 404 405 # Strings are illegal, even when the represent an integer. 406 with self.assertRaises(TypeError): 407 pb.optional_uint64 = '2' 408 409 # The exact error should propagate with a poorly written custom integer. 410 with self.assertRaisesRegex(RuntimeError, 'my_error'): 411 pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error') 412 413 def assetIntegerBoundsChecking(self, integer_fn, message_module): 414 """Verifies bounds checking for scalar integer fields. 415 416 Args: 417 integer_fn: A function to wrap the integers that will be assigned. 418 message_module: unittest_pb2 or unittest_proto3_arena_pb2 419 """ 420 def TestMinAndMaxIntegers(field_name, expected_min, expected_max): 421 pb = message_module.TestAllTypes() 422 expected_min = integer_fn(expected_min) 423 expected_max = integer_fn(expected_max) 424 setattr(pb, field_name, expected_min) 425 self.assertEqual(expected_min, getattr(pb, field_name)) 426 setattr(pb, field_name, expected_max) 427 self.assertEqual(expected_max, getattr(pb, field_name)) 428 self.assertRaises((ValueError, TypeError), setattr, pb, field_name, 429 expected_min - 1) 430 self.assertRaises((ValueError, TypeError), setattr, pb, field_name, 431 expected_max + 1) 432 433 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) 434 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) 435 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) 436 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) 437 # A bit of white-box testing since -1 is an int and not a long in C++ and 438 # so goes down a different path. 439 pb = message_module.TestAllTypes() 440 with self.assertRaises((ValueError, TypeError)): 441 pb.optional_uint64 = integer_fn(-(1 << 63)) 442 443 pb = message_module.TestAllTypes() 444 pb.optional_nested_enum = integer_fn(1) 445 self.assertEqual(1, pb.optional_nested_enum) 446 447 def testSingleScalarBoundsChecking(self, message_module): 448 self.assetIntegerBoundsChecking(lambda x: x, message_module) 449 450 def testNonStandardSingleScalarBoundsChecking(self, message_module): 451 self.assetIntegerBoundsChecking( 452 test_util.NonStandardInteger, message_module) 453 454 def testRepeatedScalarTypeSafety(self, message_module): 455 proto = message_module.TestAllTypes() 456 self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) 457 self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') 458 self.assertRaises(TypeError, proto.repeated_string, 10) 459 self.assertRaises(TypeError, proto.repeated_bytes, 10) 460 461 proto.repeated_int32.append(10) 462 proto.repeated_int32[0] = 23 463 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) 464 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') 465 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, []) 466 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 467 'index', 23) 468 469 proto.repeated_string.append('2') 470 self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10) 471 472 # Repeated enums tests. 473 # proto.repeated_nested_enum.append(0) 474 475 def testSingleScalarGettersAndSetters(self, message_module): 476 proto = message_module.TestAllTypes() 477 self.assertEqual(0, proto.optional_int32) 478 proto.optional_int32 = 1 479 self.assertEqual(1, proto.optional_int32) 480 481 proto.optional_uint64 = 0xffffffffffff 482 self.assertEqual(0xffffffffffff, proto.optional_uint64) 483 proto.optional_uint64 = 0xffffffffffffffff 484 self.assertEqual(0xffffffffffffffff, proto.optional_uint64) 485 # TODO: Test all other scalar field types. 486 487 def testEnums(self, message_module): 488 proto = message_module.TestAllTypes() 489 self.assertEqual(1, proto.FOO) 490 self.assertEqual(1, message_module.TestAllTypes.FOO) 491 self.assertEqual(2, proto.BAR) 492 self.assertEqual(2, message_module.TestAllTypes.BAR) 493 self.assertEqual(3, proto.BAZ) 494 self.assertEqual(3, message_module.TestAllTypes.BAZ) 495 496 def testEnum_Name(self, message_module): 497 self.assertEqual( 498 'FOREIGN_FOO', 499 message_module.ForeignEnum.Name(message_module.FOREIGN_FOO)) 500 self.assertEqual( 501 'FOREIGN_BAR', 502 message_module.ForeignEnum.Name(message_module.FOREIGN_BAR)) 503 self.assertEqual( 504 'FOREIGN_BAZ', 505 message_module.ForeignEnum.Name(message_module.FOREIGN_BAZ)) 506 self.assertRaises(ValueError, 507 message_module.ForeignEnum.Name, 11312) 508 509 proto = message_module.TestAllTypes() 510 self.assertEqual('FOO', 511 proto.NestedEnum.Name(proto.FOO)) 512 self.assertEqual('FOO', 513 message_module.TestAllTypes.NestedEnum.Name(proto.FOO)) 514 self.assertEqual('BAR', 515 proto.NestedEnum.Name(proto.BAR)) 516 self.assertEqual('BAR', 517 message_module.TestAllTypes.NestedEnum.Name(proto.BAR)) 518 self.assertEqual('BAZ', 519 proto.NestedEnum.Name(proto.BAZ)) 520 self.assertEqual('BAZ', 521 message_module.TestAllTypes.NestedEnum.Name(proto.BAZ)) 522 self.assertRaises(ValueError, 523 proto.NestedEnum.Name, 11312) 524 self.assertRaises(ValueError, 525 message_module.TestAllTypes.NestedEnum.Name, 11312) 526 527 # Check some coercion cases. 528 self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name, 529 11312.0) 530 self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name, 531 None) 532 self.assertEqual('FOO', message_module.TestAllTypes.NestedEnum.Name(True)) 533 534 def testEnum_Value(self, message_module): 535 self.assertEqual(message_module.FOREIGN_FOO, 536 message_module.ForeignEnum.Value('FOREIGN_FOO')) 537 self.assertEqual(message_module.FOREIGN_FOO, 538 message_module.ForeignEnum.FOREIGN_FOO) 539 540 self.assertEqual(message_module.FOREIGN_BAR, 541 message_module.ForeignEnum.Value('FOREIGN_BAR')) 542 self.assertEqual(message_module.FOREIGN_BAR, 543 message_module.ForeignEnum.FOREIGN_BAR) 544 545 self.assertEqual(message_module.FOREIGN_BAZ, 546 message_module.ForeignEnum.Value('FOREIGN_BAZ')) 547 self.assertEqual(message_module.FOREIGN_BAZ, 548 message_module.ForeignEnum.FOREIGN_BAZ) 549 550 self.assertRaises(ValueError, 551 message_module.ForeignEnum.Value, 'FO') 552 with self.assertRaises(AttributeError): 553 message_module.ForeignEnum.FO 554 555 proto = message_module.TestAllTypes() 556 self.assertEqual(proto.FOO, 557 proto.NestedEnum.Value('FOO')) 558 self.assertEqual(proto.FOO, 559 proto.NestedEnum.FOO) 560 561 self.assertEqual(proto.FOO, 562 message_module.TestAllTypes.NestedEnum.Value('FOO')) 563 self.assertEqual(proto.FOO, 564 message_module.TestAllTypes.NestedEnum.FOO) 565 566 self.assertEqual(proto.BAR, 567 proto.NestedEnum.Value('BAR')) 568 self.assertEqual(proto.BAR, 569 proto.NestedEnum.BAR) 570 571 self.assertEqual(proto.BAR, 572 message_module.TestAllTypes.NestedEnum.Value('BAR')) 573 self.assertEqual(proto.BAR, 574 message_module.TestAllTypes.NestedEnum.BAR) 575 576 self.assertEqual(proto.BAZ, 577 proto.NestedEnum.Value('BAZ')) 578 self.assertEqual(proto.BAZ, 579 proto.NestedEnum.BAZ) 580 581 self.assertEqual(proto.BAZ, 582 message_module.TestAllTypes.NestedEnum.Value('BAZ')) 583 self.assertEqual(proto.BAZ, 584 message_module.TestAllTypes.NestedEnum.BAZ) 585 586 self.assertRaises(ValueError, 587 proto.NestedEnum.Value, 'Foo') 588 with self.assertRaises(AttributeError): 589 proto.NestedEnum.Value.Foo 590 591 self.assertRaises(ValueError, 592 message_module.TestAllTypes.NestedEnum.Value, 'Foo') 593 with self.assertRaises(AttributeError): 594 message_module.TestAllTypes.NestedEnum.Value.Foo 595 596 def testEnum_KeysAndValues(self, message_module): 597 if message_module == unittest_pb2: 598 keys = [ 599 'FOREIGN_FOO', 600 'FOREIGN_BAR', 601 'FOREIGN_BAZ', 602 'FOREIGN_BAX', 603 'FOREIGN_LARGE', 604 ] 605 values = [4, 5, 6, 32, 123456] 606 items = [ 607 ('FOREIGN_FOO', 4), 608 ('FOREIGN_BAR', 5), 609 ('FOREIGN_BAZ', 6), 610 ('FOREIGN_BAX', 32), 611 ('FOREIGN_LARGE', 123456), 612 ] 613 else: 614 keys = [ 615 'FOREIGN_ZERO', 616 'FOREIGN_FOO', 617 'FOREIGN_BAR', 618 'FOREIGN_BAZ', 619 'FOREIGN_LARGE', 620 ] 621 values = [0, 4, 5, 6, 123456] 622 items = [ 623 ('FOREIGN_ZERO', 0), 624 ('FOREIGN_FOO', 4), 625 ('FOREIGN_BAR', 5), 626 ('FOREIGN_BAZ', 6), 627 ('FOREIGN_LARGE', 123456), 628 ] 629 self.assertEqual(keys, 630 list(message_module.ForeignEnum.keys())) 631 self.assertEqual(values, 632 list(message_module.ForeignEnum.values())) 633 self.assertEqual(items, 634 list(message_module.ForeignEnum.items())) 635 636 proto = message_module.TestAllTypes() 637 if message_module == unittest_pb2: 638 keys = ['FOO', 'BAR', 'BAZ', 'NEG'] 639 values = [1, 2, 3, -1] 640 items = [('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)] 641 else: 642 keys = ['ZERO', 'FOO', 'BAR', 'BAZ', 'NEG'] 643 values = [0, 1, 2, 3, -1] 644 items = [('ZERO', 0), ('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)] 645 self.assertEqual(keys, list(proto.NestedEnum.keys())) 646 self.assertEqual(values, list(proto.NestedEnum.values())) 647 self.assertEqual(items, 648 list(proto.NestedEnum.items())) 649 650 def testStaticParseFrom(self, message_module): 651 proto1 = message_module.TestAllTypes() 652 test_util.SetAllFields(proto1) 653 654 string1 = proto1.SerializeToString() 655 proto2 = message_module.TestAllTypes.FromString(string1) 656 657 # Messages should be equal. 658 self.assertEqual(proto2, proto1) 659 660 def testMergeFromSingularField(self, message_module): 661 # Test merge with just a singular field. 662 proto1 = message_module.TestAllTypes() 663 proto1.optional_int32 = 1 664 665 proto2 = message_module.TestAllTypes() 666 # This shouldn't get overwritten. 667 proto2.optional_string = 'value' 668 669 proto2.MergeFrom(proto1) 670 self.assertEqual(1, proto2.optional_int32) 671 self.assertEqual('value', proto2.optional_string) 672 673 def testMergeFromRepeatedField(self, message_module): 674 # Test merge with just a repeated field. 675 proto1 = message_module.TestAllTypes() 676 proto1.repeated_int32.append(1) 677 proto1.repeated_int32.append(2) 678 679 proto2 = message_module.TestAllTypes() 680 proto2.repeated_int32.append(0) 681 proto2.MergeFrom(proto1) 682 683 self.assertEqual(0, proto2.repeated_int32[0]) 684 self.assertEqual(1, proto2.repeated_int32[1]) 685 self.assertEqual(2, proto2.repeated_int32[2]) 686 687 def testMergeFromRepeatedNestedMessage(self, message_module): 688 # Test merge with a repeated nested message. 689 proto1 = message_module.TestAllTypes() 690 m = proto1.repeated_nested_message.add() 691 m.bb = 123 692 m = proto1.repeated_nested_message.add() 693 m.bb = 321 694 695 proto2 = message_module.TestAllTypes() 696 m = proto2.repeated_nested_message.add() 697 m.bb = 999 698 proto2.MergeFrom(proto1) 699 self.assertEqual(999, proto2.repeated_nested_message[0].bb) 700 self.assertEqual(123, proto2.repeated_nested_message[1].bb) 701 self.assertEqual(321, proto2.repeated_nested_message[2].bb) 702 703 proto3 = message_module.TestAllTypes() 704 proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) 705 self.assertEqual(999, proto3.repeated_nested_message[0].bb) 706 self.assertEqual(123, proto3.repeated_nested_message[1].bb) 707 self.assertEqual(321, proto3.repeated_nested_message[2].bb) 708 709 def testMergeFromAllFields(self, message_module): 710 # With all fields set. 711 proto1 = message_module.TestAllTypes() 712 test_util.SetAllFields(proto1) 713 proto2 = message_module.TestAllTypes() 714 proto2.MergeFrom(proto1) 715 716 # Messages should be equal. 717 self.assertEqual(proto2, proto1) 718 719 # Serialized string should be equal too. 720 string1 = proto1.SerializeToString() 721 string2 = proto2.SerializeToString() 722 self.assertEqual(string1, string2) 723 724 def testMergeFromBug(self, message_module): 725 message1 = message_module.TestAllTypes() 726 message2 = message_module.TestAllTypes() 727 728 # Cause optional_nested_message to be instantiated within message1, even 729 # though it is not considered to be "present". 730 message1.optional_nested_message 731 self.assertFalse(message1.HasField('optional_nested_message')) 732 733 # Merge into message2. This should not instantiate the field is message2. 734 message2.MergeFrom(message1) 735 self.assertFalse(message2.HasField('optional_nested_message')) 736 737 def testCopyFromSingularField(self, message_module): 738 # Test copy with just a singular field. 739 proto1 = message_module.TestAllTypes() 740 proto1.optional_int32 = 1 741 proto1.optional_string = 'important-text' 742 743 proto2 = message_module.TestAllTypes() 744 proto2.optional_string = 'value' 745 746 proto2.CopyFrom(proto1) 747 self.assertEqual(1, proto2.optional_int32) 748 self.assertEqual('important-text', proto2.optional_string) 749 750 def testCopyFromRepeatedField(self, message_module): 751 # Test copy with a repeated field. 752 proto1 = message_module.TestAllTypes() 753 proto1.repeated_int32.append(1) 754 proto1.repeated_int32.append(2) 755 756 proto2 = message_module.TestAllTypes() 757 proto2.repeated_int32.append(0) 758 proto2.CopyFrom(proto1) 759 760 self.assertEqual(1, proto2.repeated_int32[0]) 761 self.assertEqual(2, proto2.repeated_int32[1]) 762 763 def testCopyFromAllFields(self, message_module): 764 # With all fields set. 765 proto1 = message_module.TestAllTypes() 766 test_util.SetAllFields(proto1) 767 proto2 = message_module.TestAllTypes() 768 proto2.CopyFrom(proto1) 769 770 # Messages should be equal. 771 self.assertEqual(proto2, proto1) 772 773 # Serialized string should be equal too. 774 string1 = proto1.SerializeToString() 775 string2 = proto2.SerializeToString() 776 self.assertEqual(string1, string2) 777 778 def testCopyFromSelf(self, message_module): 779 proto1 = message_module.TestAllTypes() 780 proto1.repeated_int32.append(1) 781 proto1.optional_int32 = 2 782 proto1.optional_string = 'important-text' 783 784 proto1.CopyFrom(proto1) 785 self.assertEqual(1, proto1.repeated_int32[0]) 786 self.assertEqual(2, proto1.optional_int32) 787 self.assertEqual('important-text', proto1.optional_string) 788 789 def testDeepCopy(self, message_module): 790 proto1 = message_module.TestAllTypes() 791 proto1.optional_int32 = 1 792 proto2 = copy.deepcopy(proto1) 793 self.assertEqual(1, proto2.optional_int32) 794 795 proto1.repeated_int32.append(2) 796 proto1.repeated_int32.append(3) 797 container = copy.deepcopy(proto1.repeated_int32) 798 self.assertEqual([2, 3], container) 799 container.remove(container[0]) 800 self.assertEqual([3], container) 801 802 message1 = proto1.repeated_nested_message.add() 803 message1.bb = 1 804 messages = copy.deepcopy(proto1.repeated_nested_message) 805 self.assertEqual(proto1.repeated_nested_message, messages) 806 message1.bb = 2 807 self.assertNotEqual(proto1.repeated_nested_message, messages) 808 messages.remove(messages[0]) 809 self.assertEqual(len(messages), 0) 810 811 def testEmptyDeepCopy(self, message_module): 812 proto1 = message_module.TestAllTypes() 813 nested2 = copy.deepcopy(proto1.optional_nested_message) 814 self.assertEqual(0, nested2.bb) 815 816 # TODO: Implement deepcopy for extension dict 817 818 def testDisconnectingBeforeClear(self, message_module): 819 proto = message_module.TestAllTypes() 820 nested = proto.optional_nested_message 821 proto.Clear() 822 self.assertIsNot(nested, proto.optional_nested_message) 823 nested.bb = 23 824 self.assertFalse(proto.HasField('optional_nested_message')) 825 self.assertEqual(0, proto.optional_nested_message.bb) 826 827 proto = message_module.TestAllTypes() 828 nested = proto.optional_nested_message 829 nested.bb = 5 830 foreign = proto.optional_foreign_message 831 foreign.c = 6 832 proto.Clear() 833 self.assertIsNot(nested, proto.optional_nested_message) 834 self.assertIsNot(foreign, proto.optional_foreign_message) 835 self.assertEqual(5, nested.bb) 836 self.assertEqual(6, foreign.c) 837 nested.bb = 15 838 foreign.c = 16 839 self.assertFalse(proto.HasField('optional_nested_message')) 840 self.assertEqual(0, proto.optional_nested_message.bb) 841 self.assertFalse(proto.HasField('optional_foreign_message')) 842 self.assertEqual(0, proto.optional_foreign_message.c) 843 844 def testStringUTF8Encoding(self, message_module): 845 proto = message_module.TestAllTypes() 846 847 # Assignment of a unicode object to a field of type 'bytes' is not allowed. 848 self.assertRaises(TypeError, 849 setattr, proto, 'optional_bytes', u'unicode object') 850 851 # Check that the default value is of python's 'unicode' type. 852 self.assertEqual(type(proto.optional_string), str) 853 854 proto.optional_string = str('Testing') 855 self.assertEqual(proto.optional_string, str('Testing')) 856 857 # Assign a value of type 'str' which can be encoded in UTF-8. 858 proto.optional_string = str('Testing') 859 self.assertEqual(proto.optional_string, str('Testing')) 860 861 # Try to assign a 'bytes' object which contains non-UTF-8. 862 self.assertRaises(ValueError, 863 setattr, proto, 'optional_string', b'a\x80a') 864 # No exception: Assign already encoded UTF-8 bytes to a string field. 865 utf8_bytes = u'Тест'.encode('utf-8') 866 proto.optional_string = utf8_bytes 867 # No exception: Assign the a non-ascii unicode object. 868 proto.optional_string = u'Тест' 869 # No exception thrown (normal str assignment containing ASCII). 870 proto.optional_string = 'abc' 871 872 def testBytesInTextFormat(self, message_module): 873 proto = message_module.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') 874 self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', str(proto)) 875 876 def testEmptyNestedMessage(self, message_module): 877 proto = message_module.TestAllTypes() 878 proto.optional_nested_message.MergeFrom( 879 message_module.TestAllTypes.NestedMessage()) 880 self.assertTrue(proto.HasField('optional_nested_message')) 881 882 proto = message_module.TestAllTypes() 883 proto.optional_nested_message.CopyFrom( 884 message_module.TestAllTypes.NestedMessage()) 885 self.assertTrue(proto.HasField('optional_nested_message')) 886 887 proto = message_module.TestAllTypes() 888 bytes_read = proto.optional_nested_message.MergeFromString(b'') 889 self.assertEqual(0, bytes_read) 890 self.assertTrue(proto.HasField('optional_nested_message')) 891 892 proto = message_module.TestAllTypes() 893 proto.optional_nested_message.ParseFromString(b'') 894 self.assertTrue(proto.HasField('optional_nested_message')) 895 896 serialized = proto.SerializeToString() 897 proto2 = message_module.TestAllTypes() 898 self.assertEqual( 899 len(serialized), 900 proto2.MergeFromString(serialized)) 901 self.assertTrue(proto2.HasField('optional_nested_message')) 902 903 904# Class to test proto2-only features (required, extensions, etc.) 905@testing_refleaks.TestCase 906class Proto2ReflectionTest(unittest.TestCase): 907 908 def testRepeatedCompositeConstructor(self): 909 # Constructor with only repeated composite types should succeed. 910 proto = unittest_pb2.TestAllTypes( 911 repeated_nested_message=[ 912 unittest_pb2.TestAllTypes.NestedMessage( 913 bb=unittest_pb2.TestAllTypes.FOO), 914 unittest_pb2.TestAllTypes.NestedMessage( 915 bb=unittest_pb2.TestAllTypes.BAR)], 916 repeated_foreign_message=[ 917 unittest_pb2.ForeignMessage(c=-43), 918 unittest_pb2.ForeignMessage(c=45324), 919 unittest_pb2.ForeignMessage(c=12)], 920 repeatedgroup=[ 921 unittest_pb2.TestAllTypes.RepeatedGroup(), 922 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 923 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) 924 925 self.assertEqual( 926 [unittest_pb2.TestAllTypes.NestedMessage( 927 bb=unittest_pb2.TestAllTypes.FOO), 928 unittest_pb2.TestAllTypes.NestedMessage( 929 bb=unittest_pb2.TestAllTypes.BAR)], 930 list(proto.repeated_nested_message)) 931 self.assertEqual( 932 [unittest_pb2.ForeignMessage(c=-43), 933 unittest_pb2.ForeignMessage(c=45324), 934 unittest_pb2.ForeignMessage(c=12)], 935 list(proto.repeated_foreign_message)) 936 self.assertEqual( 937 [unittest_pb2.TestAllTypes.RepeatedGroup(), 938 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 939 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], 940 list(proto.repeatedgroup)) 941 942 def assertListsEqual(self, values, others): 943 self.assertEqual(len(values), len(others)) 944 for i in range(len(values)): 945 self.assertEqual(values[i], others[i]) 946 947 def testSimpleHasBits(self): 948 # Test a scalar. 949 proto = unittest_pb2.TestAllTypes() 950 self.assertFalse(proto.HasField('optional_int32')) 951 self.assertEqual(0, proto.optional_int32) 952 # HasField() shouldn't be true if all we've done is 953 # read the default value. 954 self.assertFalse(proto.HasField('optional_int32')) 955 proto.optional_int32 = 1 956 # Setting a value however *should* set the "has" bit. 957 self.assertTrue(proto.HasField('optional_int32')) 958 proto.ClearField('optional_int32') 959 # And clearing that value should unset the "has" bit. 960 self.assertFalse(proto.HasField('optional_int32')) 961 962 def testHasBitsWithSinglyNestedScalar(self): 963 # Helper used to test foreign messages and groups. 964 # 965 # composite_field_name should be the name of a non-repeated 966 # composite (i.e., foreign or group) field in TestAllTypes, 967 # and scalar_field_name should be the name of an integer-valued 968 # scalar field within that composite. 969 # 970 # I never thought I'd miss C++ macros and templates so much. :( 971 # This helper is semantically just: 972 # 973 # assert proto.composite_field.scalar_field == 0 974 # assert not proto.composite_field.HasField('scalar_field') 975 # assert not proto.HasField('composite_field') 976 # 977 # proto.composite_field.scalar_field = 10 978 # old_composite_field = proto.composite_field 979 # 980 # assert proto.composite_field.scalar_field == 10 981 # assert proto.composite_field.HasField('scalar_field') 982 # assert proto.HasField('composite_field') 983 # 984 # proto.ClearField('composite_field') 985 # 986 # assert not proto.composite_field.HasField('scalar_field') 987 # assert not proto.HasField('composite_field') 988 # assert proto.composite_field.scalar_field == 0 989 # 990 # # Now ensure that ClearField('composite_field') disconnected 991 # # the old field object from the object tree... 992 # assert old_composite_field is not proto.composite_field 993 # old_composite_field.scalar_field = 20 994 # assert not proto.composite_field.HasField('scalar_field') 995 # assert not proto.HasField('composite_field') 996 def TestCompositeHasBits(composite_field_name, scalar_field_name): 997 proto = unittest_pb2.TestAllTypes() 998 # First, check that we can get the scalar value, and see that it's the 999 # default (0), but that proto.HasField('omposite') and 1000 # proto.composite.HasField('scalar') will still return False. 1001 composite_field = getattr(proto, composite_field_name) 1002 original_scalar_value = getattr(composite_field, scalar_field_name) 1003 self.assertEqual(0, original_scalar_value) 1004 # Assert that the composite object does not "have" the scalar. 1005 self.assertFalse(composite_field.HasField(scalar_field_name)) 1006 # Assert that proto does not "have" the composite field. 1007 self.assertFalse(proto.HasField(composite_field_name)) 1008 1009 # Now set the scalar within the composite field. Ensure that the setting 1010 # is reflected, and that proto.HasField('composite') and 1011 # proto.composite.HasField('scalar') now both return True. 1012 new_val = 20 1013 setattr(composite_field, scalar_field_name, new_val) 1014 self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) 1015 # Hold on to a reference to the current composite_field object. 1016 old_composite_field = composite_field 1017 # Assert that the has methods now return true. 1018 self.assertTrue(composite_field.HasField(scalar_field_name)) 1019 self.assertTrue(proto.HasField(composite_field_name)) 1020 1021 # Now call the clear method... 1022 proto.ClearField(composite_field_name) 1023 1024 # ...and ensure that the "has" bits are all back to False... 1025 composite_field = getattr(proto, composite_field_name) 1026 self.assertFalse(composite_field.HasField(scalar_field_name)) 1027 self.assertFalse(proto.HasField(composite_field_name)) 1028 # ...and ensure that the scalar field has returned to its default. 1029 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 1030 1031 self.assertIsNot(old_composite_field, composite_field) 1032 setattr(old_composite_field, scalar_field_name, new_val) 1033 self.assertFalse(composite_field.HasField(scalar_field_name)) 1034 self.assertFalse(proto.HasField(composite_field_name)) 1035 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 1036 1037 # Test simple, single-level nesting when we set a scalar. 1038 TestCompositeHasBits('optionalgroup', 'a') 1039 TestCompositeHasBits('optional_nested_message', 'bb') 1040 TestCompositeHasBits('optional_foreign_message', 'c') 1041 TestCompositeHasBits('optional_import_message', 'd') 1042 1043 def testHasBitsWhenModifyingRepeatedFields(self): 1044 # Test nesting when we add an element to a repeated field in a submessage. 1045 proto = unittest_pb2.TestNestedMessageHasBits() 1046 proto.optional_nested_message.nestedmessage_repeated_int32.append(5) 1047 self.assertEqual( 1048 [5], proto.optional_nested_message.nestedmessage_repeated_int32) 1049 self.assertTrue(proto.HasField('optional_nested_message')) 1050 1051 # Do the same test, but with a repeated composite field within the 1052 # submessage. 1053 proto.ClearField('optional_nested_message') 1054 self.assertFalse(proto.HasField('optional_nested_message')) 1055 proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() 1056 self.assertTrue(proto.HasField('optional_nested_message')) 1057 1058 def testHasBitsForManyLevelsOfNesting(self): 1059 # Test nesting many levels deep. 1060 recursive_proto = unittest_pb2.TestMutualRecursionA() 1061 self.assertFalse(recursive_proto.HasField('bb')) 1062 self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) 1063 self.assertFalse(recursive_proto.HasField('bb')) 1064 recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 1065 self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) 1066 self.assertTrue(recursive_proto.HasField('bb')) 1067 self.assertTrue(recursive_proto.bb.HasField('a')) 1068 self.assertTrue(recursive_proto.bb.a.HasField('bb')) 1069 self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) 1070 self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) 1071 self.assertFalse(recursive_proto.bb.a.bb.a.bb.HasField('a')) 1072 self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) 1073 1074 def testSingularListExtensions(self): 1075 proto = unittest_pb2.TestAllExtensions() 1076 proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 1077 proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 1078 proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' 1079 self.assertEqual( 1080 [ (unittest_pb2.optional_int32_extension , 5), 1081 (unittest_pb2.optional_fixed32_extension, 1), 1082 (unittest_pb2.optional_string_extension , 'foo') ], 1083 proto.ListFields()) 1084 del proto.Extensions[unittest_pb2.optional_fixed32_extension] 1085 self.assertEqual( 1086 [(unittest_pb2.optional_int32_extension, 5), 1087 (unittest_pb2.optional_string_extension, 'foo')], 1088 proto.ListFields()) 1089 1090 def testRepeatedListExtensions(self): 1091 proto = unittest_pb2.TestAllExtensions() 1092 proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) 1093 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) 1094 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) 1095 proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') 1096 proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') 1097 proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') 1098 proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 1099 self.assertEqual( 1100 [ (unittest_pb2.optional_int32_extension , 21), 1101 (unittest_pb2.repeated_int32_extension , [5, 11]), 1102 (unittest_pb2.repeated_fixed32_extension, [1]), 1103 (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], 1104 proto.ListFields()) 1105 del proto.Extensions[unittest_pb2.repeated_int32_extension] 1106 del proto.Extensions[unittest_pb2.repeated_string_extension] 1107 self.assertEqual( 1108 [(unittest_pb2.optional_int32_extension, 21), 1109 (unittest_pb2.repeated_fixed32_extension, [1])], 1110 proto.ListFields()) 1111 1112 def testListFieldsAndExtensions(self): 1113 proto = unittest_pb2.TestFieldOrderings() 1114 test_util.SetAllFieldsAndExtensions(proto) 1115 unittest_pb2.my_extension_int 1116 self.assertEqual( 1117 [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), 1118 (unittest_pb2.my_extension_int , 23), 1119 (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), 1120 (unittest_pb2.my_extension_string , 'bar'), 1121 (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], 1122 proto.ListFields()) 1123 1124 def testDefaultValues(self): 1125 proto = unittest_pb2.TestAllTypes() 1126 self.assertEqual(0, proto.optional_int32) 1127 self.assertEqual(0, proto.optional_int64) 1128 self.assertEqual(0, proto.optional_uint32) 1129 self.assertEqual(0, proto.optional_uint64) 1130 self.assertEqual(0, proto.optional_sint32) 1131 self.assertEqual(0, proto.optional_sint64) 1132 self.assertEqual(0, proto.optional_fixed32) 1133 self.assertEqual(0, proto.optional_fixed64) 1134 self.assertEqual(0, proto.optional_sfixed32) 1135 self.assertEqual(0, proto.optional_sfixed64) 1136 self.assertEqual(0.0, proto.optional_float) 1137 self.assertEqual(0.0, proto.optional_double) 1138 self.assertEqual(False, proto.optional_bool) 1139 self.assertEqual('', proto.optional_string) 1140 self.assertEqual(b'', proto.optional_bytes) 1141 1142 self.assertEqual(41, proto.default_int32) 1143 self.assertEqual(42, proto.default_int64) 1144 self.assertEqual(43, proto.default_uint32) 1145 self.assertEqual(44, proto.default_uint64) 1146 self.assertEqual(-45, proto.default_sint32) 1147 self.assertEqual(46, proto.default_sint64) 1148 self.assertEqual(47, proto.default_fixed32) 1149 self.assertEqual(48, proto.default_fixed64) 1150 self.assertEqual(49, proto.default_sfixed32) 1151 self.assertEqual(-50, proto.default_sfixed64) 1152 self.assertEqual(51.5, proto.default_float) 1153 self.assertEqual(52e3, proto.default_double) 1154 self.assertEqual(True, proto.default_bool) 1155 self.assertEqual('hello', proto.default_string) 1156 self.assertEqual(b'world', proto.default_bytes) 1157 self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) 1158 self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) 1159 self.assertEqual(unittest_import_pb2.IMPORT_BAR, 1160 proto.default_import_enum) 1161 1162 proto = unittest_pb2.TestExtremeDefaultValues() 1163 self.assertEqual(u'\u1234', proto.utf8_string) 1164 1165 def testHasFieldWithUnknownFieldName(self): 1166 proto = unittest_pb2.TestAllTypes() 1167 self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') 1168 1169 def testClearRemovesChildren(self): 1170 # Make sure there aren't any implementation bugs that are only partially 1171 # clearing the message (which can happen in the more complex C++ 1172 # implementation which has parallel message lists). 1173 proto = unittest_pb2.TestRequiredForeign() 1174 for i in range(10): 1175 proto.repeated_message.add() 1176 proto2 = unittest_pb2.TestRequiredForeign() 1177 proto.CopyFrom(proto2) 1178 self.assertRaises(IndexError, lambda: proto.repeated_message[5]) 1179 1180 def testSingleScalarClearField(self): 1181 proto = unittest_pb2.TestAllTypes() 1182 # Should be allowed to clear something that's not there (a no-op). 1183 proto.ClearField('optional_int32') 1184 proto.optional_int32 = 1 1185 self.assertTrue(proto.HasField('optional_int32')) 1186 proto.ClearField('optional_int32') 1187 self.assertEqual(0, proto.optional_int32) 1188 self.assertFalse(proto.HasField('optional_int32')) 1189 # TODO: Test all other scalar field types. 1190 1191 def testRepeatedScalars(self): 1192 proto = unittest_pb2.TestAllTypes() 1193 1194 self.assertFalse(proto.repeated_int32) 1195 self.assertEqual(0, len(proto.repeated_int32)) 1196 proto.repeated_int32.append(5) 1197 proto.repeated_int32.append(10) 1198 proto.repeated_int32.append(15) 1199 self.assertTrue(proto.repeated_int32) 1200 self.assertEqual(3, len(proto.repeated_int32)) 1201 1202 self.assertEqual([5, 10, 15], proto.repeated_int32) 1203 1204 # Test single retrieval. 1205 self.assertEqual(5, proto.repeated_int32[0]) 1206 self.assertEqual(15, proto.repeated_int32[-1]) 1207 # Test out-of-bounds indices. 1208 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) 1209 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) 1210 # Test incorrect types passed to __getitem__. 1211 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') 1212 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) 1213 1214 # Test single assignment. 1215 proto.repeated_int32[1] = 20 1216 self.assertEqual([5, 20, 15], proto.repeated_int32) 1217 1218 # Test insertion. 1219 proto.repeated_int32.insert(1, 25) 1220 self.assertEqual([5, 25, 20, 15], proto.repeated_int32) 1221 1222 # Test slice retrieval. 1223 proto.repeated_int32.append(30) 1224 self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) 1225 self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) 1226 1227 # Test slice assignment with an iterator 1228 proto.repeated_int32[1:4] = (i for i in range(3)) 1229 self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) 1230 1231 # Test slice assignment. 1232 proto.repeated_int32[1:4] = [35, 40, 45] 1233 self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) 1234 1235 # Test that we can use the field as an iterator. 1236 result = [] 1237 for i in proto.repeated_int32: 1238 result.append(i) 1239 self.assertEqual([5, 35, 40, 45, 30], result) 1240 1241 # Test single deletion. 1242 del proto.repeated_int32[2] 1243 self.assertEqual([5, 35, 45, 30], proto.repeated_int32) 1244 1245 # Test slice deletion. 1246 del proto.repeated_int32[2:] 1247 self.assertEqual([5, 35], proto.repeated_int32) 1248 1249 # Test extending. 1250 proto.repeated_int32.extend([3, 13]) 1251 self.assertEqual([5, 35, 3, 13], proto.repeated_int32) 1252 1253 # Test clearing. 1254 proto.ClearField('repeated_int32') 1255 self.assertFalse(proto.repeated_int32) 1256 self.assertEqual(0, len(proto.repeated_int32)) 1257 1258 proto.repeated_int32.append(1) 1259 self.assertEqual(1, proto.repeated_int32[-1]) 1260 # Test assignment to a negative index. 1261 proto.repeated_int32[-1] = 2 1262 self.assertEqual(2, proto.repeated_int32[-1]) 1263 1264 # Test deletion at negative indices. 1265 proto.repeated_int32[:] = [0, 1, 2, 3] 1266 del proto.repeated_int32[-1] 1267 self.assertEqual([0, 1, 2], proto.repeated_int32) 1268 1269 del proto.repeated_int32[-2] 1270 self.assertEqual([0, 2], proto.repeated_int32) 1271 1272 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) 1273 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) 1274 1275 del proto.repeated_int32[-2:-1] 1276 self.assertEqual([2], proto.repeated_int32) 1277 1278 del proto.repeated_int32[100:10000] 1279 self.assertEqual([2], proto.repeated_int32) 1280 1281 def testRepeatedScalarsRemove(self): 1282 proto = unittest_pb2.TestAllTypes() 1283 1284 self.assertFalse(proto.repeated_int32) 1285 self.assertEqual(0, len(proto.repeated_int32)) 1286 proto.repeated_int32.append(5) 1287 proto.repeated_int32.append(10) 1288 proto.repeated_int32.append(5) 1289 proto.repeated_int32.append(5) 1290 1291 self.assertEqual(4, len(proto.repeated_int32)) 1292 proto.repeated_int32.remove(5) 1293 self.assertEqual(3, len(proto.repeated_int32)) 1294 self.assertEqual(10, proto.repeated_int32[0]) 1295 self.assertEqual(5, proto.repeated_int32[1]) 1296 self.assertEqual(5, proto.repeated_int32[2]) 1297 1298 proto.repeated_int32.remove(5) 1299 self.assertEqual(2, len(proto.repeated_int32)) 1300 self.assertEqual(10, proto.repeated_int32[0]) 1301 self.assertEqual(5, proto.repeated_int32[1]) 1302 1303 proto.repeated_int32.remove(10) 1304 self.assertEqual(1, len(proto.repeated_int32)) 1305 self.assertEqual(5, proto.repeated_int32[0]) 1306 1307 # Remove a non-existent element. 1308 self.assertRaises(ValueError, proto.repeated_int32.remove, 123) 1309 1310 def testRepeatedScalarsReverse_Empty(self): 1311 proto = unittest_pb2.TestAllTypes() 1312 1313 self.assertFalse(proto.repeated_int32) 1314 self.assertEqual(0, len(proto.repeated_int32)) 1315 1316 self.assertIsNone(proto.repeated_int32.reverse()) 1317 1318 self.assertFalse(proto.repeated_int32) 1319 self.assertEqual(0, len(proto.repeated_int32)) 1320 1321 def testRepeatedScalarsReverse_NonEmpty(self): 1322 proto = unittest_pb2.TestAllTypes() 1323 1324 self.assertFalse(proto.repeated_int32) 1325 self.assertEqual(0, len(proto.repeated_int32)) 1326 1327 proto.repeated_int32.append(1) 1328 proto.repeated_int32.append(2) 1329 proto.repeated_int32.append(3) 1330 proto.repeated_int32.append(4) 1331 1332 self.assertEqual(4, len(proto.repeated_int32)) 1333 1334 self.assertIsNone(proto.repeated_int32.reverse()) 1335 1336 self.assertEqual(4, len(proto.repeated_int32)) 1337 self.assertEqual(4, proto.repeated_int32[0]) 1338 self.assertEqual(3, proto.repeated_int32[1]) 1339 self.assertEqual(2, proto.repeated_int32[2]) 1340 self.assertEqual(1, proto.repeated_int32[3]) 1341 1342 def testRepeatedComposites(self): 1343 proto = unittest_pb2.TestAllTypes() 1344 self.assertFalse(proto.repeated_nested_message) 1345 self.assertEqual(0, len(proto.repeated_nested_message)) 1346 m0 = proto.repeated_nested_message.add() 1347 m1 = proto.repeated_nested_message.add() 1348 self.assertTrue(proto.repeated_nested_message) 1349 self.assertEqual(2, len(proto.repeated_nested_message)) 1350 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 1351 self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage) 1352 1353 # Test out-of-bounds indices. 1354 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 1355 1234) 1356 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 1357 -1234) 1358 1359 # Test incorrect types passed to __getitem__. 1360 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 1361 'foo') 1362 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 1363 None) 1364 1365 # Test slice retrieval. 1366 m2 = proto.repeated_nested_message.add() 1367 m3 = proto.repeated_nested_message.add() 1368 m4 = proto.repeated_nested_message.add() 1369 self.assertListsEqual( 1370 [m1, m2, m3], proto.repeated_nested_message[1:4]) 1371 self.assertListsEqual( 1372 [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) 1373 self.assertListsEqual( 1374 [m0, m1], proto.repeated_nested_message[:2]) 1375 self.assertListsEqual( 1376 [m2, m3, m4], proto.repeated_nested_message[2:]) 1377 self.assertEqual( 1378 m0, proto.repeated_nested_message[0]) 1379 self.assertListsEqual( 1380 [m0], proto.repeated_nested_message[:1]) 1381 1382 # Test that we can use the field as an iterator. 1383 result = [] 1384 for i in proto.repeated_nested_message: 1385 result.append(i) 1386 self.assertListsEqual([m0, m1, m2, m3, m4], result) 1387 1388 # Test single deletion. 1389 del proto.repeated_nested_message[2] 1390 self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) 1391 1392 # Test slice deletion. 1393 del proto.repeated_nested_message[2:] 1394 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 1395 1396 # Test extending. 1397 n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) 1398 n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) 1399 proto.repeated_nested_message.extend([n1,n2]) 1400 self.assertEqual(4, len(proto.repeated_nested_message)) 1401 self.assertEqual(n1, proto.repeated_nested_message[2]) 1402 self.assertEqual(n2, proto.repeated_nested_message[3]) 1403 self.assertRaises(TypeError, 1404 proto.repeated_nested_message.extend, n1) 1405 self.assertRaises(TypeError, 1406 proto.repeated_nested_message.extend, [0]) 1407 wrong_message_type = unittest_pb2.TestAllTypes() 1408 self.assertRaises(TypeError, 1409 proto.repeated_nested_message.extend, 1410 [wrong_message_type]) 1411 1412 # Test clearing. 1413 proto.ClearField('repeated_nested_message') 1414 self.assertFalse(proto.repeated_nested_message) 1415 self.assertEqual(0, len(proto.repeated_nested_message)) 1416 1417 # Test constructing an element while adding it. 1418 proto.repeated_nested_message.add(bb=23) 1419 self.assertEqual(1, len(proto.repeated_nested_message)) 1420 self.assertEqual(23, proto.repeated_nested_message[0].bb) 1421 self.assertRaises(TypeError, proto.repeated_nested_message.add, 23) 1422 with self.assertRaises(Exception): 1423 proto.repeated_nested_message[0] = 23 1424 1425 def testRepeatedCompositeRemove(self): 1426 proto = unittest_pb2.TestAllTypes() 1427 1428 self.assertEqual(0, len(proto.repeated_nested_message)) 1429 m0 = proto.repeated_nested_message.add() 1430 # Need to set some differentiating variable so m0 != m1 != m2: 1431 m0.bb = len(proto.repeated_nested_message) 1432 m1 = proto.repeated_nested_message.add() 1433 m1.bb = len(proto.repeated_nested_message) 1434 self.assertTrue(m0 != m1) 1435 m2 = proto.repeated_nested_message.add() 1436 m2.bb = len(proto.repeated_nested_message) 1437 self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) 1438 1439 self.assertEqual(3, len(proto.repeated_nested_message)) 1440 proto.repeated_nested_message.remove(m0) 1441 self.assertEqual(2, len(proto.repeated_nested_message)) 1442 self.assertEqual(m1, proto.repeated_nested_message[0]) 1443 self.assertEqual(m2, proto.repeated_nested_message[1]) 1444 1445 # Removing m0 again or removing None should raise error 1446 self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) 1447 self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) 1448 self.assertEqual(2, len(proto.repeated_nested_message)) 1449 1450 proto.repeated_nested_message.remove(m2) 1451 self.assertEqual(1, len(proto.repeated_nested_message)) 1452 self.assertEqual(m1, proto.repeated_nested_message[0]) 1453 1454 def testRepeatedCompositeReverse_Empty(self): 1455 proto = unittest_pb2.TestAllTypes() 1456 1457 self.assertFalse(proto.repeated_nested_message) 1458 self.assertEqual(0, len(proto.repeated_nested_message)) 1459 1460 self.assertIsNone(proto.repeated_nested_message.reverse()) 1461 1462 self.assertFalse(proto.repeated_nested_message) 1463 self.assertEqual(0, len(proto.repeated_nested_message)) 1464 1465 def testRepeatedCompositeReverse_NonEmpty(self): 1466 proto = unittest_pb2.TestAllTypes() 1467 1468 self.assertFalse(proto.repeated_nested_message) 1469 self.assertEqual(0, len(proto.repeated_nested_message)) 1470 1471 m0 = proto.repeated_nested_message.add() 1472 m0.bb = len(proto.repeated_nested_message) 1473 m1 = proto.repeated_nested_message.add() 1474 m1.bb = len(proto.repeated_nested_message) 1475 m2 = proto.repeated_nested_message.add() 1476 m2.bb = len(proto.repeated_nested_message) 1477 self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) 1478 1479 self.assertIsNone(proto.repeated_nested_message.reverse()) 1480 1481 self.assertListsEqual([m2, m1, m0], proto.repeated_nested_message) 1482 1483 def testHandWrittenReflection(self): 1484 # Hand written extensions are only supported by the pure-Python 1485 # implementation of the API. 1486 if api_implementation.Type() != 'python': 1487 return 1488 1489 file = descriptor.FileDescriptor(name='foo.proto', package='') 1490 FieldDescriptor = descriptor.FieldDescriptor 1491 foo_field_descriptor = FieldDescriptor( 1492 name='foo_field', full_name='MyProto.foo_field', 1493 index=0, number=1, type=FieldDescriptor.TYPE_INT64, 1494 cpp_type=FieldDescriptor.CPPTYPE_INT64, 1495 label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, 1496 containing_type=None, message_type=None, enum_type=None, 1497 is_extension=False, extension_scope=None, 1498 options=descriptor_pb2.FieldOptions(), file=file, 1499 # pylint: disable=protected-access 1500 create_key=descriptor._internal_create_key) 1501 mydescriptor = descriptor.Descriptor( 1502 name='MyProto', full_name='MyProto', filename='ignored', 1503 containing_type=None, nested_types=[], enum_types=[], 1504 fields=[foo_field_descriptor], extensions=[], 1505 options=descriptor_pb2.MessageOptions(), 1506 file=file, 1507 # pylint: disable=protected-access 1508 create_key=descriptor._internal_create_key) 1509 1510 class MyProtoClass( 1511 message.Message, metaclass=reflection.GeneratedProtocolMessageType): 1512 DESCRIPTOR = mydescriptor 1513 myproto_instance = MyProtoClass() 1514 self.assertEqual(0, myproto_instance.foo_field) 1515 self.assertFalse(myproto_instance.HasField('foo_field')) 1516 myproto_instance.foo_field = 23 1517 self.assertEqual(23, myproto_instance.foo_field) 1518 self.assertTrue(myproto_instance.HasField('foo_field')) 1519 1520 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 1521 def testDescriptorProtoSupport(self): 1522 # Hand written descriptors/reflection are only supported by the pure-Python 1523 # implementation of the API. 1524 if api_implementation.Type() != 'python': 1525 return 1526 1527 def AddDescriptorField(proto, field_name, field_type): 1528 AddDescriptorField.field_index += 1 1529 new_field = proto.field.add() 1530 new_field.name = field_name 1531 new_field.type = field_type 1532 new_field.number = AddDescriptorField.field_index 1533 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL 1534 1535 AddDescriptorField.field_index = 0 1536 1537 desc_proto = descriptor_pb2.DescriptorProto() 1538 desc_proto.name = 'Car' 1539 fdp = descriptor_pb2.FieldDescriptorProto 1540 AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) 1541 AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) 1542 AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) 1543 AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) 1544 # Add a repeated field 1545 AddDescriptorField.field_index += 1 1546 new_field = desc_proto.field.add() 1547 new_field.name = 'owners' 1548 new_field.type = fdp.TYPE_STRING 1549 new_field.number = AddDescriptorField.field_index 1550 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED 1551 1552 desc = descriptor.MakeDescriptor(desc_proto) 1553 self.assertTrue('name' in desc.fields_by_name) 1554 self.assertTrue('year' in desc.fields_by_name) 1555 self.assertTrue('automatic' in desc.fields_by_name) 1556 self.assertTrue('price' in desc.fields_by_name) 1557 self.assertTrue('owners' in desc.fields_by_name) 1558 1559 class CarMessage( 1560 message.Message, metaclass=reflection.GeneratedProtocolMessageType): 1561 DESCRIPTOR = desc 1562 1563 prius = CarMessage() 1564 prius.name = 'prius' 1565 prius.year = 2010 1566 prius.automatic = True 1567 prius.price = 25134.75 1568 prius.owners.extend(['bob', 'susan']) 1569 1570 serialized_prius = prius.SerializeToString() 1571 new_prius = message_factory.GetMessageClass(desc)() 1572 new_prius.ParseFromString(serialized_prius) 1573 self.assertIsNot(new_prius, prius) 1574 self.assertEqual(prius, new_prius) 1575 1576 # these are unnecessary assuming message equality works as advertised but 1577 # explicitly check to be safe since we're mucking about in metaclass foo 1578 self.assertEqual(prius.name, new_prius.name) 1579 self.assertEqual(prius.year, new_prius.year) 1580 self.assertEqual(prius.automatic, new_prius.automatic) 1581 self.assertEqual(prius.price, new_prius.price) 1582 self.assertEqual(prius.owners, new_prius.owners) 1583 1584 def testExtensionDelete(self): 1585 extendee_proto = more_extensions_pb2.ExtendedMessage() 1586 1587 extension_int32 = more_extensions_pb2.optional_int_extension 1588 extendee_proto.Extensions[extension_int32] = 23 1589 1590 extension_repeated = more_extensions_pb2.repeated_int_extension 1591 extendee_proto.Extensions[extension_repeated].append(11) 1592 1593 extension_msg = more_extensions_pb2.optional_message_extension 1594 extendee_proto.Extensions[extension_msg].foreign_message_int = 56 1595 1596 self.assertEqual(len(extendee_proto.Extensions), 3) 1597 del extendee_proto.Extensions[extension_msg] 1598 self.assertEqual(len(extendee_proto.Extensions), 2) 1599 del extendee_proto.Extensions[extension_repeated] 1600 self.assertEqual(len(extendee_proto.Extensions), 1) 1601 # Delete a none exist extension. It is OK to "del m.Extensions[ext]" 1602 # even if the extension is not present in the message; we don't 1603 # raise KeyError. This is consistent with "m.Extensions[ext]" 1604 # returning a default value even if we did not set anything. 1605 del extendee_proto.Extensions[extension_repeated] 1606 self.assertEqual(len(extendee_proto.Extensions), 1) 1607 del extendee_proto.Extensions[extension_int32] 1608 self.assertEqual(len(extendee_proto.Extensions), 0) 1609 1610 def testExtensionIter(self): 1611 extendee_proto = more_extensions_pb2.ExtendedMessage() 1612 1613 extension_int32 = more_extensions_pb2.optional_int_extension 1614 extendee_proto.Extensions[extension_int32] = 23 1615 1616 extension_repeated = more_extensions_pb2.repeated_int_extension 1617 extendee_proto.Extensions[extension_repeated].append(11) 1618 1619 extension_msg = more_extensions_pb2.optional_message_extension 1620 extendee_proto.Extensions[extension_msg].foreign_message_int = 56 1621 1622 # Set some normal fields. 1623 extendee_proto.optional_int32 = 1 1624 extendee_proto.repeated_string.append('hi') 1625 1626 expected = (extension_int32, extension_msg, extension_repeated) 1627 count = 0 1628 for item in extendee_proto.Extensions: 1629 self.assertEqual(item.name, expected[count].name) 1630 self.assertIn(item, extendee_proto.Extensions) 1631 count += 1 1632 self.assertEqual(count, 3) 1633 1634 def testExtensionContainsError(self): 1635 extendee_proto = more_extensions_pb2.ExtendedMessage() 1636 self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, 0) 1637 1638 field = more_extensions_pb2.ExtendedMessage.DESCRIPTOR.fields_by_name[ 1639 'optional_int32'] 1640 self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, field) 1641 1642 def testTopLevelExtensionsForOptionalScalar(self): 1643 extendee_proto = unittest_pb2.TestAllExtensions() 1644 extension = unittest_pb2.optional_int32_extension 1645 self.assertFalse(extendee_proto.HasExtension(extension)) 1646 self.assertNotIn(extension, extendee_proto.Extensions) 1647 self.assertEqual(0, extendee_proto.Extensions[extension]) 1648 # As with normal scalar fields, just doing a read doesn't actually set the 1649 # "has" bit. 1650 self.assertFalse(extendee_proto.HasExtension(extension)) 1651 self.assertNotIn(extension, extendee_proto.Extensions) 1652 # Actually set the thing. 1653 extendee_proto.Extensions[extension] = 23 1654 self.assertEqual(23, extendee_proto.Extensions[extension]) 1655 self.assertTrue(extendee_proto.HasExtension(extension)) 1656 self.assertIn(extension, extendee_proto.Extensions) 1657 # Ensure that clearing works as well. 1658 extendee_proto.ClearExtension(extension) 1659 self.assertEqual(0, extendee_proto.Extensions[extension]) 1660 self.assertFalse(extendee_proto.HasExtension(extension)) 1661 self.assertNotIn(extension, extendee_proto.Extensions) 1662 1663 def testTopLevelExtensionsForRepeatedScalar(self): 1664 extendee_proto = unittest_pb2.TestAllExtensions() 1665 extension = unittest_pb2.repeated_string_extension 1666 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1667 self.assertNotIn(extension, extendee_proto.Extensions) 1668 extendee_proto.Extensions[extension].append('foo') 1669 self.assertEqual(['foo'], extendee_proto.Extensions[extension]) 1670 self.assertIn(extension, extendee_proto.Extensions) 1671 string_list = extendee_proto.Extensions[extension] 1672 extendee_proto.ClearExtension(extension) 1673 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1674 self.assertNotIn(extension, extendee_proto.Extensions) 1675 self.assertIsNot(string_list, extendee_proto.Extensions[extension]) 1676 # Shouldn't be allowed to do Extensions[extension] = 'a' 1677 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1678 extension, 'a') 1679 1680 def testTopLevelExtensionsForOptionalMessage(self): 1681 extendee_proto = unittest_pb2.TestAllExtensions() 1682 extension = unittest_pb2.optional_foreign_message_extension 1683 self.assertFalse(extendee_proto.HasExtension(extension)) 1684 self.assertNotIn(extension, extendee_proto.Extensions) 1685 self.assertEqual(0, extendee_proto.Extensions[extension].c) 1686 # As with normal (non-extension) fields, merely reading from the 1687 # thing shouldn't set the "has" bit. 1688 self.assertFalse(extendee_proto.HasExtension(extension)) 1689 self.assertNotIn(extension, extendee_proto.Extensions) 1690 extendee_proto.Extensions[extension].c = 23 1691 self.assertEqual(23, extendee_proto.Extensions[extension].c) 1692 self.assertTrue(extendee_proto.HasExtension(extension)) 1693 self.assertIn(extension, extendee_proto.Extensions) 1694 # Save a reference here. 1695 foreign_message = extendee_proto.Extensions[extension] 1696 extendee_proto.ClearExtension(extension) 1697 self.assertIsNot(foreign_message, extendee_proto.Extensions[extension]) 1698 # Setting a field on foreign_message now shouldn't set 1699 # any "has" bits on extendee_proto. 1700 foreign_message.c = 42 1701 self.assertEqual(42, foreign_message.c) 1702 self.assertTrue(foreign_message.HasField('c')) 1703 self.assertFalse(extendee_proto.HasExtension(extension)) 1704 self.assertNotIn(extension, extendee_proto.Extensions) 1705 # Shouldn't be allowed to do Extensions[extension] = 'a' 1706 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1707 extension, 'a') 1708 1709 def testTopLevelExtensionsForRepeatedMessage(self): 1710 extendee_proto = unittest_pb2.TestAllExtensions() 1711 extension = unittest_pb2.repeatedgroup_extension 1712 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1713 group = extendee_proto.Extensions[extension].add() 1714 group.a = 23 1715 self.assertEqual(23, extendee_proto.Extensions[extension][0].a) 1716 group.a = 42 1717 self.assertEqual(42, extendee_proto.Extensions[extension][0].a) 1718 group_list = extendee_proto.Extensions[extension] 1719 extendee_proto.ClearExtension(extension) 1720 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1721 self.assertIsNot(group_list, extendee_proto.Extensions[extension]) 1722 # Shouldn't be allowed to do Extensions[extension] = 'a' 1723 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1724 extension, 'a') 1725 1726 def testNestedExtensions(self): 1727 extendee_proto = unittest_pb2.TestAllExtensions() 1728 extension = unittest_pb2.TestRequired.single 1729 1730 # We just test the non-repeated case. 1731 self.assertFalse(extendee_proto.HasExtension(extension)) 1732 self.assertNotIn(extension, extendee_proto.Extensions) 1733 required = extendee_proto.Extensions[extension] 1734 self.assertEqual(0, required.a) 1735 self.assertFalse(extendee_proto.HasExtension(extension)) 1736 self.assertNotIn(extension, extendee_proto.Extensions) 1737 required.a = 23 1738 self.assertEqual(23, extendee_proto.Extensions[extension].a) 1739 self.assertTrue(extendee_proto.HasExtension(extension)) 1740 self.assertIn(extension, extendee_proto.Extensions) 1741 extendee_proto.ClearExtension(extension) 1742 self.assertIsNot(required, extendee_proto.Extensions[extension]) 1743 self.assertFalse(extendee_proto.HasExtension(extension)) 1744 self.assertNotIn(extension, extendee_proto.Extensions) 1745 1746 def testRegisteredExtensions(self): 1747 pool = unittest_pb2.DESCRIPTOR.pool 1748 self.assertTrue( 1749 pool.FindExtensionByNumber( 1750 unittest_pb2.TestAllExtensions.DESCRIPTOR, 1)) 1751 self.assertIs( 1752 pool.FindExtensionByName( 1753 'protobuf_unittest.optional_int32_extension').containing_type, 1754 unittest_pb2.TestAllExtensions.DESCRIPTOR) 1755 # Make sure extensions haven't been registered into types that shouldn't 1756 # have any. 1757 self.assertEqual(0, len( 1758 pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR))) 1759 1760 # If message A directly contains message B, and 1761 # a.HasField('b') is currently False, then mutating any 1762 # extension in B should change a.HasField('b') to True 1763 # (and so on up the object tree). 1764 def testHasBitsForAncestorsOfExtendedMessage(self): 1765 # Optional scalar extension. 1766 toplevel = more_extensions_pb2.TopLevelMessage() 1767 self.assertFalse(toplevel.HasField('submessage')) 1768 self.assertEqual(0, toplevel.submessage.Extensions[ 1769 more_extensions_pb2.optional_int_extension]) 1770 self.assertFalse(toplevel.HasField('submessage')) 1771 toplevel.submessage.Extensions[ 1772 more_extensions_pb2.optional_int_extension] = 23 1773 self.assertEqual(23, toplevel.submessage.Extensions[ 1774 more_extensions_pb2.optional_int_extension]) 1775 self.assertTrue(toplevel.HasField('submessage')) 1776 1777 # Repeated scalar extension. 1778 toplevel = more_extensions_pb2.TopLevelMessage() 1779 self.assertFalse(toplevel.HasField('submessage')) 1780 self.assertEqual([], toplevel.submessage.Extensions[ 1781 more_extensions_pb2.repeated_int_extension]) 1782 self.assertFalse(toplevel.HasField('submessage')) 1783 toplevel.submessage.Extensions[ 1784 more_extensions_pb2.repeated_int_extension].append(23) 1785 self.assertEqual([23], toplevel.submessage.Extensions[ 1786 more_extensions_pb2.repeated_int_extension]) 1787 self.assertTrue(toplevel.HasField('submessage')) 1788 1789 # Optional message extension. 1790 toplevel = more_extensions_pb2.TopLevelMessage() 1791 self.assertFalse(toplevel.HasField('submessage')) 1792 self.assertEqual(0, toplevel.submessage.Extensions[ 1793 more_extensions_pb2.optional_message_extension].foreign_message_int) 1794 self.assertFalse(toplevel.HasField('submessage')) 1795 toplevel.submessage.Extensions[ 1796 more_extensions_pb2.optional_message_extension].foreign_message_int = 23 1797 self.assertEqual(23, toplevel.submessage.Extensions[ 1798 more_extensions_pb2.optional_message_extension].foreign_message_int) 1799 self.assertTrue(toplevel.HasField('submessage')) 1800 1801 # Repeated message extension. 1802 toplevel = more_extensions_pb2.TopLevelMessage() 1803 self.assertFalse(toplevel.HasField('submessage')) 1804 self.assertEqual(0, len(toplevel.submessage.Extensions[ 1805 more_extensions_pb2.repeated_message_extension])) 1806 self.assertFalse(toplevel.HasField('submessage')) 1807 foreign = toplevel.submessage.Extensions[ 1808 more_extensions_pb2.repeated_message_extension].add() 1809 self.assertEqual(foreign, toplevel.submessage.Extensions[ 1810 more_extensions_pb2.repeated_message_extension][0]) 1811 self.assertTrue(toplevel.HasField('submessage')) 1812 1813 def testDisconnectionAfterClearingEmptyMessage(self): 1814 toplevel = more_extensions_pb2.TopLevelMessage() 1815 extendee_proto = toplevel.submessage 1816 extension = more_extensions_pb2.optional_message_extension 1817 extension_proto = extendee_proto.Extensions[extension] 1818 extendee_proto.ClearExtension(extension) 1819 extension_proto.foreign_message_int = 23 1820 1821 self.assertIsNot(extension_proto, extendee_proto.Extensions[extension]) 1822 1823 def testExtensionFailureModes(self): 1824 extendee_proto = unittest_pb2.TestAllExtensions() 1825 1826 # Try non-extension-handle arguments to HasExtension, 1827 # ClearExtension(), and Extensions[]... 1828 self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) 1829 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) 1830 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) 1831 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) 1832 1833 # Try something that *is* an extension handle, just not for 1834 # this message... 1835 for unknown_handle in (more_extensions_pb2.optional_int_extension, 1836 more_extensions_pb2.optional_message_extension, 1837 more_extensions_pb2.repeated_int_extension, 1838 more_extensions_pb2.repeated_message_extension): 1839 self.assertRaises(KeyError, extendee_proto.HasExtension, 1840 unknown_handle) 1841 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1842 unknown_handle) 1843 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1844 unknown_handle) 1845 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1846 unknown_handle, 5) 1847 1848 # Try call HasExtension() with a valid handle, but for a 1849 # *repeated* field. (Just as with non-extension repeated 1850 # fields, Has*() isn't supported for extension repeated fields). 1851 self.assertRaises(KeyError, extendee_proto.HasExtension, 1852 unittest_pb2.repeated_string_extension) 1853 1854 def testMergeFromOptionalGroup(self): 1855 # Test merge with an optional group. 1856 proto1 = unittest_pb2.TestAllTypes() 1857 proto1.optionalgroup.a = 12 1858 proto2 = unittest_pb2.TestAllTypes() 1859 proto2.MergeFrom(proto1) 1860 self.assertEqual(12, proto2.optionalgroup.a) 1861 1862 def testMergeFromExtensionsSingular(self): 1863 proto1 = unittest_pb2.TestAllExtensions() 1864 proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 1865 1866 proto2 = unittest_pb2.TestAllExtensions() 1867 proto2.MergeFrom(proto1) 1868 self.assertEqual( 1869 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) 1870 1871 def testMergeFromExtensionsRepeated(self): 1872 proto1 = unittest_pb2.TestAllExtensions() 1873 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) 1874 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) 1875 1876 proto2 = unittest_pb2.TestAllExtensions() 1877 proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) 1878 proto2.MergeFrom(proto1) 1879 self.assertEqual( 1880 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) 1881 self.assertEqual( 1882 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) 1883 self.assertEqual( 1884 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) 1885 self.assertEqual( 1886 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) 1887 1888 def testMergeFromExtensionsNestedMessage(self): 1889 proto1 = unittest_pb2.TestAllExtensions() 1890 ext1 = proto1.Extensions[ 1891 unittest_pb2.repeated_nested_message_extension] 1892 m = ext1.add() 1893 m.bb = 222 1894 m = ext1.add() 1895 m.bb = 333 1896 1897 proto2 = unittest_pb2.TestAllExtensions() 1898 ext2 = proto2.Extensions[ 1899 unittest_pb2.repeated_nested_message_extension] 1900 m = ext2.add() 1901 m.bb = 111 1902 1903 proto2.MergeFrom(proto1) 1904 ext2 = proto2.Extensions[ 1905 unittest_pb2.repeated_nested_message_extension] 1906 self.assertEqual(3, len(ext2)) 1907 self.assertEqual(111, ext2[0].bb) 1908 self.assertEqual(222, ext2[1].bb) 1909 self.assertEqual(333, ext2[2].bb) 1910 1911 def testCopyFromBadType(self): 1912 # The python implementation doesn't raise an exception in this 1913 # case. In theory it should. 1914 if api_implementation.Type() == 'python': 1915 return 1916 proto1 = unittest_pb2.TestAllTypes() 1917 proto2 = unittest_pb2.TestAllExtensions() 1918 self.assertRaises(TypeError, proto1.CopyFrom, proto2) 1919 1920 def testClear(self): 1921 proto = unittest_pb2.TestAllTypes() 1922 # C++ implementation does not support lazy fields right now so leave it 1923 # out for now. 1924 if api_implementation.Type() == 'python': 1925 test_util.SetAllFields(proto) 1926 else: 1927 test_util.SetAllNonLazyFields(proto) 1928 # Clear the message. 1929 proto.Clear() 1930 self.assertEqual(proto.ByteSize(), 0) 1931 empty_proto = unittest_pb2.TestAllTypes() 1932 self.assertEqual(proto, empty_proto) 1933 1934 # Test if extensions which were set are cleared. 1935 proto = unittest_pb2.TestAllExtensions() 1936 test_util.SetAllExtensions(proto) 1937 # Clear the message. 1938 proto.Clear() 1939 self.assertEqual(proto.ByteSize(), 0) 1940 empty_proto = unittest_pb2.TestAllExtensions() 1941 self.assertEqual(proto, empty_proto) 1942 1943 def testDisconnectingInOneof(self): 1944 m = unittest_pb2.TestOneof2() # This message has two messages in a oneof. 1945 m.foo_message.moo_int = 5 1946 sub_message = m.foo_message 1947 # Accessing another message's field does not clear the first one 1948 self.assertEqual(m.foo_lazy_message.moo_int, 0) 1949 self.assertEqual(m.foo_message.moo_int, 5) 1950 # But mutating another message in the oneof detaches the first one. 1951 m.foo_lazy_message.moo_int = 6 1952 self.assertEqual(m.foo_message.moo_int, 0) 1953 # The reference we got above was detached and is still valid. 1954 self.assertEqual(sub_message.moo_int, 5) 1955 sub_message.moo_int = 7 1956 1957 def assertInitialized(self, proto): 1958 self.assertTrue(proto.IsInitialized()) 1959 # Neither method should raise an exception. 1960 proto.SerializeToString() 1961 proto.SerializePartialToString() 1962 1963 def assertNotInitialized(self, proto, error_size=None): 1964 errors = [] 1965 self.assertFalse(proto.IsInitialized()) 1966 self.assertFalse(proto.IsInitialized(errors)) 1967 self.assertEqual(error_size, len(errors)) 1968 self.assertRaises(message.EncodeError, proto.SerializeToString) 1969 # "Partial" serialization doesn't care if message is uninitialized. 1970 proto.SerializePartialToString() 1971 1972 def testIsInitialized(self): 1973 # Trivial cases - all optional fields and extensions. 1974 proto = unittest_pb2.TestAllTypes() 1975 self.assertInitialized(proto) 1976 proto = unittest_pb2.TestAllExtensions() 1977 self.assertInitialized(proto) 1978 1979 # The case of uninitialized required fields. 1980 proto = unittest_pb2.TestRequired() 1981 self.assertNotInitialized(proto, 3) 1982 proto.a = proto.b = proto.c = 2 1983 self.assertInitialized(proto) 1984 1985 # The case of uninitialized submessage. 1986 proto = unittest_pb2.TestRequiredForeign() 1987 self.assertInitialized(proto) 1988 proto.optional_message.a = 1 1989 self.assertNotInitialized(proto, 2) 1990 proto.optional_message.b = 0 1991 proto.optional_message.c = 0 1992 self.assertInitialized(proto) 1993 1994 # Uninitialized repeated submessage. 1995 message1 = proto.repeated_message.add() 1996 self.assertNotInitialized(proto, 3) 1997 message1.a = message1.b = message1.c = 0 1998 self.assertInitialized(proto) 1999 2000 # Uninitialized repeated group in an extension. 2001 proto = unittest_pb2.TestAllExtensions() 2002 extension = unittest_pb2.TestRequired.multi 2003 message1 = proto.Extensions[extension].add() 2004 message2 = proto.Extensions[extension].add() 2005 self.assertNotInitialized(proto, 6) 2006 message1.a = 1 2007 message1.b = 1 2008 message1.c = 1 2009 self.assertNotInitialized(proto, 3) 2010 message2.a = 2 2011 message2.b = 2 2012 message2.c = 2 2013 self.assertInitialized(proto) 2014 2015 # Uninitialized nonrepeated message in an extension. 2016 proto = unittest_pb2.TestAllExtensions() 2017 extension = unittest_pb2.TestRequired.single 2018 proto.Extensions[extension].a = 1 2019 self.assertNotInitialized(proto, 2) 2020 proto.Extensions[extension].b = 2 2021 proto.Extensions[extension].c = 3 2022 self.assertInitialized(proto) 2023 2024 # Try passing an errors list. 2025 errors = [] 2026 proto = unittest_pb2.TestRequired() 2027 self.assertFalse(proto.IsInitialized(errors)) 2028 self.assertEqual(errors, ['a', 'b', 'c']) 2029 self.assertRaises(TypeError, proto.IsInitialized, 1, 2, 3) 2030 2031 @unittest.skipIf( 2032 api_implementation.Type() == 'python', 2033 'Errors are only available from the most recent C++ implementation.') 2034 def testFileDescriptorErrors(self): 2035 file_name = 'test_file_descriptor_errors.proto' 2036 package_name = 'test_file_descriptor_errors.proto' 2037 file_descriptor_proto = descriptor_pb2.FileDescriptorProto() 2038 file_descriptor_proto.name = file_name 2039 file_descriptor_proto.package = package_name 2040 m1 = file_descriptor_proto.message_type.add() 2041 m1.name = 'msg1' 2042 # Compiles the proto into the C++ descriptor pool 2043 descriptor.FileDescriptor( 2044 file_name, 2045 package_name, 2046 serialized_pb=file_descriptor_proto.SerializeToString()) 2047 # Add a FileDescriptorProto that has duplicate symbols 2048 another_file_name = 'another_test_file_descriptor_errors.proto' 2049 file_descriptor_proto.name = another_file_name 2050 m2 = file_descriptor_proto.message_type.add() 2051 m2.name = 'msg2' 2052 with self.assertRaises(TypeError) as cm: 2053 descriptor.FileDescriptor( 2054 another_file_name, 2055 package_name, 2056 serialized_pb=file_descriptor_proto.SerializeToString()) 2057 self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % 2058 getattr(cm.expected, '__name__', cm.expected)) 2059 self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) 2060 # Error message will say something about this definition being a 2061 # duplicate, though we don't check the message exactly to avoid a 2062 # dependency on the C++ logging code. 2063 self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) 2064 2065 def testDescriptorProtoHasFileOptions(self): 2066 self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) 2067 self.assertEqual( 2068 descriptor_pb2.DESCRIPTOR.GetOptions().java_package, 2069 'com.google.protobuf', 2070 ) 2071 2072 def testDescriptorProtoHasFieldOptions(self): 2073 self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) 2074 self.assertEqual( 2075 descriptor_pb2.DESCRIPTOR.GetOptions().java_package, 2076 'com.google.protobuf', 2077 ) 2078 packed_desc = ( 2079 descriptor_pb2.SourceCodeInfo.DESCRIPTOR.nested_types_by_name.get( 2080 'Location' 2081 ).fields_by_name.get('path') 2082 ) 2083 self.assertTrue(packed_desc.has_options) 2084 self.assertTrue(packed_desc.GetOptions().packed) 2085 2086 def testDescriptorProtoHasFeatureOptions(self): 2087 self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options) 2088 self.assertEqual( 2089 descriptor_pb2.DESCRIPTOR.GetOptions().java_package, 2090 'com.google.protobuf', 2091 ) 2092 presence_desc = descriptor_pb2.FeatureSet.DESCRIPTOR.fields_by_name.get( 2093 'field_presence' 2094 ) 2095 self.assertTrue(presence_desc.has_options) 2096 self.assertEqual( 2097 presence_desc.GetOptions().retention, 2098 descriptor_pb2.FieldOptions.OptionRetention.RETENTION_RUNTIME, 2099 ) 2100 self.assertListsEqual( 2101 presence_desc.GetOptions().targets, 2102 [ 2103 descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FIELD, 2104 descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FILE, 2105 ], 2106 ) 2107 2108 def testStringUTF8Serialization(self): 2109 proto = message_set_extensions_pb2.TestMessageSet() 2110 extension_message = message_set_extensions_pb2.TestMessageSetExtension2 2111 extension = extension_message.message_set_extension 2112 2113 test_utf8 = u'Тест' 2114 test_utf8_bytes = test_utf8.encode('utf-8') 2115 2116 # 'Test' in another language, using UTF-8 charset. 2117 proto.Extensions[extension].str = test_utf8 2118 2119 # Serialize using the MessageSet wire format (this is specified in the 2120 # .proto file). 2121 serialized = proto.SerializeToString() 2122 2123 # Check byte size. 2124 self.assertEqual(proto.ByteSize(), len(serialized)) 2125 2126 raw = unittest_mset_pb2.RawMessageSet() 2127 bytes_read = raw.MergeFromString(serialized) 2128 self.assertEqual(len(serialized), bytes_read) 2129 2130 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 2131 2132 self.assertEqual(1, len(raw.item)) 2133 # Check that the type_id is the same as the tag ID in the .proto file. 2134 self.assertEqual(raw.item[0].type_id, 98418634) 2135 2136 # Check the actual bytes on the wire. 2137 self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) 2138 bytes_read = message2.MergeFromString(raw.item[0].message) 2139 self.assertEqual(len(raw.item[0].message), bytes_read) 2140 2141 self.assertEqual(type(message2.str), str) 2142 self.assertEqual(message2.str, test_utf8) 2143 2144 # The pure Python API throws an exception on MergeFromString(), 2145 # if any of the string fields of the message can't be UTF-8 decoded. 2146 # The C++ implementation of the API has no way to check that on 2147 # MergeFromString and thus has no way to throw the exception. 2148 # 2149 # The pure Python API always returns objects of type 'unicode' (UTF-8 2150 # encoded), or 'bytes' (in 7 bit ASCII). 2151 badbytes = raw.item[0].message.replace( 2152 test_utf8_bytes, len(test_utf8_bytes) * b'\xff') 2153 2154 unicode_decode_failed = False 2155 try: 2156 message2.MergeFromString(badbytes) 2157 except UnicodeDecodeError: 2158 unicode_decode_failed = True 2159 string_field = message2.str 2160 self.assertTrue(unicode_decode_failed or type(string_field) is bytes) 2161 2162 def testSetInParent(self): 2163 proto = unittest_pb2.TestAllTypes() 2164 self.assertFalse(proto.HasField('optionalgroup')) 2165 proto.optionalgroup.SetInParent() 2166 self.assertTrue(proto.HasField('optionalgroup')) 2167 2168 def testPackageInitializationImport(self): 2169 """Test that we can import nested messages from their __init__.py. 2170 2171 Such setup is not trivial since at the time of processing of __init__.py one 2172 can't refer to its submodules by name in code, so expressions like 2173 google.protobuf.internal.import_test_package.inner_pb2 2174 don't work. They do work in imports, so we have assign an alias at import 2175 and then use that alias in generated code. 2176 """ 2177 # We import here since it's the import that used to fail, and we want 2178 # the failure to have the right context. 2179 # pylint: disable=g-import-not-at-top 2180 from google.protobuf.internal import import_test_package 2181 # pylint: enable=g-import-not-at-top 2182 msg = import_test_package.myproto.Outer() 2183 # Just check the default value. 2184 self.assertEqual(57, msg.inner.value) 2185 2186# Since we had so many tests for protocol buffer equality, we broke these out 2187# into separate TestCase classes. 2188 2189 2190@testing_refleaks.TestCase 2191class TestAllTypesEqualityTest(unittest.TestCase): 2192 2193 def setUp(self): 2194 self.first_proto = unittest_pb2.TestAllTypes() 2195 self.second_proto = unittest_pb2.TestAllTypes() 2196 2197 def testNotHashable(self): 2198 self.assertRaises(TypeError, hash, self.first_proto) 2199 2200 def testSelfEquality(self): 2201 self.assertEqual(self.first_proto, self.first_proto) 2202 2203 def testEmptyProtosEqual(self): 2204 self.assertEqual(self.first_proto, self.second_proto) 2205 2206 2207@testing_refleaks.TestCase 2208class FullProtosEqualityTest(unittest.TestCase): 2209 2210 """Equality tests using completely-full protos as a starting point.""" 2211 2212 def setUp(self): 2213 self.first_proto = unittest_pb2.TestAllTypes() 2214 self.second_proto = unittest_pb2.TestAllTypes() 2215 test_util.SetAllFields(self.first_proto) 2216 test_util.SetAllFields(self.second_proto) 2217 2218 def testNotHashable(self): 2219 self.assertRaises(TypeError, hash, self.first_proto) 2220 2221 def testNoneNotEqual(self): 2222 self.assertNotEqual(self.first_proto, None) 2223 self.assertNotEqual(None, self.second_proto) 2224 2225 def testNotEqualToOtherMessage(self): 2226 third_proto = unittest_pb2.TestRequired() 2227 self.assertNotEqual(self.first_proto, third_proto) 2228 self.assertNotEqual(third_proto, self.second_proto) 2229 2230 def testAllFieldsFilledEquality(self): 2231 self.assertEqual(self.first_proto, self.second_proto) 2232 2233 def testNonRepeatedScalar(self): 2234 # Nonrepeated scalar field change should cause inequality. 2235 self.first_proto.optional_int32 += 1 2236 self.assertNotEqual(self.first_proto, self.second_proto) 2237 # ...as should clearing a field. 2238 self.first_proto.ClearField('optional_int32') 2239 self.assertNotEqual(self.first_proto, self.second_proto) 2240 2241 def testNonRepeatedComposite(self): 2242 # Change a nonrepeated composite field. 2243 self.first_proto.optional_nested_message.bb += 1 2244 self.assertNotEqual(self.first_proto, self.second_proto) 2245 self.first_proto.optional_nested_message.bb -= 1 2246 self.assertEqual(self.first_proto, self.second_proto) 2247 # Clear a field in the nested message. 2248 self.first_proto.optional_nested_message.ClearField('bb') 2249 self.assertNotEqual(self.first_proto, self.second_proto) 2250 self.first_proto.optional_nested_message.bb = ( 2251 self.second_proto.optional_nested_message.bb) 2252 self.assertEqual(self.first_proto, self.second_proto) 2253 # Remove the nested message entirely. 2254 self.first_proto.ClearField('optional_nested_message') 2255 self.assertNotEqual(self.first_proto, self.second_proto) 2256 2257 def testRepeatedScalar(self): 2258 # Change a repeated scalar field. 2259 self.first_proto.repeated_int32.append(5) 2260 self.assertNotEqual(self.first_proto, self.second_proto) 2261 self.first_proto.ClearField('repeated_int32') 2262 self.assertNotEqual(self.first_proto, self.second_proto) 2263 2264 def testRepeatedComposite(self): 2265 # Change value within a repeated composite field. 2266 self.first_proto.repeated_nested_message[0].bb += 1 2267 self.assertNotEqual(self.first_proto, self.second_proto) 2268 self.first_proto.repeated_nested_message[0].bb -= 1 2269 self.assertEqual(self.first_proto, self.second_proto) 2270 # Add a value to a repeated composite field. 2271 self.first_proto.repeated_nested_message.add() 2272 self.assertNotEqual(self.first_proto, self.second_proto) 2273 self.second_proto.repeated_nested_message.add() 2274 self.assertEqual(self.first_proto, self.second_proto) 2275 2276 def testNonRepeatedScalarHasBits(self): 2277 # Ensure that we test "has" bits as well as value for 2278 # nonrepeated scalar field. 2279 self.first_proto.ClearField('optional_int32') 2280 self.second_proto.optional_int32 = 0 2281 self.assertNotEqual(self.first_proto, self.second_proto) 2282 2283 def testNonRepeatedCompositeHasBits(self): 2284 # Ensure that we test "has" bits as well as value for 2285 # nonrepeated composite field. 2286 self.first_proto.ClearField('optional_nested_message') 2287 self.second_proto.optional_nested_message.ClearField('bb') 2288 self.assertNotEqual(self.first_proto, self.second_proto) 2289 self.first_proto.optional_nested_message.bb = 0 2290 self.first_proto.optional_nested_message.ClearField('bb') 2291 self.assertEqual(self.first_proto, self.second_proto) 2292 2293 2294@testing_refleaks.TestCase 2295class ExtensionEqualityTest(unittest.TestCase): 2296 2297 def testExtensionEquality(self): 2298 first_proto = unittest_pb2.TestAllExtensions() 2299 second_proto = unittest_pb2.TestAllExtensions() 2300 self.assertEqual(first_proto, second_proto) 2301 test_util.SetAllExtensions(first_proto) 2302 self.assertNotEqual(first_proto, second_proto) 2303 test_util.SetAllExtensions(second_proto) 2304 self.assertEqual(first_proto, second_proto) 2305 2306 # Ensure that we check value equality. 2307 first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 2308 self.assertNotEqual(first_proto, second_proto) 2309 first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 2310 self.assertEqual(first_proto, second_proto) 2311 2312 # Ensure that we also look at "has" bits. 2313 first_proto.ClearExtension(unittest_pb2.optional_int32_extension) 2314 second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 2315 self.assertNotEqual(first_proto, second_proto) 2316 first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 2317 self.assertEqual(first_proto, second_proto) 2318 2319 # Ensure that differences in cached values 2320 # don't matter if "has" bits are both false. 2321 first_proto = unittest_pb2.TestAllExtensions() 2322 second_proto = unittest_pb2.TestAllExtensions() 2323 self.assertEqual( 2324 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) 2325 self.assertEqual(first_proto, second_proto) 2326 2327 2328@testing_refleaks.TestCase 2329class MutualRecursionEqualityTest(unittest.TestCase): 2330 2331 def testEqualityWithMutualRecursion(self): 2332 first_proto = unittest_pb2.TestMutualRecursionA() 2333 second_proto = unittest_pb2.TestMutualRecursionA() 2334 self.assertEqual(first_proto, second_proto) 2335 first_proto.bb.a.bb.optional_int32 = 23 2336 self.assertNotEqual(first_proto, second_proto) 2337 second_proto.bb.a.bb.optional_int32 = 23 2338 self.assertEqual(first_proto, second_proto) 2339 2340 2341@testing_refleaks.TestCase 2342class ByteSizeTest(unittest.TestCase): 2343 2344 def setUp(self): 2345 self.proto = unittest_pb2.TestAllTypes() 2346 self.extended_proto = more_extensions_pb2.ExtendedMessage() 2347 self.packed_proto = unittest_pb2.TestPackedTypes() 2348 self.packed_extended_proto = unittest_pb2.TestPackedExtensions() 2349 2350 def Size(self): 2351 return self.proto.ByteSize() 2352 2353 def testEmptyMessage(self): 2354 self.assertEqual(0, self.proto.ByteSize()) 2355 2356 def testSizedOnKwargs(self): 2357 # Use a separate message to ensure testing right after creation. 2358 proto = unittest_pb2.TestAllTypes() 2359 self.assertEqual(0, proto.ByteSize()) 2360 proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) 2361 # One byte for the tag, one to encode varint 1. 2362 self.assertEqual(2, proto_kwargs.ByteSize()) 2363 2364 def testVarints(self): 2365 def Test(i, expected_varint_size): 2366 self.proto.Clear() 2367 self.proto.optional_int64 = i 2368 # Add one to the varint size for the tag info 2369 # for tag 1. 2370 self.assertEqual(expected_varint_size + 1, self.Size()) 2371 Test(0, 1) 2372 Test(1, 1) 2373 for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): 2374 Test((1 << i) - 1, num_bytes) 2375 Test(-1, 10) 2376 Test(-2, 10) 2377 Test(-(1 << 63), 10) 2378 2379 def testStrings(self): 2380 self.proto.optional_string = '' 2381 # Need one byte for tag info (tag #14), and one byte for length. 2382 self.assertEqual(2, self.Size()) 2383 2384 self.proto.optional_string = 'abc' 2385 # Need one byte for tag info (tag #14), and one byte for length. 2386 self.assertEqual(2 + len(self.proto.optional_string), self.Size()) 2387 2388 self.proto.optional_string = 'x' * 128 2389 # Need one byte for tag info (tag #14), and TWO bytes for length. 2390 self.assertEqual(3 + len(self.proto.optional_string), self.Size()) 2391 2392 def testOtherNumerics(self): 2393 self.proto.optional_fixed32 = 1234 2394 # One byte for tag and 4 bytes for fixed32. 2395 self.assertEqual(5, self.Size()) 2396 self.proto = unittest_pb2.TestAllTypes() 2397 2398 self.proto.optional_fixed64 = 1234 2399 # One byte for tag and 8 bytes for fixed64. 2400 self.assertEqual(9, self.Size()) 2401 self.proto = unittest_pb2.TestAllTypes() 2402 2403 self.proto.optional_float = 1.234 2404 # One byte for tag and 4 bytes for float. 2405 self.assertEqual(5, self.Size()) 2406 self.proto = unittest_pb2.TestAllTypes() 2407 2408 self.proto.optional_double = 1.234 2409 # One byte for tag and 8 bytes for float. 2410 self.assertEqual(9, self.Size()) 2411 self.proto = unittest_pb2.TestAllTypes() 2412 2413 self.proto.optional_sint32 = 64 2414 # One byte for tag and 2 bytes for zig-zag-encoded 64. 2415 self.assertEqual(3, self.Size()) 2416 self.proto = unittest_pb2.TestAllTypes() 2417 2418 def testComposites(self): 2419 # 3 bytes. 2420 self.proto.optional_nested_message.bb = (1 << 14) 2421 # Plus one byte for bb tag. 2422 # Plus 1 byte for optional_nested_message serialized size. 2423 # Plus two bytes for optional_nested_message tag. 2424 self.assertEqual(3 + 1 + 1 + 2, self.Size()) 2425 2426 def testGroups(self): 2427 # 4 bytes. 2428 self.proto.optionalgroup.a = (1 << 21) 2429 # Plus two bytes for |a| tag. 2430 # Plus 2 * two bytes for START_GROUP and END_GROUP tags. 2431 self.assertEqual(4 + 2 + 2*2, self.Size()) 2432 2433 def testRepeatedScalars(self): 2434 self.proto.repeated_int32.append(10) # 1 byte. 2435 self.proto.repeated_int32.append(128) # 2 bytes. 2436 # Also need 2 bytes for each entry for tag. 2437 self.assertEqual(1 + 2 + 2*2, self.Size()) 2438 2439 def testRepeatedScalarsExtend(self): 2440 self.proto.repeated_int32.extend([10, 128]) # 3 bytes. 2441 # Also need 2 bytes for each entry for tag. 2442 self.assertEqual(1 + 2 + 2*2, self.Size()) 2443 2444 def testRepeatedScalarsRemove(self): 2445 self.proto.repeated_int32.append(10) # 1 byte. 2446 self.proto.repeated_int32.append(128) # 2 bytes. 2447 # Also need 2 bytes for each entry for tag. 2448 self.assertEqual(1 + 2 + 2*2, self.Size()) 2449 self.proto.repeated_int32.remove(128) 2450 self.assertEqual(1 + 2, self.Size()) 2451 2452 def testRepeatedComposites(self): 2453 # Empty message. 2 bytes tag plus 1 byte length. 2454 foreign_message_0 = self.proto.repeated_nested_message.add() 2455 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2456 foreign_message_1 = self.proto.repeated_nested_message.add() 2457 foreign_message_1.bb = 7 2458 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2459 2460 def testRepeatedCompositesDelete(self): 2461 # Empty message. 2 bytes tag plus 1 byte length. 2462 foreign_message_0 = self.proto.repeated_nested_message.add() 2463 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2464 foreign_message_1 = self.proto.repeated_nested_message.add() 2465 foreign_message_1.bb = 9 2466 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2467 repeated_nested_message = copy.deepcopy( 2468 self.proto.repeated_nested_message) 2469 2470 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2471 del self.proto.repeated_nested_message[0] 2472 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2473 2474 # Now add a new message. 2475 foreign_message_2 = self.proto.repeated_nested_message.add() 2476 foreign_message_2.bb = 12 2477 2478 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2479 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2480 self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) 2481 2482 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2483 del self.proto.repeated_nested_message[1] 2484 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2485 2486 del self.proto.repeated_nested_message[0] 2487 self.assertEqual(0, self.Size()) 2488 2489 self.assertEqual(2, len(repeated_nested_message)) 2490 del repeated_nested_message[0:1] 2491 # TODO: Fix cpp extension bug when delete repeated message. 2492 if api_implementation.Type() == 'python': 2493 self.assertEqual(1, len(repeated_nested_message)) 2494 del repeated_nested_message[-1] 2495 # TODO: Fix cpp extension bug when delete repeated message. 2496 if api_implementation.Type() == 'python': 2497 self.assertEqual(0, len(repeated_nested_message)) 2498 2499 def testRepeatedGroups(self): 2500 # 2-byte START_GROUP plus 2-byte END_GROUP. 2501 group_0 = self.proto.repeatedgroup.add() 2502 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| 2503 # plus 2-byte END_GROUP. 2504 group_1 = self.proto.repeatedgroup.add() 2505 group_1.a = 7 2506 self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) 2507 2508 def testExtensions(self): 2509 proto = unittest_pb2.TestAllExtensions() 2510 self.assertEqual(0, proto.ByteSize()) 2511 extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. 2512 proto.Extensions[extension] = 23 2513 # 1 byte for tag, 1 byte for value. 2514 self.assertEqual(2, proto.ByteSize()) 2515 field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[ 2516 'optional_int32'] 2517 with self.assertRaises(KeyError): 2518 proto.Extensions[field] = 23 2519 2520 def testCacheInvalidationForNonrepeatedScalar(self): 2521 # Test non-extension. 2522 self.proto.optional_int32 = 1 2523 self.assertEqual(2, self.proto.ByteSize()) 2524 self.proto.optional_int32 = 128 2525 self.assertEqual(3, self.proto.ByteSize()) 2526 self.proto.ClearField('optional_int32') 2527 self.assertEqual(0, self.proto.ByteSize()) 2528 2529 # Test within extension. 2530 extension = more_extensions_pb2.optional_int_extension 2531 self.extended_proto.Extensions[extension] = 1 2532 self.assertEqual(2, self.extended_proto.ByteSize()) 2533 self.extended_proto.Extensions[extension] = 128 2534 self.assertEqual(3, self.extended_proto.ByteSize()) 2535 self.extended_proto.ClearExtension(extension) 2536 self.assertEqual(0, self.extended_proto.ByteSize()) 2537 2538 def testCacheInvalidationForRepeatedScalar(self): 2539 # Test non-extension. 2540 self.proto.repeated_int32.append(1) 2541 self.assertEqual(3, self.proto.ByteSize()) 2542 self.proto.repeated_int32.append(1) 2543 self.assertEqual(6, self.proto.ByteSize()) 2544 self.proto.repeated_int32[1] = 128 2545 self.assertEqual(7, self.proto.ByteSize()) 2546 self.proto.ClearField('repeated_int32') 2547 self.assertEqual(0, self.proto.ByteSize()) 2548 2549 # Test within extension. 2550 extension = more_extensions_pb2.repeated_int_extension 2551 repeated = self.extended_proto.Extensions[extension] 2552 repeated.append(1) 2553 self.assertEqual(2, self.extended_proto.ByteSize()) 2554 repeated.append(1) 2555 self.assertEqual(4, self.extended_proto.ByteSize()) 2556 repeated[1] = 128 2557 self.assertEqual(5, self.extended_proto.ByteSize()) 2558 self.extended_proto.ClearExtension(extension) 2559 self.assertEqual(0, self.extended_proto.ByteSize()) 2560 2561 def testCacheInvalidationForNonrepeatedMessage(self): 2562 # Test non-extension. 2563 self.proto.optional_foreign_message.c = 1 2564 self.assertEqual(5, self.proto.ByteSize()) 2565 self.proto.optional_foreign_message.c = 128 2566 self.assertEqual(6, self.proto.ByteSize()) 2567 self.proto.optional_foreign_message.ClearField('c') 2568 self.assertEqual(3, self.proto.ByteSize()) 2569 self.proto.ClearField('optional_foreign_message') 2570 self.assertEqual(0, self.proto.ByteSize()) 2571 2572 if api_implementation.Type() == 'python': 2573 # This is only possible in pure-Python implementation of the API. 2574 child = self.proto.optional_foreign_message 2575 self.proto.ClearField('optional_foreign_message') 2576 child.c = 128 2577 self.assertEqual(0, self.proto.ByteSize()) 2578 2579 # Test within extension. 2580 extension = more_extensions_pb2.optional_message_extension 2581 child = self.extended_proto.Extensions[extension] 2582 self.assertEqual(0, self.extended_proto.ByteSize()) 2583 child.foreign_message_int = 1 2584 self.assertEqual(4, self.extended_proto.ByteSize()) 2585 child.foreign_message_int = 128 2586 self.assertEqual(5, self.extended_proto.ByteSize()) 2587 self.extended_proto.ClearExtension(extension) 2588 self.assertEqual(0, self.extended_proto.ByteSize()) 2589 2590 def testCacheInvalidationForRepeatedMessage(self): 2591 # Test non-extension. 2592 child0 = self.proto.repeated_foreign_message.add() 2593 self.assertEqual(3, self.proto.ByteSize()) 2594 self.proto.repeated_foreign_message.add() 2595 self.assertEqual(6, self.proto.ByteSize()) 2596 child0.c = 1 2597 self.assertEqual(8, self.proto.ByteSize()) 2598 self.proto.ClearField('repeated_foreign_message') 2599 self.assertEqual(0, self.proto.ByteSize()) 2600 2601 # Test within extension. 2602 extension = more_extensions_pb2.repeated_message_extension 2603 child_list = self.extended_proto.Extensions[extension] 2604 child0 = child_list.add() 2605 self.assertEqual(2, self.extended_proto.ByteSize()) 2606 child_list.add() 2607 self.assertEqual(4, self.extended_proto.ByteSize()) 2608 child0.foreign_message_int = 1 2609 self.assertEqual(6, self.extended_proto.ByteSize()) 2610 child0.ClearField('foreign_message_int') 2611 self.assertEqual(4, self.extended_proto.ByteSize()) 2612 self.extended_proto.ClearExtension(extension) 2613 self.assertEqual(0, self.extended_proto.ByteSize()) 2614 2615 def testPackedRepeatedScalars(self): 2616 self.assertEqual(0, self.packed_proto.ByteSize()) 2617 2618 self.packed_proto.packed_int32.append(10) # 1 byte. 2619 self.packed_proto.packed_int32.append(128) # 2 bytes. 2620 # The tag is 2 bytes (the field number is 90), and the varint 2621 # storing the length is 1 byte. 2622 int_size = 1 + 2 + 3 2623 self.assertEqual(int_size, self.packed_proto.ByteSize()) 2624 2625 self.packed_proto.packed_double.append(4.2) # 8 bytes 2626 self.packed_proto.packed_double.append(3.25) # 8 bytes 2627 # 2 more tag bytes, 1 more length byte. 2628 double_size = 8 + 8 + 3 2629 self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) 2630 2631 self.packed_proto.ClearField('packed_int32') 2632 self.assertEqual(double_size, self.packed_proto.ByteSize()) 2633 2634 def testPackedExtensions(self): 2635 self.assertEqual(0, self.packed_extended_proto.ByteSize()) 2636 extension = self.packed_extended_proto.Extensions[ 2637 unittest_pb2.packed_fixed32_extension] 2638 extension.extend([1, 2, 3, 4]) # 16 bytes 2639 # Tag is 3 bytes. 2640 self.assertEqual(19, self.packed_extended_proto.ByteSize()) 2641 2642 2643# Issues to be sure to cover include: 2644# * Handling of unrecognized tags ("uninterpreted_bytes"). 2645# * Handling of MessageSets. 2646# * Consistent ordering of tags in the wire format, 2647# including ordering between extensions and non-extension 2648# fields. 2649# * Consistent serialization of negative numbers, especially 2650# negative int32s. 2651# * Handling of empty submessages (with and without "has" 2652# bits set). 2653 2654@testing_refleaks.TestCase 2655class SerializationTest(unittest.TestCase): 2656 2657 def testSerializeEmptyMessage(self): 2658 first_proto = unittest_pb2.TestAllTypes() 2659 second_proto = unittest_pb2.TestAllTypes() 2660 serialized = first_proto.SerializeToString() 2661 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2662 self.assertEqual( 2663 len(serialized), 2664 second_proto.MergeFromString(serialized)) 2665 self.assertEqual(first_proto, second_proto) 2666 2667 def testSerializeAllFields(self): 2668 first_proto = unittest_pb2.TestAllTypes() 2669 second_proto = unittest_pb2.TestAllTypes() 2670 test_util.SetAllFields(first_proto) 2671 serialized = first_proto.SerializeToString() 2672 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2673 self.assertEqual( 2674 len(serialized), 2675 second_proto.MergeFromString(serialized)) 2676 self.assertEqual(first_proto, second_proto) 2677 2678 def testSerializeAllExtensions(self): 2679 first_proto = unittest_pb2.TestAllExtensions() 2680 second_proto = unittest_pb2.TestAllExtensions() 2681 test_util.SetAllExtensions(first_proto) 2682 serialized = first_proto.SerializeToString() 2683 self.assertEqual( 2684 len(serialized), 2685 second_proto.MergeFromString(serialized)) 2686 self.assertEqual(first_proto, second_proto) 2687 2688 def testSerializeWithOptionalGroup(self): 2689 first_proto = unittest_pb2.TestAllTypes() 2690 second_proto = unittest_pb2.TestAllTypes() 2691 first_proto.optionalgroup.a = 242 2692 serialized = first_proto.SerializeToString() 2693 self.assertEqual( 2694 len(serialized), 2695 second_proto.MergeFromString(serialized)) 2696 self.assertEqual(first_proto, second_proto) 2697 2698 def testSerializeNegativeValues(self): 2699 first_proto = unittest_pb2.TestAllTypes() 2700 2701 first_proto.optional_int32 = -1 2702 first_proto.optional_int64 = -(2 << 40) 2703 first_proto.optional_sint32 = -3 2704 first_proto.optional_sint64 = -(4 << 40) 2705 first_proto.optional_sfixed32 = -5 2706 first_proto.optional_sfixed64 = -(6 << 40) 2707 2708 second_proto = unittest_pb2.TestAllTypes.FromString( 2709 first_proto.SerializeToString()) 2710 2711 self.assertEqual(first_proto, second_proto) 2712 2713 def testParseTruncated(self): 2714 # This test is only applicable for the Python implementation of the API. 2715 if api_implementation.Type() != 'python': 2716 return 2717 2718 first_proto = unittest_pb2.TestAllTypes() 2719 test_util.SetAllFields(first_proto) 2720 serialized = memoryview(first_proto.SerializeToString()) 2721 2722 for truncation_point in range(len(serialized) + 1): 2723 try: 2724 second_proto = unittest_pb2.TestAllTypes() 2725 unknown_fields = unittest_pb2.TestEmptyMessage() 2726 pos = second_proto._InternalParse(serialized, 0, truncation_point) 2727 # If we didn't raise an error then we read exactly the amount expected. 2728 self.assertEqual(truncation_point, pos) 2729 2730 # Parsing to unknown fields should not throw if parsing to known fields 2731 # did not. 2732 try: 2733 pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) 2734 self.assertEqual(truncation_point, pos2) 2735 except message.DecodeError: 2736 self.fail('Parsing unknown fields failed when parsing known fields ' 2737 'did not.') 2738 except message.DecodeError: 2739 # Parsing unknown fields should also fail. 2740 self.assertRaises(message.DecodeError, unknown_fields._InternalParse, 2741 serialized, 0, truncation_point) 2742 2743 def testCanonicalSerializationOrder(self): 2744 proto = more_messages_pb2.OutOfOrderFields() 2745 # These are also their tag numbers. Even though we're setting these in 2746 # reverse-tag order AND they're listed in reverse tag-order in the .proto 2747 # file, they should nonetheless be serialized in tag order. 2748 proto.optional_sint32 = 5 2749 proto.Extensions[more_messages_pb2.optional_uint64] = 4 2750 proto.optional_uint32 = 3 2751 proto.Extensions[more_messages_pb2.optional_int64] = 2 2752 proto.optional_int32 = 1 2753 serialized = proto.SerializeToString() 2754 self.assertEqual(proto.ByteSize(), len(serialized)) 2755 d = _MiniDecoder(serialized) 2756 ReadTag = d.ReadFieldNumberAndWireType 2757 self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) 2758 self.assertEqual(1, d.ReadInt32()) 2759 self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) 2760 self.assertEqual(2, d.ReadInt64()) 2761 self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) 2762 self.assertEqual(3, d.ReadUInt32()) 2763 self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) 2764 self.assertEqual(4, d.ReadUInt64()) 2765 self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) 2766 self.assertEqual(5, d.ReadSInt32()) 2767 2768 def testCanonicalSerializationOrderSameAsCpp(self): 2769 # Copy of the same test we use for C++. 2770 proto = unittest_pb2.TestFieldOrderings() 2771 test_util.SetAllFieldsAndExtensions(proto) 2772 serialized = proto.SerializeToString() 2773 test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) 2774 2775 def testMergeFromStringWhenFieldsAlreadySet(self): 2776 first_proto = unittest_pb2.TestAllTypes() 2777 first_proto.repeated_string.append('foobar') 2778 first_proto.optional_int32 = 23 2779 first_proto.optional_nested_message.bb = 42 2780 serialized = first_proto.SerializeToString() 2781 2782 second_proto = unittest_pb2.TestAllTypes() 2783 second_proto.repeated_string.append('baz') 2784 second_proto.optional_int32 = 100 2785 second_proto.optional_nested_message.bb = 999 2786 2787 bytes_parsed = second_proto.MergeFromString(serialized) 2788 self.assertEqual(len(serialized), bytes_parsed) 2789 2790 # Ensure that we append to repeated fields. 2791 self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) 2792 # Ensure that we overwrite nonrepeatd scalars. 2793 self.assertEqual(23, second_proto.optional_int32) 2794 # Ensure that we recursively call MergeFromString() on 2795 # submessages. 2796 self.assertEqual(42, second_proto.optional_nested_message.bb) 2797 2798 def testMessageSetWireFormat(self): 2799 proto = message_set_extensions_pb2.TestMessageSet() 2800 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2801 extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 2802 extension1 = extension_message1.message_set_extension 2803 extension2 = extension_message2.message_set_extension 2804 extension3 = message_set_extensions_pb2.message_set_extension3 2805 proto.Extensions[extension1].i = 123 2806 proto.Extensions[extension2].str = 'foo' 2807 proto.Extensions[extension3].text = 'bar' 2808 2809 # Serialize using the MessageSet wire format (this is specified in the 2810 # .proto file). 2811 serialized = proto.SerializeToString() 2812 2813 raw = unittest_mset_pb2.RawMessageSet() 2814 self.assertEqual(False, 2815 raw.DESCRIPTOR.GetOptions().message_set_wire_format) 2816 self.assertEqual( 2817 len(serialized), 2818 raw.MergeFromString(serialized)) 2819 self.assertEqual(3, len(raw.item)) 2820 2821 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2822 self.assertEqual( 2823 len(raw.item[0].message), 2824 message1.MergeFromString(raw.item[0].message)) 2825 self.assertEqual(123, message1.i) 2826 2827 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 2828 self.assertEqual( 2829 len(raw.item[1].message), 2830 message2.MergeFromString(raw.item[1].message)) 2831 self.assertEqual('foo', message2.str) 2832 2833 message3 = message_set_extensions_pb2.TestMessageSetExtension3() 2834 self.assertEqual( 2835 len(raw.item[2].message), 2836 message3.MergeFromString(raw.item[2].message)) 2837 self.assertEqual('bar', message3.text) 2838 2839 # Deserialize using the MessageSet wire format. 2840 proto2 = message_set_extensions_pb2.TestMessageSet() 2841 self.assertEqual( 2842 len(serialized), 2843 proto2.MergeFromString(serialized)) 2844 self.assertEqual(123, proto2.Extensions[extension1].i) 2845 self.assertEqual('foo', proto2.Extensions[extension2].str) 2846 self.assertEqual('bar', proto2.Extensions[extension3].text) 2847 2848 # Check byte size. 2849 self.assertEqual(proto2.ByteSize(), len(serialized)) 2850 self.assertEqual(proto.ByteSize(), len(serialized)) 2851 2852 def testMessageSetWireFormatUnknownExtension(self): 2853 # Create a message using the message set wire format with an unknown 2854 # message. 2855 raw = unittest_mset_pb2.RawMessageSet() 2856 2857 # Add an item. 2858 item = raw.item.add() 2859 item.type_id = 98418603 2860 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2861 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2862 message1.i = 12345 2863 item.message = message1.SerializeToString() 2864 2865 # Add a second, unknown extension. 2866 item = raw.item.add() 2867 item.type_id = 98418604 2868 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2869 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2870 message1.i = 12346 2871 item.message = message1.SerializeToString() 2872 2873 # Add another unknown extension. 2874 item = raw.item.add() 2875 item.type_id = 98418605 2876 message1 = message_set_extensions_pb2.TestMessageSetExtension2() 2877 message1.str = 'foo' 2878 item.message = message1.SerializeToString() 2879 2880 serialized = raw.SerializeToString() 2881 2882 # Parse message using the message set wire format. 2883 proto = message_set_extensions_pb2.TestMessageSet() 2884 self.assertEqual( 2885 len(serialized), 2886 proto.MergeFromString(serialized)) 2887 2888 # Check that the message parsed well. 2889 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2890 extension1 = extension_message1.message_set_extension 2891 self.assertEqual(12345, proto.Extensions[extension1].i) 2892 2893 def testUnknownFields(self): 2894 proto = unittest_pb2.TestAllTypes() 2895 test_util.SetAllFields(proto) 2896 2897 serialized = proto.SerializeToString() 2898 2899 # The empty message should be parsable with all of the fields 2900 # unknown. 2901 proto2 = unittest_pb2.TestEmptyMessage() 2902 2903 # Parsing this message should succeed. 2904 self.assertEqual( 2905 len(serialized), 2906 proto2.MergeFromString(serialized)) 2907 2908 # Now test with a int64 field set. 2909 proto = unittest_pb2.TestAllTypes() 2910 proto.optional_int64 = 0x0fffffffffffffff 2911 serialized = proto.SerializeToString() 2912 # The empty message should be parsable with all of the fields 2913 # unknown. 2914 proto2 = unittest_pb2.TestEmptyMessage() 2915 # Parsing this message should succeed. 2916 self.assertEqual( 2917 len(serialized), 2918 proto2.MergeFromString(serialized)) 2919 2920 def _CheckRaises(self, exc_class, callable_obj, exception): 2921 """This method checks if the exception type and message are as expected.""" 2922 try: 2923 callable_obj() 2924 except exc_class as ex: 2925 # Check if the exception message is the right one. 2926 self.assertEqual(exception, str(ex)) 2927 return 2928 else: 2929 raise self.failureException('%s not raised' % str(exc_class)) 2930 2931 def testSerializeUninitialized(self): 2932 proto = unittest_pb2.TestRequired() 2933 self._CheckRaises( 2934 message.EncodeError, 2935 proto.SerializeToString, 2936 'Message protobuf_unittest.TestRequired is missing required fields: ' 2937 'a,b,c') 2938 # Shouldn't raise exceptions. 2939 partial = proto.SerializePartialToString() 2940 2941 proto2 = unittest_pb2.TestRequired() 2942 self.assertFalse(proto2.HasField('a')) 2943 # proto2 ParseFromString does not check that required fields are set. 2944 proto2.ParseFromString(partial) 2945 self.assertFalse(proto2.HasField('a')) 2946 2947 proto.a = 1 2948 self._CheckRaises( 2949 message.EncodeError, 2950 proto.SerializeToString, 2951 'Message protobuf_unittest.TestRequired is missing required fields: b,c') 2952 # Shouldn't raise exceptions. 2953 partial = proto.SerializePartialToString() 2954 2955 proto.b = 2 2956 self._CheckRaises( 2957 message.EncodeError, 2958 proto.SerializeToString, 2959 'Message protobuf_unittest.TestRequired is missing required fields: c') 2960 # Shouldn't raise exceptions. 2961 partial = proto.SerializePartialToString() 2962 2963 proto.c = 3 2964 serialized = proto.SerializeToString() 2965 # Shouldn't raise exceptions. 2966 partial = proto.SerializePartialToString() 2967 2968 proto2 = unittest_pb2.TestRequired() 2969 self.assertEqual( 2970 len(serialized), 2971 proto2.MergeFromString(serialized)) 2972 self.assertEqual(1, proto2.a) 2973 self.assertEqual(2, proto2.b) 2974 self.assertEqual(3, proto2.c) 2975 self.assertEqual( 2976 len(partial), 2977 proto2.MergeFromString(partial)) 2978 self.assertEqual(1, proto2.a) 2979 self.assertEqual(2, proto2.b) 2980 self.assertEqual(3, proto2.c) 2981 2982 def testSerializeUninitializedSubMessage(self): 2983 proto = unittest_pb2.TestRequiredForeign() 2984 2985 # Sub-message doesn't exist yet, so this succeeds. 2986 proto.SerializeToString() 2987 2988 proto.optional_message.a = 1 2989 self._CheckRaises( 2990 message.EncodeError, 2991 proto.SerializeToString, 2992 'Message protobuf_unittest.TestRequiredForeign ' 2993 'is missing required fields: ' 2994 'optional_message.b,optional_message.c') 2995 2996 proto.optional_message.b = 2 2997 proto.optional_message.c = 3 2998 proto.SerializeToString() 2999 3000 proto.repeated_message.add().a = 1 3001 proto.repeated_message.add().b = 2 3002 self._CheckRaises( 3003 message.EncodeError, 3004 proto.SerializeToString, 3005 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 3006 'repeated_message[0].b,repeated_message[0].c,' 3007 'repeated_message[1].a,repeated_message[1].c') 3008 3009 proto.repeated_message[0].b = 2 3010 proto.repeated_message[0].c = 3 3011 proto.repeated_message[1].a = 1 3012 proto.repeated_message[1].c = 3 3013 proto.SerializeToString() 3014 3015 def testSerializeAllPackedFields(self): 3016 first_proto = unittest_pb2.TestPackedTypes() 3017 second_proto = unittest_pb2.TestPackedTypes() 3018 test_util.SetAllPackedFields(first_proto) 3019 serialized = first_proto.SerializeToString() 3020 self.assertEqual(first_proto.ByteSize(), len(serialized)) 3021 bytes_read = second_proto.MergeFromString(serialized) 3022 self.assertEqual(second_proto.ByteSize(), bytes_read) 3023 self.assertEqual(first_proto, second_proto) 3024 3025 def testSerializeAllPackedExtensions(self): 3026 first_proto = unittest_pb2.TestPackedExtensions() 3027 second_proto = unittest_pb2.TestPackedExtensions() 3028 test_util.SetAllPackedExtensions(first_proto) 3029 serialized = first_proto.SerializeToString() 3030 bytes_read = second_proto.MergeFromString(serialized) 3031 self.assertEqual(second_proto.ByteSize(), bytes_read) 3032 self.assertEqual(first_proto, second_proto) 3033 3034 def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): 3035 first_proto = unittest_pb2.TestPackedTypes() 3036 first_proto.packed_int32.extend([1, 2]) 3037 first_proto.packed_double.append(3.0) 3038 serialized = first_proto.SerializeToString() 3039 3040 second_proto = unittest_pb2.TestPackedTypes() 3041 second_proto.packed_int32.append(3) 3042 second_proto.packed_double.extend([1.0, 2.0]) 3043 second_proto.packed_sint32.append(4) 3044 3045 self.assertEqual( 3046 len(serialized), 3047 second_proto.MergeFromString(serialized)) 3048 self.assertEqual([3, 1, 2], second_proto.packed_int32) 3049 self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) 3050 self.assertEqual([4], second_proto.packed_sint32) 3051 3052 def testPackedFieldsWireFormat(self): 3053 proto = unittest_pb2.TestPackedTypes() 3054 proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes 3055 proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes 3056 proto.packed_float.append(2.0) # 4 bytes, will be before double 3057 serialized = proto.SerializeToString() 3058 self.assertEqual(proto.ByteSize(), len(serialized)) 3059 d = _MiniDecoder(serialized) 3060 ReadTag = d.ReadFieldNumberAndWireType 3061 self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 3062 self.assertEqual(1+1+1+2, d.ReadInt32()) 3063 self.assertEqual(1, d.ReadInt32()) 3064 self.assertEqual(2, d.ReadInt32()) 3065 self.assertEqual(150, d.ReadInt32()) 3066 self.assertEqual(3, d.ReadInt32()) 3067 self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 3068 self.assertEqual(4, d.ReadInt32()) 3069 self.assertEqual(2.0, d.ReadFloat()) 3070 self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 3071 self.assertEqual(8+8, d.ReadInt32()) 3072 self.assertEqual(1.0, d.ReadDouble()) 3073 self.assertEqual(1000.0, d.ReadDouble()) 3074 self.assertTrue(d.EndOfStream()) 3075 3076 def testParsePackedFromUnpacked(self): 3077 unpacked = unittest_pb2.TestUnpackedTypes() 3078 test_util.SetAllUnpackedFields(unpacked) 3079 packed = unittest_pb2.TestPackedTypes() 3080 serialized = unpacked.SerializeToString() 3081 self.assertEqual( 3082 len(serialized), 3083 packed.MergeFromString(serialized)) 3084 expected = unittest_pb2.TestPackedTypes() 3085 test_util.SetAllPackedFields(expected) 3086 self.assertEqual(expected, packed) 3087 3088 def testParseUnpackedFromPacked(self): 3089 packed = unittest_pb2.TestPackedTypes() 3090 test_util.SetAllPackedFields(packed) 3091 unpacked = unittest_pb2.TestUnpackedTypes() 3092 serialized = packed.SerializeToString() 3093 self.assertEqual( 3094 len(serialized), 3095 unpacked.MergeFromString(serialized)) 3096 expected = unittest_pb2.TestUnpackedTypes() 3097 test_util.SetAllUnpackedFields(expected) 3098 self.assertEqual(expected, unpacked) 3099 3100 def testFieldNumbers(self): 3101 proto = unittest_pb2.TestAllTypes() 3102 self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) 3103 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) 3104 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) 3105 self.assertEqual( 3106 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) 3107 self.assertEqual( 3108 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) 3109 self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) 3110 self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) 3111 self.assertEqual( 3112 unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) 3113 self.assertEqual( 3114 unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) 3115 3116 def testExtensionFieldNumbers(self): 3117 self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) 3118 self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) 3119 self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) 3120 self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) 3121 self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) 3122 self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) 3123 self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) 3124 self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) 3125 self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) 3126 self.assertEqual( 3127 unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) 3128 self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) 3129 self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 3130 21) 3131 self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) 3132 self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) 3133 self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) 3134 self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) 3135 self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) 3136 self.assertEqual( 3137 unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) 3138 self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) 3139 self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 3140 51) 3141 3142 def testFieldProperties(self): 3143 cls = unittest_pb2.TestAllTypes 3144 self.assertIs(cls.optional_int32.DESCRIPTOR, 3145 cls.DESCRIPTOR.fields_by_name['optional_int32']) 3146 self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER, 3147 cls.optional_int32.DESCRIPTOR.number) 3148 self.assertIs(cls.optional_nested_message.DESCRIPTOR, 3149 cls.DESCRIPTOR.fields_by_name['optional_nested_message']) 3150 self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 3151 cls.optional_nested_message.DESCRIPTOR.number) 3152 self.assertIs(cls.repeated_int32.DESCRIPTOR, 3153 cls.DESCRIPTOR.fields_by_name['repeated_int32']) 3154 self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER, 3155 cls.repeated_int32.DESCRIPTOR.number) 3156 3157 def testFieldDataDescriptor(self): 3158 msg = unittest_pb2.TestAllTypes() 3159 msg.optional_int32 = 42 3160 self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42) 3161 unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25) 3162 self.assertEqual(msg.optional_int32, 25) 3163 with self.assertRaises(AttributeError): 3164 del msg.optional_int32 3165 try: 3166 unittest_pb2.ForeignMessage.c.__get__(msg) 3167 except TypeError: 3168 pass # The cpp implementation cannot mix fields from other messages. 3169 # This test exercises a specific check that avoids a crash. 3170 else: 3171 pass # The python implementation allows fields from other messages. 3172 # This is useless, but works. 3173 3174 def testInitKwargs(self): 3175 proto = unittest_pb2.TestAllTypes( 3176 optional_int32=1, 3177 optional_string='foo', 3178 optional_bool=True, 3179 optional_bytes=b'bar', 3180 optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), 3181 optional_foreign_message=unittest_pb2.ForeignMessage(c=1), 3182 optional_nested_enum=unittest_pb2.TestAllTypes.FOO, 3183 optional_foreign_enum=unittest_pb2.FOREIGN_FOO, 3184 repeated_int32=[1, 2, 3]) 3185 self.assertTrue(proto.IsInitialized()) 3186 self.assertTrue(proto.HasField('optional_int32')) 3187 self.assertTrue(proto.HasField('optional_string')) 3188 self.assertTrue(proto.HasField('optional_bool')) 3189 self.assertTrue(proto.HasField('optional_bytes')) 3190 self.assertTrue(proto.HasField('optional_nested_message')) 3191 self.assertTrue(proto.HasField('optional_foreign_message')) 3192 self.assertTrue(proto.HasField('optional_nested_enum')) 3193 self.assertTrue(proto.HasField('optional_foreign_enum')) 3194 self.assertEqual(1, proto.optional_int32) 3195 self.assertEqual('foo', proto.optional_string) 3196 self.assertEqual(True, proto.optional_bool) 3197 self.assertEqual(b'bar', proto.optional_bytes) 3198 self.assertEqual(1, proto.optional_nested_message.bb) 3199 self.assertEqual(1, proto.optional_foreign_message.c) 3200 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 3201 proto.optional_nested_enum) 3202 self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) 3203 self.assertEqual([1, 2, 3], proto.repeated_int32) 3204 3205 def testInitArgsUnknownFieldName(self): 3206 def InitializeEmptyMessageWithExtraKeywordArg(): 3207 unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') 3208 self._CheckRaises( 3209 ValueError, 3210 InitializeEmptyMessageWithExtraKeywordArg, 3211 'Protocol message TestEmptyMessage has no "unknown" field.') 3212 3213 def testInitRequiredKwargs(self): 3214 proto = unittest_pb2.TestRequired(a=1, b=1, c=1) 3215 self.assertTrue(proto.IsInitialized()) 3216 self.assertTrue(proto.HasField('a')) 3217 self.assertTrue(proto.HasField('b')) 3218 self.assertTrue(proto.HasField('c')) 3219 self.assertFalse(proto.HasField('dummy2')) 3220 self.assertEqual(1, proto.a) 3221 self.assertEqual(1, proto.b) 3222 self.assertEqual(1, proto.c) 3223 3224 def testInitRequiredForeignKwargs(self): 3225 proto = unittest_pb2.TestRequiredForeign( 3226 optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) 3227 self.assertTrue(proto.IsInitialized()) 3228 self.assertTrue(proto.HasField('optional_message')) 3229 self.assertTrue(proto.optional_message.IsInitialized()) 3230 self.assertTrue(proto.optional_message.HasField('a')) 3231 self.assertTrue(proto.optional_message.HasField('b')) 3232 self.assertTrue(proto.optional_message.HasField('c')) 3233 self.assertFalse(proto.optional_message.HasField('dummy2')) 3234 self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), 3235 proto.optional_message) 3236 self.assertEqual(1, proto.optional_message.a) 3237 self.assertEqual(1, proto.optional_message.b) 3238 self.assertEqual(1, proto.optional_message.c) 3239 3240 def testInitRepeatedKwargs(self): 3241 proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) 3242 self.assertTrue(proto.IsInitialized()) 3243 self.assertEqual(1, proto.repeated_int32[0]) 3244 self.assertEqual(2, proto.repeated_int32[1]) 3245 self.assertEqual(3, proto.repeated_int32[2]) 3246 3247 3248@testing_refleaks.TestCase 3249class OptionsTest(unittest.TestCase): 3250 3251 def testMessageOptions(self): 3252 proto = message_set_extensions_pb2.TestMessageSet() 3253 self.assertEqual(True, 3254 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 3255 proto = unittest_pb2.TestAllTypes() 3256 self.assertEqual(False, 3257 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 3258 3259 def testPackedOptions(self): 3260 proto = unittest_pb2.TestAllTypes() 3261 proto.optional_int32 = 1 3262 proto.optional_double = 3.0 3263 for field_descriptor, _ in proto.ListFields(): 3264 self.assertEqual(False, field_descriptor.is_packed) 3265 3266 proto = unittest_pb2.TestPackedTypes() 3267 proto.packed_int32.append(1) 3268 proto.packed_double.append(3.0) 3269 for field_descriptor, _ in proto.ListFields(): 3270 self.assertEqual(True, field_descriptor.is_packed) 3271 self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, 3272 field_descriptor.label) 3273 3274 3275@testing_refleaks.TestCase 3276class ClassAPITest(unittest.TestCase): 3277 3278 @unittest.skipIf( 3279 api_implementation.Type() != 'python', 3280 'C++ implementation requires a call to MakeDescriptor()') 3281 @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable') 3282 def testMakeClassWithNestedDescriptor(self): 3283 leaf_desc = descriptor.Descriptor( 3284 'leaf', 'package.parent.child.leaf', '', 3285 containing_type=None, fields=[], 3286 nested_types=[], enum_types=[], 3287 extensions=[], 3288 # pylint: disable=protected-access 3289 create_key=descriptor._internal_create_key) 3290 child_desc = descriptor.Descriptor( 3291 'child', 'package.parent.child', '', 3292 containing_type=None, fields=[], 3293 nested_types=[leaf_desc], enum_types=[], 3294 extensions=[], 3295 # pylint: disable=protected-access 3296 create_key=descriptor._internal_create_key) 3297 sibling_desc = descriptor.Descriptor( 3298 'sibling', 'package.parent.sibling', 3299 '', containing_type=None, fields=[], 3300 nested_types=[], enum_types=[], 3301 extensions=[], 3302 # pylint: disable=protected-access 3303 create_key=descriptor._internal_create_key) 3304 parent_desc = descriptor.Descriptor( 3305 'parent', 'package.parent', '', 3306 containing_type=None, fields=[], 3307 nested_types=[child_desc, sibling_desc], 3308 enum_types=[], extensions=[], 3309 # pylint: disable=protected-access 3310 create_key=descriptor._internal_create_key) 3311 message_factory.GetMessageClass(parent_desc) 3312 3313 def _GetSerializedFileDescriptor(self, name): 3314 """Get a serialized representation of a test FileDescriptorProto. 3315 3316 Args: 3317 name: All calls to this must use a unique message name, to avoid 3318 collisions in the cpp descriptor pool. 3319 Returns: 3320 A string containing the serialized form of a test FileDescriptorProto. 3321 """ 3322 file_descriptor_str = ( 3323 'message_type {' 3324 ' name: "' + name + '"' 3325 ' field {' 3326 ' name: "flat"' 3327 ' number: 1' 3328 ' label: LABEL_REPEATED' 3329 ' type: TYPE_UINT32' 3330 ' }' 3331 ' field {' 3332 ' name: "bar"' 3333 ' number: 2' 3334 ' label: LABEL_OPTIONAL' 3335 ' type: TYPE_MESSAGE' 3336 ' type_name: "Bar"' 3337 ' }' 3338 ' nested_type {' 3339 ' name: "Bar"' 3340 ' field {' 3341 ' name: "baz"' 3342 ' number: 3' 3343 ' label: LABEL_OPTIONAL' 3344 ' type: TYPE_MESSAGE' 3345 ' type_name: "Baz"' 3346 ' }' 3347 ' nested_type {' 3348 ' name: "Baz"' 3349 ' enum_type {' 3350 ' name: "deep_enum"' 3351 ' value {' 3352 ' name: "VALUE_A"' 3353 ' number: 0' 3354 ' }' 3355 ' }' 3356 ' field {' 3357 ' name: "deep"' 3358 ' number: 4' 3359 ' label: LABEL_OPTIONAL' 3360 ' type: TYPE_UINT32' 3361 ' }' 3362 ' }' 3363 ' }' 3364 '}') 3365 file_descriptor = descriptor_pb2.FileDescriptorProto() 3366 text_format.Merge(file_descriptor_str, file_descriptor) 3367 return file_descriptor.SerializeToString() 3368 3369 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3370 # This test can only run once; the second time, it raises errors about 3371 # conflicting message descriptors. 3372 def testParsingFlatClassWithExplicitClassDeclaration(self): 3373 """Test that the generated class can parse a flat message.""" 3374 # TODO: This test fails with cpp implementation in the call 3375 # of six.with_metaclass(). The other two callsites of with_metaclass 3376 # in this file are both excluded from cpp test, so it might be expected 3377 # to fail. Need someone more familiar with the python code to take a 3378 # look at this. 3379 if api_implementation.Type() != 'python': 3380 return 3381 file_descriptor = descriptor_pb2.FileDescriptorProto() 3382 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) 3383 msg_descriptor = descriptor.MakeDescriptor( 3384 file_descriptor.message_type[0]) 3385 3386 class MessageClass( 3387 message.Message, metaclass=reflection.GeneratedProtocolMessageType): 3388 DESCRIPTOR = msg_descriptor 3389 msg = MessageClass() 3390 msg_str = ( 3391 'flat: 0 ' 3392 'flat: 1 ' 3393 'flat: 2 ') 3394 text_format.Merge(msg_str, msg) 3395 self.assertEqual(msg.flat, [0, 1, 2]) 3396 3397 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3398 def testParsingFlatClass(self): 3399 """Test that the generated class can parse a flat message.""" 3400 file_descriptor = descriptor_pb2.FileDescriptorProto() 3401 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) 3402 msg_descriptor = descriptor.MakeDescriptor( 3403 file_descriptor.message_type[0]) 3404 msg_class = message_factory.GetMessageClass(msg_descriptor) 3405 msg = msg_class() 3406 msg_str = ( 3407 'flat: 0 ' 3408 'flat: 1 ' 3409 'flat: 2 ') 3410 text_format.Merge(msg_str, msg) 3411 self.assertEqual(msg.flat, [0, 1, 2]) 3412 3413 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3414 def testParsingNestedClass(self): 3415 """Test that the generated class can parse a nested message.""" 3416 file_descriptor = descriptor_pb2.FileDescriptorProto() 3417 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) 3418 msg_descriptor = descriptor.MakeDescriptor( 3419 file_descriptor.message_type[0]) 3420 msg_class = message_factory.GetMessageClass(msg_descriptor) 3421 msg = msg_class() 3422 msg_str = ( 3423 'bar {' 3424 ' baz {' 3425 ' deep: 4' 3426 ' }' 3427 '}') 3428 text_format.Merge(msg_str, msg) 3429 self.assertEqual(msg.bar.baz.deep, 4) 3430 3431if __name__ == '__main__': 3432 unittest.main() 3433