1# -*- coding: utf-8 -*- 2# Protocol Buffers - Google's data interchange format 3# Copyright 2008 Google Inc. All rights reserved. 4# 5# Use of this source code is governed by a BSD-style 6# license that can be found in the LICENSE file or at 7# https://developers.google.com/open-source/licenses/bsd 8 9"""Tests python protocol buffers against the golden message. 10 11Note that the golden messages exercise every known field type, thus this 12test ends up exercising and verifying nearly all of the parsing and 13serialization code in the whole library. 14""" 15 16__author__ = 'gps@google.com (Gregory P. Smith)' 17 18import collections 19import copy 20import math 21import operator 22import pickle 23import pydoc 24import sys 25import types 26import unittest 27from unittest import mock 28import warnings 29 30cmp = lambda x, y: (x > y) - (x < y) 31 32from google.protobuf.internal import api_implementation # pylint: disable=g-import-not-at-top 33from google.protobuf.internal import encoder 34from google.protobuf.internal import enum_type_wrapper 35from google.protobuf.internal import more_extensions_pb2 36from google.protobuf.internal import more_messages_pb2 37from google.protobuf.internal import packed_field_test_pb2 38from google.protobuf.internal import self_recursive_pb2 39from google.protobuf.internal import test_proto3_optional_pb2 40from google.protobuf.internal import test_util 41from google.protobuf.internal import testing_refleaks 42from google.protobuf import descriptor 43from google.protobuf import message 44from google.protobuf.internal import _parameterized 45from google.protobuf import map_proto2_unittest_pb2 46from google.protobuf import map_unittest_pb2 47from google.protobuf import unittest_pb2 48from google.protobuf import unittest_proto3_arena_pb2 49 50UCS2_MAXUNICODE = 65535 51 52warnings.simplefilter('error', DeprecationWarning) 53 54@_parameterized.named_parameters(('_proto2', unittest_pb2), 55 ('_proto3', unittest_proto3_arena_pb2)) 56@testing_refleaks.TestCase 57class MessageTest(unittest.TestCase): 58 59 def testBadUtf8String(self, message_module): 60 if api_implementation.Type() != 'python': 61 self.skipTest('Skipping testBadUtf8String, currently only the python ' 62 'api implementation raises UnicodeDecodeError when a ' 63 'string field contains bad utf-8.') 64 bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') 65 with self.assertRaises(UnicodeDecodeError) as context: 66 message_module.TestAllTypes.FromString(bad_utf8_data) 67 self.assertIn('TestAllTypes.optional_string', str(context.exception)) 68 69 def testParseErrors(self, message_module): 70 msg = message_module.TestAllTypes() 71 self.assertRaises(TypeError, msg.FromString, 0) 72 self.assertRaises(Exception, msg.FromString, '0') 73 # TODO: Fix cpp extension to raise error instead of warning. 74 # b/27494216 75 end_tag = encoder.TagBytes(1, 4) 76 if (api_implementation.Type() == 'python' or 77 api_implementation.Type() == 'upb'): 78 with self.assertRaises(message.DecodeError) as context: 79 msg.FromString(end_tag) 80 if api_implementation.Type() == 'python': 81 # Only pure-Python has an error message this specific. 82 self.assertEqual('Unexpected end-group tag.', str(context.exception)) 83 84 # Field number 0 is illegal. 85 self.assertRaises(message.DecodeError, msg.FromString, b'\3\4') 86 87 def testDeterminismParameters(self, message_module): 88 # This message is always deterministically serialized, even if determinism 89 # is disabled, so we can use it to verify that all the determinism 90 # parameters work correctly. 91 golden_data = (b'\xe2\x02\nOne string' 92 b'\xe2\x02\nTwo string' 93 b'\xe2\x02\nRed string' 94 b'\xe2\x02\x0bBlue string') 95 golden_message = message_module.TestAllTypes() 96 golden_message.repeated_string.extend([ 97 'One string', 98 'Two string', 99 'Red string', 100 'Blue string', 101 ]) 102 self.assertEqual(golden_data, 103 golden_message.SerializeToString(deterministic=None)) 104 self.assertEqual(golden_data, 105 golden_message.SerializeToString(deterministic=False)) 106 self.assertEqual(golden_data, 107 golden_message.SerializeToString(deterministic=True)) 108 109 class BadArgError(Exception): 110 pass 111 112 class BadArg(object): 113 114 def __nonzero__(self): 115 raise BadArgError() 116 117 def __bool__(self): 118 raise BadArgError() 119 120 with self.assertRaises(BadArgError): 121 golden_message.SerializeToString(deterministic=BadArg()) 122 123 def testPickleSupport(self, message_module): 124 golden_message = message_module.TestAllTypes() 125 test_util.SetAllFields(golden_message) 126 golden_data = golden_message.SerializeToString() 127 golden_message = message_module.TestAllTypes() 128 golden_message.ParseFromString(golden_data) 129 pickled_message = pickle.dumps(golden_message) 130 131 unpickled_message = pickle.loads(pickled_message) 132 self.assertEqual(unpickled_message, golden_message) 133 134 def testPickleNestedMessage(self, message_module): 135 golden_message = message_module.TestPickleNestedMessage.NestedMessage(bb=1) 136 pickled_message = pickle.dumps(golden_message) 137 unpickled_message = pickle.loads(pickled_message) 138 self.assertEqual(unpickled_message, golden_message) 139 140 def testPickleNestedNestedMessage(self, message_module): 141 cls = message_module.TestPickleNestedMessage.NestedMessage 142 golden_message = cls.NestedNestedMessage(cc=1) 143 pickled_message = pickle.dumps(golden_message) 144 unpickled_message = pickle.loads(pickled_message) 145 self.assertEqual(unpickled_message, golden_message) 146 147 def testPositiveInfinity(self, message_module): 148 if message_module is unittest_pb2: 149 golden_data = (b'\x5D\x00\x00\x80\x7F' 150 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 151 b'\xCD\x02\x00\x00\x80\x7F' 152 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') 153 else: 154 golden_data = (b'\x5D\x00\x00\x80\x7F' 155 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 156 b'\xCA\x02\x04\x00\x00\x80\x7F' 157 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 158 159 golden_message = message_module.TestAllTypes() 160 golden_message.ParseFromString(golden_data) 161 self.assertEqual(golden_message.optional_float, math.inf) 162 self.assertEqual(golden_message.optional_double, math.inf) 163 self.assertEqual(golden_message.repeated_float[0], math.inf) 164 self.assertEqual(golden_message.repeated_double[0], math.inf) 165 self.assertEqual(golden_data, golden_message.SerializeToString()) 166 167 def testNegativeInfinity(self, message_module): 168 if message_module is unittest_pb2: 169 golden_data = (b'\x5D\x00\x00\x80\xFF' 170 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 171 b'\xCD\x02\x00\x00\x80\xFF' 172 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') 173 else: 174 golden_data = (b'\x5D\x00\x00\x80\xFF' 175 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 176 b'\xCA\x02\x04\x00\x00\x80\xFF' 177 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 178 179 golden_message = message_module.TestAllTypes() 180 golden_message.ParseFromString(golden_data) 181 self.assertEqual(golden_message.optional_float, -math.inf) 182 self.assertEqual(golden_message.optional_double, -math.inf) 183 self.assertEqual(golden_message.repeated_float[0], -math.inf) 184 self.assertEqual(golden_message.repeated_double[0], -math.inf) 185 self.assertEqual(golden_data, golden_message.SerializeToString()) 186 187 def testNotANumber(self, message_module): 188 golden_data = (b'\x5D\x00\x00\xC0\x7F' 189 b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' 190 b'\xCD\x02\x00\x00\xC0\x7F' 191 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') 192 golden_message = message_module.TestAllTypes() 193 golden_message.ParseFromString(golden_data) 194 self.assertTrue(math.isnan(golden_message.optional_float)) 195 self.assertTrue(math.isnan(golden_message.optional_double)) 196 self.assertTrue(math.isnan(golden_message.repeated_float[0])) 197 self.assertTrue(math.isnan(golden_message.repeated_double[0])) 198 199 # The protocol buffer may serialize to any one of multiple different 200 # representations of a NaN. Rather than verify a specific representation, 201 # verify the serialized string can be converted into a correctly 202 # behaving protocol buffer. 203 serialized = golden_message.SerializeToString() 204 message = message_module.TestAllTypes() 205 message.ParseFromString(serialized) 206 self.assertTrue(math.isnan(message.optional_float)) 207 self.assertTrue(math.isnan(message.optional_double)) 208 self.assertTrue(math.isnan(message.repeated_float[0])) 209 self.assertTrue(math.isnan(message.repeated_double[0])) 210 211 def testPositiveInfinityPacked(self, message_module): 212 golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' 213 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 214 golden_message = message_module.TestPackedTypes() 215 golden_message.ParseFromString(golden_data) 216 self.assertEqual(golden_message.packed_float[0], math.inf) 217 self.assertEqual(golden_message.packed_double[0], math.inf) 218 self.assertEqual(golden_data, golden_message.SerializeToString()) 219 220 def testNegativeInfinityPacked(self, message_module): 221 golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' 222 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 223 golden_message = message_module.TestPackedTypes() 224 golden_message.ParseFromString(golden_data) 225 self.assertEqual(golden_message.packed_float[0], -math.inf) 226 self.assertEqual(golden_message.packed_double[0], -math.inf) 227 self.assertEqual(golden_data, golden_message.SerializeToString()) 228 229 def testNotANumberPacked(self, message_module): 230 golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' 231 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') 232 golden_message = message_module.TestPackedTypes() 233 golden_message.ParseFromString(golden_data) 234 self.assertTrue(math.isnan(golden_message.packed_float[0])) 235 self.assertTrue(math.isnan(golden_message.packed_double[0])) 236 237 serialized = golden_message.SerializeToString() 238 message = message_module.TestPackedTypes() 239 message.ParseFromString(serialized) 240 self.assertTrue(math.isnan(message.packed_float[0])) 241 self.assertTrue(math.isnan(message.packed_double[0])) 242 243 def testExtremeFloatValues(self, message_module): 244 message = message_module.TestAllTypes() 245 246 # Most positive exponent, no significand bits set. 247 kMostPosExponentNoSigBits = math.pow(2, 127) 248 message.optional_float = kMostPosExponentNoSigBits 249 message.ParseFromString(message.SerializeToString()) 250 self.assertTrue(message.optional_float == kMostPosExponentNoSigBits) 251 252 # Most positive exponent, one significand bit set. 253 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127) 254 message.optional_float = kMostPosExponentOneSigBit 255 message.ParseFromString(message.SerializeToString()) 256 self.assertTrue(message.optional_float == kMostPosExponentOneSigBit) 257 258 # Repeat last two cases with values of same magnitude, but negative. 259 message.optional_float = -kMostPosExponentNoSigBits 260 message.ParseFromString(message.SerializeToString()) 261 self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits) 262 263 message.optional_float = -kMostPosExponentOneSigBit 264 message.ParseFromString(message.SerializeToString()) 265 self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit) 266 267 # Most negative exponent, no significand bits set. 268 kMostNegExponentNoSigBits = math.pow(2, -127) 269 message.optional_float = kMostNegExponentNoSigBits 270 message.ParseFromString(message.SerializeToString()) 271 self.assertTrue(message.optional_float == kMostNegExponentNoSigBits) 272 273 # Most negative exponent, one significand bit set. 274 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127) 275 message.optional_float = kMostNegExponentOneSigBit 276 message.ParseFromString(message.SerializeToString()) 277 self.assertTrue(message.optional_float == kMostNegExponentOneSigBit) 278 279 # Repeat last two cases with values of the same magnitude, but negative. 280 message.optional_float = -kMostNegExponentNoSigBits 281 message.ParseFromString(message.SerializeToString()) 282 self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits) 283 284 message.optional_float = -kMostNegExponentOneSigBit 285 message.ParseFromString(message.SerializeToString()) 286 self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) 287 288 # Max 4 bytes float value 289 max_float = float.fromhex('0x1.fffffep+127') 290 message.optional_float = max_float 291 self.assertAlmostEqual(message.optional_float, max_float) 292 serialized_data = message.SerializeToString() 293 message.ParseFromString(serialized_data) 294 self.assertAlmostEqual(message.optional_float, max_float) 295 296 # Test set double to float field. 297 message.optional_float = 3.4028235e+39 298 self.assertEqual(message.optional_float, float('inf')) 299 serialized_data = message.SerializeToString() 300 message.ParseFromString(serialized_data) 301 self.assertEqual(message.optional_float, float('inf')) 302 303 message.optional_float = -3.4028235e+39 304 self.assertEqual(message.optional_float, float('-inf')) 305 306 message.optional_float = 1.4028235e-39 307 self.assertAlmostEqual(message.optional_float, 1.4028235e-39) 308 309 def testExtremeDoubleValues(self, message_module): 310 message = message_module.TestAllTypes() 311 312 # Most positive exponent, no significand bits set. 313 kMostPosExponentNoSigBits = math.pow(2, 1023) 314 message.optional_double = kMostPosExponentNoSigBits 315 message.ParseFromString(message.SerializeToString()) 316 self.assertTrue(message.optional_double == kMostPosExponentNoSigBits) 317 318 # Most positive exponent, one significand bit set. 319 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023) 320 message.optional_double = kMostPosExponentOneSigBit 321 message.ParseFromString(message.SerializeToString()) 322 self.assertTrue(message.optional_double == kMostPosExponentOneSigBit) 323 324 # Repeat last two cases with values of same magnitude, but negative. 325 message.optional_double = -kMostPosExponentNoSigBits 326 message.ParseFromString(message.SerializeToString()) 327 self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits) 328 329 message.optional_double = -kMostPosExponentOneSigBit 330 message.ParseFromString(message.SerializeToString()) 331 self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit) 332 333 # Most negative exponent, no significand bits set. 334 kMostNegExponentNoSigBits = math.pow(2, -1023) 335 message.optional_double = kMostNegExponentNoSigBits 336 message.ParseFromString(message.SerializeToString()) 337 self.assertTrue(message.optional_double == kMostNegExponentNoSigBits) 338 339 # Most negative exponent, one significand bit set. 340 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023) 341 message.optional_double = kMostNegExponentOneSigBit 342 message.ParseFromString(message.SerializeToString()) 343 self.assertTrue(message.optional_double == kMostNegExponentOneSigBit) 344 345 # Repeat last two cases with values of the same magnitude, but negative. 346 message.optional_double = -kMostNegExponentNoSigBits 347 message.ParseFromString(message.SerializeToString()) 348 self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits) 349 350 message.optional_double = -kMostNegExponentOneSigBit 351 message.ParseFromString(message.SerializeToString()) 352 self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) 353 354 def testFloatPrinting(self, message_module): 355 message = message_module.TestAllTypes() 356 message.optional_float = 2.0 357 self.assertEqual(str(message), 'optional_float: 2.0\n') 358 359 def testFloatNanPrinting(self, message_module): 360 message = message_module.TestAllTypes() 361 message.optional_float = float('nan') 362 self.assertEqual(str(message), 'optional_float: nan\n') 363 364 def testHighPrecisionFloatPrinting(self, message_module): 365 msg = message_module.TestAllTypes() 366 msg.optional_float = 0.12345678912345678 367 old_float = msg.optional_float 368 msg.ParseFromString(msg.SerializeToString()) 369 self.assertEqual(old_float, msg.optional_float) 370 371 def testDoubleNanPrinting(self, message_module): 372 message = message_module.TestAllTypes() 373 message.optional_double = float('nan') 374 self.assertEqual(str(message), 'optional_double: nan\n') 375 376 def testHighPrecisionDoublePrinting(self, message_module): 377 msg = message_module.TestAllTypes() 378 msg.optional_double = 0.12345678912345678 379 self.assertEqual(str(msg), 'optional_double: 0.12345678912345678\n') 380 381 def testUnknownFieldPrinting(self, message_module): 382 populated = message_module.TestAllTypes() 383 test_util.SetAllNonLazyFields(populated) 384 empty = message_module.TestEmptyMessage() 385 empty.ParseFromString(populated.SerializeToString()) 386 self.assertEqual(str(empty), '') 387 388 def testCopyFromEmpty(self, message_module): 389 msg = message_module.NestedTestAllTypes() 390 test_msg = message_module.NestedTestAllTypes() 391 test_util.SetAllFields(test_msg.payload) 392 self.assertTrue(test_msg.HasField('payload')) 393 # Copy from empty message 394 test_msg.CopyFrom(msg) 395 self.assertEqual(0, len(test_msg.ListFields())) 396 397 test_util.SetAllFields(test_msg.payload) 398 self.assertTrue(test_msg.HasField('payload')) 399 # Copy from a non exist message 400 test_msg.CopyFrom(msg.child) 401 self.assertFalse(test_msg.HasField('payload')) 402 self.assertEqual(0, len(test_msg.ListFields())) 403 404 def testAppendRepeatedCompositeField(self, message_module): 405 msg = message_module.TestAllTypes() 406 msg.repeated_nested_message.append( 407 message_module.TestAllTypes.NestedMessage(bb=1)) 408 nested = message_module.TestAllTypes.NestedMessage(bb=2) 409 msg.repeated_nested_message.append(nested) 410 try: 411 msg.repeated_nested_message.append(1) 412 except TypeError: 413 pass 414 self.assertEqual(2, len(msg.repeated_nested_message)) 415 self.assertEqual([1, 2], [m.bb for m in msg.repeated_nested_message]) 416 417 def testInsertRepeatedCompositeField(self, message_module): 418 msg = message_module.TestAllTypes() 419 msg.repeated_nested_message.insert( 420 -1, message_module.TestAllTypes.NestedMessage(bb=1)) 421 sub_msg = msg.repeated_nested_message[0] 422 msg.repeated_nested_message.insert( 423 0, message_module.TestAllTypes.NestedMessage(bb=2)) 424 msg.repeated_nested_message.insert( 425 99, message_module.TestAllTypes.NestedMessage(bb=3)) 426 msg.repeated_nested_message.insert( 427 -2, message_module.TestAllTypes.NestedMessage(bb=-1)) 428 msg.repeated_nested_message.insert( 429 -1000, message_module.TestAllTypes.NestedMessage(bb=-1000)) 430 try: 431 msg.repeated_nested_message.insert(1, 999) 432 except TypeError: 433 pass 434 self.assertEqual(5, len(msg.repeated_nested_message)) 435 self.assertEqual([-1000, 2, -1, 1, 3], 436 [m.bb for m in msg.repeated_nested_message]) 437 self.assertEqual( 438 str(msg), 'repeated_nested_message {\n' 439 ' bb: -1000\n' 440 '}\n' 441 'repeated_nested_message {\n' 442 ' bb: 2\n' 443 '}\n' 444 'repeated_nested_message {\n' 445 ' bb: -1\n' 446 '}\n' 447 'repeated_nested_message {\n' 448 ' bb: 1\n' 449 '}\n' 450 'repeated_nested_message {\n' 451 ' bb: 3\n' 452 '}\n') 453 self.assertEqual(sub_msg.bb, 1) 454 455 def testAssignRepeatedField(self, message_module): 456 msg = message_module.NestedTestAllTypes() 457 msg.payload.repeated_int32[:] = [1, 2, 3, 4] 458 self.assertEqual(4, len(msg.payload.repeated_int32)) 459 self.assertEqual([1, 2, 3, 4], msg.payload.repeated_int32) 460 461 def testMergeFromRepeatedField(self, message_module): 462 msg = message_module.TestAllTypes() 463 msg.repeated_int32.append(1) 464 msg.repeated_int32.append(3) 465 msg.repeated_nested_message.add(bb=1) 466 msg.repeated_nested_message.add(bb=2) 467 other_msg = message_module.TestAllTypes() 468 other_msg.repeated_nested_message.add(bb=3) 469 other_msg.repeated_nested_message.add(bb=4) 470 other_msg.repeated_int32.append(5) 471 other_msg.repeated_int32.append(7) 472 473 msg.repeated_int32.MergeFrom(other_msg.repeated_int32) 474 self.assertEqual(4, len(msg.repeated_int32)) 475 476 msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message) 477 self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message]) 478 479 def testInternalMergeWithMissingRequiredField(self, message_module): 480 req = more_messages_pb2.RequiredField() 481 more_messages_pb2.RequiredWrapper(request=req) 482 483 def testMergeFromMissingRequiredField(self, message_module): 484 msg = more_messages_pb2.RequiredField() 485 message = more_messages_pb2.RequiredField() 486 message.MergeFrom(msg) 487 self.assertEqual(msg, message) 488 489 def testAddWrongRepeatedNestedField(self, message_module): 490 msg = message_module.TestAllTypes() 491 try: 492 msg.repeated_nested_message.add('wrong') 493 except TypeError: 494 pass 495 try: 496 msg.repeated_nested_message.add(value_field='wrong') 497 except ValueError: 498 pass 499 self.assertEqual(len(msg.repeated_nested_message), 0) 500 501 def testRepeatedContains(self, message_module): 502 msg = message_module.TestAllTypes() 503 msg.repeated_int32.extend([1, 2, 3]) 504 self.assertIn(2, msg.repeated_int32) 505 self.assertNotIn(0, msg.repeated_int32) 506 507 msg.repeated_nested_message.add(bb=1) 508 sub_msg1 = msg.repeated_nested_message[0] 509 sub_msg2 = message_module.TestAllTypes.NestedMessage(bb=2) 510 sub_msg3 = message_module.TestAllTypes.NestedMessage(bb=3) 511 msg.repeated_nested_message.append(sub_msg2) 512 msg.repeated_nested_message.insert(0, sub_msg3) 513 self.assertIn(sub_msg1, msg.repeated_nested_message) 514 self.assertIn(sub_msg2, msg.repeated_nested_message) 515 self.assertIn(sub_msg3, msg.repeated_nested_message) 516 517 def testRepeatedScalarIterable(self, message_module): 518 msg = message_module.TestAllTypes() 519 msg.repeated_int32.extend([1, 2, 3]) 520 add = 0 521 for item in msg.repeated_int32: 522 add += item 523 self.assertEqual(add, 6) 524 525 def testRepeatedNestedFieldIteration(self, message_module): 526 msg = message_module.TestAllTypes() 527 msg.repeated_nested_message.add(bb=1) 528 msg.repeated_nested_message.add(bb=2) 529 msg.repeated_nested_message.add(bb=3) 530 msg.repeated_nested_message.add(bb=4) 531 532 self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message]) 533 self.assertEqual([4, 3, 2, 1], 534 [m.bb for m in reversed(msg.repeated_nested_message)]) 535 self.assertEqual([4, 3, 2, 1], 536 [m.bb for m in msg.repeated_nested_message[::-1]]) 537 538 def testSortEmptyRepeated(self, message_module): 539 message = message_module.NestedTestAllTypes() 540 self.assertFalse(message.HasField('child')) 541 self.assertFalse(message.HasField('payload')) 542 message.child.repeated_child.sort() 543 message.payload.repeated_int32.sort() 544 self.assertFalse(message.HasField('child')) 545 self.assertFalse(message.HasField('payload')) 546 547 def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module): 548 """Check some different types with the default comparator.""" 549 message = message_module.TestAllTypes() 550 551 # TODO: would testing more scalar types strengthen test? 552 message.repeated_int32.append(1) 553 message.repeated_int32.append(3) 554 message.repeated_int32.append(2) 555 message.repeated_int32.sort() 556 self.assertEqual(message.repeated_int32[0], 1) 557 self.assertEqual(message.repeated_int32[1], 2) 558 self.assertEqual(message.repeated_int32[2], 3) 559 self.assertEqual(str(message.repeated_int32), str([1, 2, 3])) 560 561 message.repeated_float.append(1.1) 562 message.repeated_float.append(1.3) 563 message.repeated_float.append(1.2) 564 message.repeated_float.sort() 565 self.assertAlmostEqual(message.repeated_float[0], 1.1) 566 self.assertAlmostEqual(message.repeated_float[1], 1.2) 567 self.assertAlmostEqual(message.repeated_float[2], 1.3) 568 569 message.repeated_string.append('a') 570 message.repeated_string.append('c') 571 message.repeated_string.append('b') 572 message.repeated_string.sort() 573 self.assertEqual(message.repeated_string[0], 'a') 574 self.assertEqual(message.repeated_string[1], 'b') 575 self.assertEqual(message.repeated_string[2], 'c') 576 self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c'])) 577 578 message.repeated_bytes.append(b'a') 579 message.repeated_bytes.append(b'c') 580 message.repeated_bytes.append(b'b') 581 message.repeated_bytes.sort() 582 self.assertEqual(message.repeated_bytes[0], b'a') 583 self.assertEqual(message.repeated_bytes[1], b'b') 584 self.assertEqual(message.repeated_bytes[2], b'c') 585 self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c'])) 586 587 def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): 588 """Check some different types with custom comparator.""" 589 message = message_module.TestAllTypes() 590 591 message.repeated_int32.append(-3) 592 message.repeated_int32.append(-2) 593 message.repeated_int32.append(-1) 594 message.repeated_int32.sort(key=abs) 595 self.assertEqual(message.repeated_int32[0], -1) 596 self.assertEqual(message.repeated_int32[1], -2) 597 self.assertEqual(message.repeated_int32[2], -3) 598 599 message.repeated_string.append('aaa') 600 message.repeated_string.append('bb') 601 message.repeated_string.append('c') 602 message.repeated_string.sort(key=len) 603 self.assertEqual(message.repeated_string[0], 'c') 604 self.assertEqual(message.repeated_string[1], 'bb') 605 self.assertEqual(message.repeated_string[2], 'aaa') 606 607 def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module): 608 """Check passing a custom comparator to sort a repeated composite field.""" 609 message = message_module.TestAllTypes() 610 611 message.repeated_nested_message.add().bb = 1 612 message.repeated_nested_message.add().bb = 3 613 message.repeated_nested_message.add().bb = 2 614 message.repeated_nested_message.add().bb = 6 615 message.repeated_nested_message.add().bb = 5 616 message.repeated_nested_message.add().bb = 4 617 message.repeated_nested_message.sort(key=operator.attrgetter('bb')) 618 self.assertEqual(message.repeated_nested_message[0].bb, 1) 619 self.assertEqual(message.repeated_nested_message[1].bb, 2) 620 self.assertEqual(message.repeated_nested_message[2].bb, 3) 621 self.assertEqual(message.repeated_nested_message[3].bb, 4) 622 self.assertEqual(message.repeated_nested_message[4].bb, 5) 623 self.assertEqual(message.repeated_nested_message[5].bb, 6) 624 self.assertEqual( 625 str(message.repeated_nested_message), 626 '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]') 627 628 def testSortingRepeatedCompositeFieldsStable(self, message_module): 629 """Check passing a custom comparator to sort a repeated composite field.""" 630 message = message_module.TestAllTypes() 631 632 message.repeated_nested_message.add().bb = 21 633 message.repeated_nested_message.add().bb = 20 634 message.repeated_nested_message.add().bb = 13 635 message.repeated_nested_message.add().bb = 33 636 message.repeated_nested_message.add().bb = 11 637 message.repeated_nested_message.add().bb = 24 638 message.repeated_nested_message.add().bb = 10 639 message.repeated_nested_message.sort(key=lambda z: z.bb // 10) 640 self.assertEqual([13, 11, 10, 21, 20, 24, 33], 641 [n.bb for n in message.repeated_nested_message]) 642 643 # Make sure that for the C++ implementation, the underlying fields 644 # are actually reordered. 645 pb = message.SerializeToString() 646 message.Clear() 647 message.MergeFromString(pb) 648 self.assertEqual([13, 11, 10, 21, 20, 24, 33], 649 [n.bb for n in message.repeated_nested_message]) 650 651 def testRepeatedCompositeFieldSortArguments(self, message_module): 652 """Check sorting a repeated composite field using list.sort() arguments.""" 653 message = message_module.TestAllTypes() 654 655 get_bb = operator.attrgetter('bb') 656 message.repeated_nested_message.add().bb = 1 657 message.repeated_nested_message.add().bb = 3 658 message.repeated_nested_message.add().bb = 2 659 message.repeated_nested_message.add().bb = 6 660 message.repeated_nested_message.add().bb = 5 661 message.repeated_nested_message.add().bb = 4 662 message.repeated_nested_message.sort(key=get_bb) 663 self.assertEqual([k.bb for k in message.repeated_nested_message], 664 [1, 2, 3, 4, 5, 6]) 665 message.repeated_nested_message.sort(key=get_bb, reverse=True) 666 self.assertEqual([k.bb for k in message.repeated_nested_message], 667 [6, 5, 4, 3, 2, 1]) 668 669 def testRepeatedScalarFieldSortArguments(self, message_module): 670 """Check sorting a scalar field using list.sort() arguments.""" 671 message = message_module.TestAllTypes() 672 673 message.repeated_int32.append(-3) 674 message.repeated_int32.append(-2) 675 message.repeated_int32.append(-1) 676 message.repeated_int32.sort(key=abs) 677 self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) 678 message.repeated_int32.sort(key=abs, reverse=True) 679 self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) 680 681 message.repeated_string.append('aaa') 682 message.repeated_string.append('bb') 683 message.repeated_string.append('c') 684 message.repeated_string.sort(key=len) 685 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) 686 message.repeated_string.sort(key=len, reverse=True) 687 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) 688 689 def testRepeatedFieldsComparable(self, message_module): 690 m1 = message_module.TestAllTypes() 691 m2 = message_module.TestAllTypes() 692 m1.repeated_int32.append(0) 693 m1.repeated_int32.append(1) 694 m1.repeated_int32.append(2) 695 m2.repeated_int32.append(0) 696 m2.repeated_int32.append(1) 697 m2.repeated_int32.append(2) 698 m1.repeated_nested_message.add().bb = 1 699 m1.repeated_nested_message.add().bb = 2 700 m1.repeated_nested_message.add().bb = 3 701 m2.repeated_nested_message.add().bb = 1 702 m2.repeated_nested_message.add().bb = 2 703 m2.repeated_nested_message.add().bb = 3 704 705 def testRepeatedFieldsAreSequences(self, message_module): 706 m = message_module.TestAllTypes() 707 self.assertIsInstance(m.repeated_int32, collections.abc.MutableSequence) 708 self.assertIsInstance(m.repeated_nested_message, 709 collections.abc.MutableSequence) 710 711 def testRepeatedFieldsNotHashable(self, message_module): 712 m = message_module.TestAllTypes() 713 with self.assertRaises(TypeError): 714 hash(m.repeated_int32) 715 with self.assertRaises(TypeError): 716 hash(m.repeated_nested_message) 717 718 def testRepeatedFieldInsideNestedMessage(self, message_module): 719 m = message_module.NestedTestAllTypes() 720 m.payload.repeated_int32.extend([]) 721 self.assertTrue(m.HasField('payload')) 722 723 def testMergeFrom(self, message_module): 724 m1 = message_module.TestAllTypes() 725 m2 = message_module.TestAllTypes() 726 # Cpp extension will lazily create a sub message which is immutable. 727 nested = m1.optional_nested_message 728 self.assertEqual(0, nested.bb) 729 m2.optional_nested_message.bb = 1 730 # Make sure cmessage pointing to a mutable message after merge instead of 731 # the lazily created message. 732 m1.MergeFrom(m2) 733 self.assertEqual(1, nested.bb) 734 735 # Test more nested sub message. 736 msg1 = message_module.NestedTestAllTypes() 737 msg2 = message_module.NestedTestAllTypes() 738 nested = msg1.child.payload.optional_nested_message 739 self.assertEqual(0, nested.bb) 740 msg2.child.payload.optional_nested_message.bb = 1 741 msg1.MergeFrom(msg2) 742 self.assertEqual(1, nested.bb) 743 744 # Test repeated field. 745 self.assertEqual(msg1.payload.repeated_nested_message, 746 msg1.payload.repeated_nested_message) 747 nested = msg2.payload.repeated_nested_message.add() 748 nested.bb = 1 749 msg1.MergeFrom(msg2) 750 self.assertEqual(1, len(msg1.payload.repeated_nested_message)) 751 self.assertEqual(1, nested.bb) 752 753 def testMergeFromString(self, message_module): 754 m1 = message_module.TestAllTypes() 755 m2 = message_module.TestAllTypes() 756 # Cpp extension will lazily create a sub message which is immutable. 757 self.assertEqual(0, m1.optional_nested_message.bb) 758 m2.optional_nested_message.bb = 1 759 # Make sure cmessage pointing to a mutable message after merge instead of 760 # the lazily created message. 761 m1.MergeFromString(m2.SerializeToString()) 762 self.assertEqual(1, m1.optional_nested_message.bb) 763 764 def testMergeFromStringUsingMemoryView(self, message_module): 765 m2 = message_module.TestAllTypes() 766 m2.optional_string = 'scalar string' 767 m2.repeated_string.append('repeated string') 768 m2.optional_bytes = b'scalar bytes' 769 m2.repeated_bytes.append(b'repeated bytes') 770 771 serialized = m2.SerializeToString() 772 memview = memoryview(serialized) 773 m1 = message_module.TestAllTypes.FromString(memview) 774 775 self.assertEqual(m1.optional_bytes, b'scalar bytes') 776 self.assertEqual(m1.repeated_bytes, [b'repeated bytes']) 777 self.assertEqual(m1.optional_string, 'scalar string') 778 self.assertEqual(m1.repeated_string, ['repeated string']) 779 # Make sure that the memoryview was correctly converted to bytes, and 780 # that a sub-sliced memoryview is not being used. 781 self.assertIsInstance(m1.optional_bytes, bytes) 782 self.assertIsInstance(m1.repeated_bytes[0], bytes) 783 self.assertIsInstance(m1.optional_string, str) 784 self.assertIsInstance(m1.repeated_string[0], str) 785 786 def testMergeFromEmpty(self, message_module): 787 m1 = message_module.TestAllTypes() 788 # Cpp extension will lazily create a sub message which is immutable. 789 self.assertEqual(0, m1.optional_nested_message.bb) 790 self.assertFalse(m1.HasField('optional_nested_message')) 791 # Make sure the sub message is still immutable after merge from empty. 792 m1.MergeFromString(b'') # field state should not change 793 self.assertFalse(m1.HasField('optional_nested_message')) 794 795 def ensureNestedMessageExists(self, msg, attribute): 796 """Make sure that a nested message object exists. 797 798 As soon as a nested message attribute is accessed, it will be present in the 799 _fields dict, without being marked as actually being set. 800 """ 801 getattr(msg, attribute) 802 self.assertFalse(msg.HasField(attribute)) 803 804 def testOneofGetCaseNonexistingField(self, message_module): 805 m = message_module.TestAllTypes() 806 self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') 807 self.assertRaises(Exception, m.WhichOneof, 0) 808 809 def testOneofDefaultValues(self, message_module): 810 m = message_module.TestAllTypes() 811 self.assertIs(None, m.WhichOneof('oneof_field')) 812 self.assertFalse(m.HasField('oneof_field')) 813 self.assertFalse(m.HasField('oneof_uint32')) 814 815 # Oneof is set even when setting it to a default value. 816 m.oneof_uint32 = 0 817 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 818 self.assertTrue(m.HasField('oneof_field')) 819 self.assertTrue(m.HasField('oneof_uint32')) 820 self.assertFalse(m.HasField('oneof_string')) 821 822 m.oneof_string = '' 823 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 824 self.assertTrue(m.HasField('oneof_string')) 825 self.assertFalse(m.HasField('oneof_uint32')) 826 827 def testOneofSemantics(self, message_module): 828 m = message_module.TestAllTypes() 829 self.assertIs(None, m.WhichOneof('oneof_field')) 830 831 m.oneof_uint32 = 11 832 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 833 self.assertTrue(m.HasField('oneof_uint32')) 834 835 m.oneof_string = u'foo' 836 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 837 self.assertFalse(m.HasField('oneof_uint32')) 838 self.assertTrue(m.HasField('oneof_string')) 839 840 # Read nested message accessor without accessing submessage. 841 m.oneof_nested_message 842 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 843 self.assertTrue(m.HasField('oneof_string')) 844 self.assertFalse(m.HasField('oneof_nested_message')) 845 846 # Read accessor of nested message without accessing submessage. 847 m.oneof_nested_message.bb 848 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 849 self.assertTrue(m.HasField('oneof_string')) 850 self.assertFalse(m.HasField('oneof_nested_message')) 851 852 m.oneof_nested_message.bb = 11 853 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 854 self.assertFalse(m.HasField('oneof_string')) 855 self.assertTrue(m.HasField('oneof_nested_message')) 856 857 m.oneof_bytes = b'bb' 858 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 859 self.assertFalse(m.HasField('oneof_nested_message')) 860 self.assertTrue(m.HasField('oneof_bytes')) 861 862 def testOneofCompositeFieldReadAccess(self, message_module): 863 m = message_module.TestAllTypes() 864 m.oneof_uint32 = 11 865 866 self.ensureNestedMessageExists(m, 'oneof_nested_message') 867 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 868 self.assertEqual(11, m.oneof_uint32) 869 870 def testOneofWhichOneof(self, message_module): 871 m = message_module.TestAllTypes() 872 self.assertIs(None, m.WhichOneof('oneof_field')) 873 if message_module is unittest_pb2: 874 self.assertFalse(m.HasField('oneof_field')) 875 876 m.oneof_uint32 = 11 877 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 878 if message_module is unittest_pb2: 879 self.assertTrue(m.HasField('oneof_field')) 880 881 m.oneof_bytes = b'bb' 882 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 883 884 m.ClearField('oneof_bytes') 885 self.assertIs(None, m.WhichOneof('oneof_field')) 886 if message_module is unittest_pb2: 887 self.assertFalse(m.HasField('oneof_field')) 888 889 def testOneofClearField(self, message_module): 890 m = message_module.TestAllTypes() 891 m.ClearField('oneof_field') 892 m.oneof_uint32 = 11 893 m.ClearField('oneof_field') 894 if message_module is unittest_pb2: 895 self.assertFalse(m.HasField('oneof_field')) 896 self.assertFalse(m.HasField('oneof_uint32')) 897 self.assertIs(None, m.WhichOneof('oneof_field')) 898 899 def testOneofClearSetField(self, message_module): 900 m = message_module.TestAllTypes() 901 m.oneof_uint32 = 11 902 m.ClearField('oneof_uint32') 903 if message_module is unittest_pb2: 904 self.assertFalse(m.HasField('oneof_field')) 905 self.assertFalse(m.HasField('oneof_uint32')) 906 self.assertIs(None, m.WhichOneof('oneof_field')) 907 908 def testOneofClearUnsetField(self, message_module): 909 m = message_module.TestAllTypes() 910 m.oneof_uint32 = 11 911 self.ensureNestedMessageExists(m, 'oneof_nested_message') 912 m.ClearField('oneof_nested_message') 913 self.assertEqual(11, m.oneof_uint32) 914 if message_module is unittest_pb2: 915 self.assertTrue(m.HasField('oneof_field')) 916 self.assertTrue(m.HasField('oneof_uint32')) 917 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 918 919 def testOneofDeserialize(self, message_module): 920 m = message_module.TestAllTypes() 921 m.oneof_uint32 = 11 922 m2 = message_module.TestAllTypes() 923 m2.ParseFromString(m.SerializeToString()) 924 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 925 926 def testOneofCopyFrom(self, message_module): 927 m = message_module.TestAllTypes() 928 m.oneof_uint32 = 11 929 m2 = message_module.TestAllTypes() 930 m2.CopyFrom(m) 931 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 932 933 def testOneofNestedMergeFrom(self, message_module): 934 m = message_module.NestedTestAllTypes() 935 m.payload.oneof_uint32 = 11 936 m2 = message_module.NestedTestAllTypes() 937 m2.payload.oneof_bytes = b'bb' 938 m2.child.payload.oneof_bytes = b'bb' 939 m2.MergeFrom(m) 940 self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field')) 941 self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field')) 942 943 def testOneofMessageMergeFrom(self, message_module): 944 m = message_module.NestedTestAllTypes() 945 m.payload.oneof_nested_message.bb = 11 946 m.child.payload.oneof_nested_message.bb = 12 947 m2 = message_module.NestedTestAllTypes() 948 m2.payload.oneof_uint32 = 13 949 m2.MergeFrom(m) 950 self.assertEqual('oneof_nested_message', 951 m2.payload.WhichOneof('oneof_field')) 952 self.assertEqual('oneof_nested_message', 953 m2.child.payload.WhichOneof('oneof_field')) 954 955 def testOneofNestedMessageInit(self, message_module): 956 m = message_module.TestAllTypes( 957 oneof_nested_message=message_module.TestAllTypes.NestedMessage()) 958 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 959 960 def testOneofClear(self, message_module): 961 m = message_module.TestAllTypes() 962 m.oneof_uint32 = 11 963 m.Clear() 964 self.assertIsNone(m.WhichOneof('oneof_field')) 965 m.oneof_bytes = b'bb' 966 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 967 968 def testAssignByteStringToUnicodeField(self, message_module): 969 """Assigning a byte string to a string field should result 970 971 in the value being converted to a Unicode string. 972 """ 973 m = message_module.TestAllTypes() 974 m.optional_string = str('') 975 self.assertIsInstance(m.optional_string, str) 976 977 def testLongValuedSlice(self, message_module): 978 """It should be possible to use int-valued indices in slices. 979 980 This didn't used to work in the v2 C++ implementation. 981 """ 982 m = message_module.TestAllTypes() 983 984 # Repeated scalar 985 m.repeated_int32.append(1) 986 sl = m.repeated_int32[int(0):int(len(m.repeated_int32))] 987 self.assertEqual(len(m.repeated_int32), len(sl)) 988 989 # Repeated composite 990 m.repeated_nested_message.add().bb = 3 991 sl = m.repeated_nested_message[int(0):int(len(m.repeated_nested_message))] 992 self.assertEqual(len(m.repeated_nested_message), len(sl)) 993 994 def testExtendShouldNotSwallowExceptions(self, message_module): 995 """This didn't use to work in the v2 C++ implementation.""" 996 m = message_module.TestAllTypes() 997 with self.assertRaises(NameError) as _: 998 m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable 999 with self.assertRaises(NameError) as _: 1000 m.repeated_nested_enum.extend(a for i in range(10)) # pylint: disable=undefined-variable 1001 1002 FALSY_VALUES = [None, False, 0, 0.0] 1003 EMPTY_VALUES = [b'', u'', bytearray(), [], {}, set()] 1004 1005 def testExtendInt32WithNothing(self, message_module): 1006 """Test no-ops extending repeated int32 fields.""" 1007 m = message_module.TestAllTypes() 1008 self.assertSequenceEqual([], m.repeated_int32) 1009 1010 for falsy_value in MessageTest.FALSY_VALUES: 1011 with self.assertRaises(TypeError) as context: 1012 m.repeated_int32.extend(falsy_value) 1013 self.assertIn('iterable', str(context.exception)) 1014 self.assertSequenceEqual([], m.repeated_int32) 1015 1016 for empty_value in MessageTest.EMPTY_VALUES: 1017 m.repeated_int32.extend(empty_value) 1018 self.assertSequenceEqual([], m.repeated_int32) 1019 1020 def testExtendFloatWithNothing(self, message_module): 1021 """Test no-ops extending repeated float fields.""" 1022 m = message_module.TestAllTypes() 1023 self.assertSequenceEqual([], m.repeated_float) 1024 1025 for falsy_value in MessageTest.FALSY_VALUES: 1026 with self.assertRaises(TypeError) as context: 1027 m.repeated_float.extend(falsy_value) 1028 self.assertIn('iterable', str(context.exception)) 1029 self.assertSequenceEqual([], m.repeated_float) 1030 1031 for empty_value in MessageTest.EMPTY_VALUES: 1032 m.repeated_float.extend(empty_value) 1033 self.assertSequenceEqual([], m.repeated_float) 1034 1035 def testExtendStringWithNothing(self, message_module): 1036 """Test no-ops extending repeated string fields.""" 1037 m = message_module.TestAllTypes() 1038 self.assertSequenceEqual([], m.repeated_string) 1039 1040 for falsy_value in MessageTest.FALSY_VALUES: 1041 with self.assertRaises(TypeError) as context: 1042 m.repeated_string.extend(falsy_value) 1043 self.assertIn('iterable', str(context.exception)) 1044 self.assertSequenceEqual([], m.repeated_string) 1045 1046 for empty_value in MessageTest.EMPTY_VALUES: 1047 m.repeated_string.extend(empty_value) 1048 self.assertSequenceEqual([], m.repeated_string) 1049 1050 def testExtendInt32WithPythonList(self, message_module): 1051 """Test extending repeated int32 fields with python lists.""" 1052 m = message_module.TestAllTypes() 1053 self.assertSequenceEqual([], m.repeated_int32) 1054 m.repeated_int32.extend([0]) 1055 self.assertSequenceEqual([0], m.repeated_int32) 1056 m.repeated_int32.extend([1, 2]) 1057 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 1058 m.repeated_int32.extend([3, 4]) 1059 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 1060 1061 def testExtendFloatWithPythonList(self, message_module): 1062 """Test extending repeated float fields with python lists.""" 1063 m = message_module.TestAllTypes() 1064 self.assertSequenceEqual([], m.repeated_float) 1065 m.repeated_float.extend([0.0]) 1066 self.assertSequenceEqual([0.0], m.repeated_float) 1067 m.repeated_float.extend([1.0, 2.0]) 1068 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 1069 m.repeated_float.extend([3.0, 4.0]) 1070 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 1071 1072 def testExtendStringWithPythonList(self, message_module): 1073 """Test extending repeated string fields with python lists.""" 1074 m = message_module.TestAllTypes() 1075 self.assertSequenceEqual([], m.repeated_string) 1076 m.repeated_string.extend(['']) 1077 self.assertSequenceEqual([''], m.repeated_string) 1078 m.repeated_string.extend(['11', '22']) 1079 self.assertSequenceEqual(['', '11', '22'], m.repeated_string) 1080 m.repeated_string.extend(['33', '44']) 1081 self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string) 1082 1083 def testExtendStringWithString(self, message_module): 1084 """Test extending repeated string fields with characters from a string.""" 1085 m = message_module.TestAllTypes() 1086 self.assertSequenceEqual([], m.repeated_string) 1087 m.repeated_string.extend('abc') 1088 self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string) 1089 1090 class TestIterable(object): 1091 """This iterable object mimics the behavior of numpy.array. 1092 1093 __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1. 1094 1095 """ 1096 1097 def __init__(self, values=None): 1098 self._list = values or [] 1099 1100 def __nonzero__(self): 1101 size = len(self._list) 1102 if size == 0: 1103 return False 1104 if size == 1: 1105 return bool(self._list[0]) 1106 raise ValueError('Truth value is ambiguous.') 1107 1108 def __len__(self): 1109 return len(self._list) 1110 1111 def __iter__(self): 1112 return self._list.__iter__() 1113 1114 def testExtendInt32WithIterable(self, message_module): 1115 """Test extending repeated int32 fields with iterable.""" 1116 m = message_module.TestAllTypes() 1117 self.assertSequenceEqual([], m.repeated_int32) 1118 m.repeated_int32.extend(MessageTest.TestIterable([])) 1119 self.assertSequenceEqual([], m.repeated_int32) 1120 m.repeated_int32.extend(MessageTest.TestIterable([0])) 1121 self.assertSequenceEqual([0], m.repeated_int32) 1122 m.repeated_int32.extend(MessageTest.TestIterable([1, 2])) 1123 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 1124 m.repeated_int32.extend(MessageTest.TestIterable([3, 4])) 1125 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 1126 1127 def testExtendFloatWithIterable(self, message_module): 1128 """Test extending repeated float fields with iterable.""" 1129 m = message_module.TestAllTypes() 1130 self.assertSequenceEqual([], m.repeated_float) 1131 m.repeated_float.extend(MessageTest.TestIterable([])) 1132 self.assertSequenceEqual([], m.repeated_float) 1133 m.repeated_float.extend(MessageTest.TestIterable([0.0])) 1134 self.assertSequenceEqual([0.0], m.repeated_float) 1135 m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0])) 1136 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 1137 m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0])) 1138 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 1139 1140 def testExtendStringWithIterable(self, message_module): 1141 """Test extending repeated string fields with iterable.""" 1142 m = message_module.TestAllTypes() 1143 self.assertSequenceEqual([], m.repeated_string) 1144 m.repeated_string.extend(MessageTest.TestIterable([])) 1145 self.assertSequenceEqual([], m.repeated_string) 1146 m.repeated_string.extend(MessageTest.TestIterable([''])) 1147 self.assertSequenceEqual([''], m.repeated_string) 1148 m.repeated_string.extend(MessageTest.TestIterable(['1', '2'])) 1149 self.assertSequenceEqual(['', '1', '2'], m.repeated_string) 1150 m.repeated_string.extend(MessageTest.TestIterable(['3', '4'])) 1151 self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string) 1152 1153 class TestIndex(object): 1154 """This index object mimics the behavior of numpy.int64 and other types.""" 1155 1156 def __init__(self, value=None): 1157 self.value = value 1158 1159 def __index__(self): 1160 return self.value 1161 1162 def testRepeatedIndexingWithIntIndex(self, message_module): 1163 msg = message_module.TestAllTypes() 1164 msg.repeated_int32.extend([1, 2, 3]) 1165 self.assertEqual(1, msg.repeated_int32[MessageTest.TestIndex(0)]) 1166 1167 def testRepeatedIndexingWithNegative1IntIndex(self, message_module): 1168 msg = message_module.TestAllTypes() 1169 msg.repeated_int32.extend([1, 2, 3]) 1170 self.assertEqual(3, msg.repeated_int32[MessageTest.TestIndex(-1)]) 1171 1172 def testRepeatedIndexingWithNegative1Int(self, message_module): 1173 msg = message_module.TestAllTypes() 1174 msg.repeated_int32.extend([1, 2, 3]) 1175 self.assertEqual(3, msg.repeated_int32[-1]) 1176 1177 def testPickleRepeatedScalarContainer(self, message_module): 1178 # Pickle repeated scalar container is not supported. 1179 m = message_module.TestAllTypes() 1180 with self.assertRaises(pickle.PickleError) as _: 1181 pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) 1182 1183 def testSortEmptyRepeatedCompositeContainer(self, message_module): 1184 """Exercise a scenario that has led to segfaults in the past.""" 1185 m = message_module.TestAllTypes() 1186 m.repeated_nested_message.sort() 1187 1188 def testHasFieldOnRepeatedField(self, message_module): 1189 """Using HasField on a repeated field should raise an exception.""" 1190 m = message_module.TestAllTypes() 1191 with self.assertRaises(ValueError) as _: 1192 m.HasField('repeated_int32') 1193 1194 def testRepeatedScalarFieldPop(self, message_module): 1195 m = message_module.TestAllTypes() 1196 with self.assertRaises(IndexError) as _: 1197 m.repeated_int32.pop() 1198 m.repeated_int32.extend(range(5)) 1199 self.assertEqual(4, m.repeated_int32.pop()) 1200 self.assertEqual(0, m.repeated_int32.pop(0)) 1201 self.assertEqual(2, m.repeated_int32.pop(1)) 1202 self.assertEqual([1, 3], m.repeated_int32) 1203 1204 def testRepeatedCompositeFieldPop(self, message_module): 1205 m = message_module.TestAllTypes() 1206 with self.assertRaises(IndexError) as _: 1207 m.repeated_nested_message.pop() 1208 with self.assertRaises(TypeError) as _: 1209 m.repeated_nested_message.pop('0') 1210 for i in range(5): 1211 n = m.repeated_nested_message.add() 1212 n.bb = i 1213 self.assertEqual(4, m.repeated_nested_message.pop().bb) 1214 self.assertEqual(0, m.repeated_nested_message.pop(0).bb) 1215 self.assertEqual(2, m.repeated_nested_message.pop(1).bb) 1216 self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message]) 1217 1218 def testRepeatedCompareWithSelf(self, message_module): 1219 m = message_module.TestAllTypes() 1220 for i in range(5): 1221 m.repeated_int32.insert(i, i) 1222 n = m.repeated_nested_message.add() 1223 n.bb = i 1224 self.assertSequenceEqual(m.repeated_int32, m.repeated_int32) 1225 self.assertEqual(m.repeated_nested_message, m.repeated_nested_message) 1226 1227 def testReleasedNestedMessages(self, message_module): 1228 """A case that lead to a segfault when a message detached from its parent 1229 1230 container has itself a child container. 1231 """ 1232 m = message_module.NestedTestAllTypes() 1233 m = m.repeated_child.add() 1234 m = m.child 1235 m = m.repeated_child.add() 1236 self.assertEqual(m.payload.optional_int32, 0) 1237 1238 def testSetRepeatedComposite(self, message_module): 1239 m = message_module.TestAllTypes() 1240 with self.assertRaises(AttributeError): 1241 m.repeated_int32 = [] 1242 m.repeated_int32.append(1) 1243 with self.assertRaises(AttributeError): 1244 m.repeated_int32 = [] 1245 1246 def testReturningType(self, message_module): 1247 m = message_module.TestAllTypes() 1248 self.assertEqual(float, type(m.optional_float)) 1249 self.assertEqual(float, type(m.optional_double)) 1250 self.assertEqual(bool, type(m.optional_bool)) 1251 m.optional_float = 1 1252 m.optional_double = 1 1253 m.optional_bool = 1 1254 m.repeated_float.append(1) 1255 m.repeated_double.append(1) 1256 m.repeated_bool.append(1) 1257 m.ParseFromString(m.SerializeToString()) 1258 self.assertEqual(float, type(m.optional_float)) 1259 self.assertEqual(float, type(m.optional_double)) 1260 self.assertEqual('1.0', str(m.optional_double)) 1261 self.assertEqual(bool, type(m.optional_bool)) 1262 self.assertEqual(float, type(m.repeated_float[0])) 1263 self.assertEqual(float, type(m.repeated_double[0])) 1264 self.assertEqual(bool, type(m.repeated_bool[0])) 1265 self.assertEqual(True, m.repeated_bool[0]) 1266 1267 def testEquality(self, message_module): 1268 m = message_module.TestAllTypes() 1269 m2 = message_module.TestAllTypes() 1270 self.assertEqual(m, m) 1271 self.assertEqual(m, m2) 1272 self.assertEqual(m2, m) 1273 1274 different_m = message_module.TestAllTypes() 1275 different_m.repeated_float.append(1) 1276 self.assertNotEqual(m, different_m) 1277 self.assertNotEqual(different_m, m) 1278 1279 self.assertIsNotNone(m) 1280 self.assertIsNotNone(m) 1281 self.assertNotEqual(42, m) 1282 self.assertNotEqual(m, 42) 1283 self.assertNotEqual('foo', m) 1284 self.assertNotEqual(m, 'foo') 1285 1286 self.assertEqual(mock.ANY, m) 1287 self.assertEqual(m, mock.ANY) 1288 1289 class ComparesWithFoo(object): 1290 1291 def __eq__(self, other): 1292 if getattr(other, 'optional_string', 'not_foo') == 'foo': 1293 return True 1294 return NotImplemented 1295 1296 m.optional_string = 'foo' 1297 self.assertEqual(m, ComparesWithFoo()) 1298 self.assertEqual(ComparesWithFoo(), m) 1299 m.optional_string = 'bar' 1300 self.assertNotEqual(m, ComparesWithFoo()) 1301 self.assertNotEqual(ComparesWithFoo(), m) 1302 1303 def testTypeUnion(self, message_module): 1304 # Below python 3.10 you cannot create union types with the | operator, so we 1305 # skip testing for unions with old versions. 1306 if sys.version_info < (3, 10): 1307 return 1308 enum_type = enum_type_wrapper.EnumTypeWrapper( 1309 message_module.TestAllTypes.NestedEnum.DESCRIPTOR 1310 ) 1311 union_type = enum_type | int 1312 self.assertIsInstance(union_type, types.UnionType) 1313 1314 def get_union() -> union_type: 1315 return enum_type 1316 1317 union = get_union() 1318 self.assertIsInstance(union, enum_type_wrapper.EnumTypeWrapper) 1319 self.assertEqual( 1320 union.DESCRIPTOR, message_module.TestAllTypes.NestedEnum.DESCRIPTOR 1321 ) 1322 1323 def testIn(self, message_module): 1324 m = message_module.TestAllTypes() 1325 self.assertNotIn('optional_nested_message', m) 1326 self.assertNotIn('oneof_bytes', m) 1327 self.assertNotIn('oneof_string', m) 1328 with self.assertRaises(ValueError) as e: 1329 'repeated_int32' in m 1330 with self.assertRaises(ValueError) as e: 1331 'repeated_nested_message' in m 1332 with self.assertRaises(ValueError) as e: 1333 1 in m 1334 with self.assertRaises(ValueError) as e: 1335 'not_a_field' in m 1336 test_util.SetAllFields(m) 1337 self.assertIn('optional_nested_message', m) 1338 self.assertIn('oneof_bytes', m) 1339 self.assertNotIn('oneof_string', m) 1340 1341 1342@testing_refleaks.TestCase 1343class TestRecursiveGroup(unittest.TestCase): 1344 1345 def _MakeRecursiveGroupMessage(self, n): 1346 msg = self_recursive_pb2.SelfRecursive() 1347 sub = msg 1348 for _ in range(n): 1349 sub = sub.sub_group 1350 sub.i = 1 1351 return msg.SerializeToString() 1352 1353 def testRecursiveGroups(self): 1354 recurse_msg = self_recursive_pb2.SelfRecursive() 1355 data = self._MakeRecursiveGroupMessage(100) 1356 recurse_msg.ParseFromString(data) 1357 self.assertTrue(recurse_msg.HasField('sub_group')) 1358 1359 def testRecursiveGroupsException(self): 1360 if api_implementation.Type() != 'python': 1361 api_implementation._c_module.SetAllowOversizeProtos(False) 1362 recurse_msg = self_recursive_pb2.SelfRecursive() 1363 data = self._MakeRecursiveGroupMessage(300) 1364 with self.assertRaises(message.DecodeError) as context: 1365 recurse_msg.ParseFromString(data) 1366 self.assertIn('Error parsing message', str(context.exception)) 1367 if api_implementation.Type() == 'python': 1368 self.assertIn('too many levels of nesting', str(context.exception)) 1369 1370 def testRecursiveGroupsUnknownFields(self): 1371 if api_implementation.Type() != 'python': 1372 api_implementation._c_module.SetAllowOversizeProtos(False) 1373 test_msg = unittest_pb2.TestAllTypes() 1374 data = self._MakeRecursiveGroupMessage(300) # unknown to test_msg 1375 with self.assertRaises(message.DecodeError) as context: 1376 test_msg.ParseFromString(data) 1377 self.assertIn( 1378 'Error parsing message', 1379 str(context.exception), 1380 ) 1381 if api_implementation.Type() == 'python': 1382 self.assertIn('too many levels of nesting', str(context.exception)) 1383 decoder.SetRecursionLimit(310) 1384 test_msg.ParseFromString(data) 1385 decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT) 1386 1387 1388# Class to test proto2-only features (required, extensions, etc.) 1389@testing_refleaks.TestCase 1390class Proto2Test(unittest.TestCase): 1391 1392 def testFieldPresence(self): 1393 message = unittest_pb2.TestAllTypes() 1394 1395 self.assertFalse(message.HasField('optional_int32')) 1396 self.assertFalse(message.HasField('optional_bool')) 1397 self.assertFalse(message.HasField('optional_nested_message')) 1398 1399 with self.assertRaises(ValueError): 1400 message.HasField('field_doesnt_exist') 1401 1402 with self.assertRaises(ValueError): 1403 message.HasField('repeated_int32') 1404 with self.assertRaises(ValueError): 1405 message.HasField('repeated_nested_message') 1406 1407 self.assertEqual(0, message.optional_int32) 1408 self.assertEqual(False, message.optional_bool) 1409 self.assertEqual(0, message.optional_nested_message.bb) 1410 1411 # Fields are set even when setting the values to default values. 1412 message.optional_int32 = 0 1413 message.optional_bool = False 1414 message.optional_nested_message.bb = 0 1415 self.assertTrue(message.HasField('optional_int32')) 1416 self.assertTrue(message.HasField('optional_bool')) 1417 self.assertTrue(message.HasField('optional_nested_message')) 1418 self.assertIn('optional_int32', message) 1419 self.assertIn('optional_bool', message) 1420 self.assertIn('optional_nested_message', message) 1421 1422 # Set the fields to non-default values. 1423 message.optional_int32 = 5 1424 message.optional_bool = True 1425 message.optional_nested_message.bb = 15 1426 1427 self.assertTrue(message.HasField(u'optional_int32')) 1428 self.assertTrue(message.HasField('optional_bool')) 1429 self.assertTrue(message.HasField('optional_nested_message')) 1430 1431 # Clearing the fields unsets them and resets their value to default. 1432 message.ClearField('optional_int32') 1433 message.ClearField(u'optional_bool') 1434 message.ClearField('optional_nested_message') 1435 1436 self.assertFalse(message.HasField('optional_int32')) 1437 self.assertFalse(message.HasField('optional_bool')) 1438 self.assertFalse(message.HasField('optional_nested_message')) 1439 self.assertNotIn('optional_int32', message) 1440 self.assertNotIn('optional_bool', message) 1441 self.assertNotIn('optional_nested_message', message) 1442 self.assertEqual(0, message.optional_int32) 1443 self.assertEqual(False, message.optional_bool) 1444 self.assertEqual(0, message.optional_nested_message.bb) 1445 1446 def testDel(self): 1447 msg = unittest_pb2.TestAllTypes() 1448 1449 # Fields cannot be deleted. 1450 with self.assertRaises(AttributeError): 1451 del msg.optional_int32 1452 with self.assertRaises(AttributeError): 1453 del msg.optional_bool 1454 with self.assertRaises(AttributeError): 1455 del msg.repeated_nested_message 1456 1457 def testAssignInvalidEnum(self): 1458 """Assigning an invalid enum number is not allowed for closed enums.""" 1459 m = unittest_pb2.TestAllTypes() 1460 1461 # TODO Enable these once upb's behavior is made conformant. 1462 if api_implementation.Type() != 'upb': 1463 # Can not assign unknown enum to closed enums. 1464 with self.assertRaises(ValueError) as _: 1465 m.optional_nested_enum = 1234567 1466 self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567) 1467 # Assignment is a different code path than append for the C++ impl. 1468 m.repeated_nested_enum.append(2) 1469 m.repeated_nested_enum[0] = 2 1470 with self.assertRaises(ValueError): 1471 m.repeated_nested_enum[0] = 123456 1472 else: 1473 m.optional_nested_enum = 1234567 1474 m.repeated_nested_enum.append(1234567) 1475 m.repeated_nested_enum.append(2) 1476 m.repeated_nested_enum[0] = 2 1477 m.repeated_nested_enum[0] = 123456 1478 1479 # Unknown enum value can be parsed but is ignored. 1480 m2 = unittest_proto3_arena_pb2.TestAllTypes() 1481 m2.optional_nested_enum = 1234567 1482 m2.repeated_nested_enum.append(7654321) 1483 serialized = m2.SerializeToString() 1484 1485 m3 = unittest_pb2.TestAllTypes() 1486 m3.ParseFromString(serialized) 1487 self.assertFalse(m3.HasField('optional_nested_enum')) 1488 # 1 is the default value for optional_nested_enum. 1489 self.assertEqual(1, m3.optional_nested_enum) 1490 self.assertEqual(0, len(m3.repeated_nested_enum)) 1491 m2.Clear() 1492 m2.ParseFromString(m3.SerializeToString()) 1493 self.assertEqual(1234567, m2.optional_nested_enum) 1494 self.assertEqual(7654321, m2.repeated_nested_enum[0]) 1495 1496 def testUnknownEnumMap(self): 1497 m = map_proto2_unittest_pb2.TestEnumMap() 1498 m.known_map_field[123] = 0 1499 with self.assertRaises(ValueError): 1500 m.unknown_map_field[1] = 123 1501 1502 def testDeepCopyClosedEnum(self): 1503 m = map_proto2_unittest_pb2.TestEnumMap() 1504 m.known_map_field[123] = 0 1505 m2 = copy.deepcopy(m) 1506 self.assertEqual(m, m2) 1507 1508 def testExtensionsErrors(self): 1509 msg = unittest_pb2.TestAllTypes() 1510 self.assertRaises(AttributeError, getattr, msg, 'Extensions') 1511 1512 def testMergeFromExtensions(self): 1513 msg1 = more_extensions_pb2.TopLevelMessage() 1514 msg2 = more_extensions_pb2.TopLevelMessage() 1515 # Cpp extension will lazily create a sub message which is immutable. 1516 self.assertEqual( 1517 0, 1518 msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension]) 1519 self.assertFalse(msg1.HasField('submessage')) 1520 msg2.submessage.Extensions[more_extensions_pb2.optional_int_extension] = 123 1521 # Make sure cmessage and extensions pointing to a mutable message 1522 # after merge instead of the lazily created message. 1523 msg1.MergeFrom(msg2) 1524 self.assertEqual( 1525 123, 1526 msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension]) 1527 1528 def testCopyFromAll(self): 1529 message = unittest_pb2.TestAllTypes() 1530 test_util.SetAllFields(message) 1531 copy = unittest_pb2.TestAllTypes() 1532 copy.CopyFrom(message) 1533 self.assertEqual(message, copy) 1534 message.repeated_nested_message.add().bb = 123 1535 self.assertNotEqual(message, copy) 1536 1537 def testCopyFromAllExtensions(self): 1538 all_set = unittest_pb2.TestAllExtensions() 1539 test_util.SetAllExtensions(all_set) 1540 copy = unittest_pb2.TestAllExtensions() 1541 copy.CopyFrom(all_set) 1542 self.assertEqual(all_set, copy) 1543 all_set.Extensions[unittest_pb2.repeatedgroup_extension].add().a = 321 1544 self.assertNotEqual(all_set, copy) 1545 1546 def testCopyFromAllPackedExtensions(self): 1547 all_set = unittest_pb2.TestPackedExtensions() 1548 test_util.SetAllPackedExtensions(all_set) 1549 copy = unittest_pb2.TestPackedExtensions() 1550 copy.CopyFrom(all_set) 1551 self.assertEqual(all_set, copy) 1552 all_set.Extensions[unittest_pb2.packed_float_extension].extend([61.0, 71.0]) 1553 self.assertNotEqual(all_set, copy) 1554 1555 def testPickleIncompleteProto(self): 1556 golden_message = unittest_pb2.TestRequired(a=1) 1557 pickled_message = pickle.dumps(golden_message) 1558 1559 unpickled_message = pickle.loads(pickled_message) 1560 self.assertEqual(unpickled_message, golden_message) 1561 self.assertEqual(unpickled_message.a, 1) 1562 # This is still an incomplete proto - so serializing should fail 1563 self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) 1564 1565 # TODO: this isn't really a proto2-specific test except that this 1566 # message has a required field in it. Should probably be factored out so 1567 # that we can test the other parts with proto3. 1568 def testParsingMerge(self): 1569 """Check the merge behavior when a required or optional field appears 1570 1571 multiple times in the input. 1572 """ 1573 messages = [ 1574 unittest_pb2.TestAllTypes(), 1575 unittest_pb2.TestAllTypes(), 1576 unittest_pb2.TestAllTypes() 1577 ] 1578 messages[0].optional_int32 = 1 1579 messages[1].optional_int64 = 2 1580 messages[2].optional_int32 = 3 1581 messages[2].optional_string = 'hello' 1582 1583 merged_message = unittest_pb2.TestAllTypes() 1584 merged_message.optional_int32 = 3 1585 merged_message.optional_int64 = 2 1586 merged_message.optional_string = 'hello' 1587 1588 generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() 1589 generator.field1.extend(messages) 1590 generator.field2.extend(messages) 1591 generator.field3.extend(messages) 1592 generator.ext1.extend(messages) 1593 generator.ext2.extend(messages) 1594 generator.group1.add().field1.MergeFrom(messages[0]) 1595 generator.group1.add().field1.MergeFrom(messages[1]) 1596 generator.group1.add().field1.MergeFrom(messages[2]) 1597 generator.group2.add().field1.MergeFrom(messages[0]) 1598 generator.group2.add().field1.MergeFrom(messages[1]) 1599 generator.group2.add().field1.MergeFrom(messages[2]) 1600 1601 data = generator.SerializeToString() 1602 parsing_merge = unittest_pb2.TestParsingMerge() 1603 parsing_merge.ParseFromString(data) 1604 1605 # Required and optional fields should be merged. 1606 self.assertEqual(parsing_merge.required_all_types, merged_message) 1607 self.assertEqual(parsing_merge.optional_all_types, merged_message) 1608 self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, 1609 merged_message) 1610 self.assertEqual( 1611 parsing_merge.Extensions[unittest_pb2.TestParsingMerge.optional_ext], 1612 merged_message) 1613 1614 # Repeated fields should not be merged. 1615 self.assertEqual(len(parsing_merge.repeated_all_types), 3) 1616 self.assertEqual(len(parsing_merge.repeatedgroup), 3) 1617 self.assertEqual( 1618 len(parsing_merge.Extensions[ 1619 unittest_pb2.TestParsingMerge.repeated_ext]), 3) 1620 1621 def testPythonicInit(self): 1622 message = unittest_pb2.TestAllTypes( 1623 optional_int32=100, 1624 optional_fixed32=200, 1625 optional_float=300.5, 1626 optional_bytes=b'x', 1627 optionalgroup={'a': 400}, 1628 optional_nested_message={'bb': 500}, 1629 optional_foreign_message={}, 1630 optional_nested_enum='BAZ', 1631 repeatedgroup=[{ 1632 'a': 600 1633 }, { 1634 'a': 700 1635 }], 1636 repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], 1637 default_int32=800, 1638 oneof_string='y') 1639 self.assertIsInstance(message, unittest_pb2.TestAllTypes) 1640 self.assertEqual(100, message.optional_int32) 1641 self.assertEqual(200, message.optional_fixed32) 1642 self.assertEqual(300.5, message.optional_float) 1643 self.assertEqual(b'x', message.optional_bytes) 1644 self.assertEqual(400, message.optionalgroup.a) 1645 self.assertIsInstance(message.optional_nested_message, 1646 unittest_pb2.TestAllTypes.NestedMessage) 1647 self.assertEqual(500, message.optional_nested_message.bb) 1648 self.assertTrue(message.HasField('optional_foreign_message')) 1649 self.assertEqual(message.optional_foreign_message, 1650 unittest_pb2.ForeignMessage()) 1651 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1652 message.optional_nested_enum) 1653 self.assertEqual(2, len(message.repeatedgroup)) 1654 self.assertEqual(600, message.repeatedgroup[0].a) 1655 self.assertEqual(700, message.repeatedgroup[1].a) 1656 self.assertEqual(2, len(message.repeated_nested_enum)) 1657 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 1658 message.repeated_nested_enum[0]) 1659 self.assertEqual(unittest_pb2.TestAllTypes.BAR, 1660 message.repeated_nested_enum[1]) 1661 self.assertEqual(800, message.default_int32) 1662 self.assertEqual('y', message.oneof_string) 1663 self.assertFalse(message.HasField('optional_int64')) 1664 self.assertEqual(0, len(message.repeated_float)) 1665 self.assertEqual(42, message.default_int64) 1666 1667 message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ') 1668 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1669 message.optional_nested_enum) 1670 1671 with self.assertRaises(ValueError): 1672 unittest_pb2.TestAllTypes( 1673 optional_nested_message={'INVALID_NESTED_FIELD': 17}) 1674 1675 with self.assertRaises(TypeError): 1676 unittest_pb2.TestAllTypes( 1677 optional_nested_message={'bb': 'INVALID_VALUE_TYPE'}) 1678 1679 with self.assertRaises(ValueError): 1680 unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL') 1681 1682 with self.assertRaises(ValueError): 1683 unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') 1684 1685 def testPythonicInitWithDict(self): 1686 # Both string/unicode field name keys should work. 1687 kwargs = { 1688 'optional_int32': 100, 1689 u'optional_fixed32': 200, 1690 } 1691 msg = unittest_pb2.TestAllTypes(**kwargs) 1692 self.assertEqual(100, msg.optional_int32) 1693 self.assertEqual(200, msg.optional_fixed32) 1694 1695 def test_documentation(self): 1696 # Also used by the interactive help() function. 1697 doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message') 1698 self.assertIn('class TestAllTypes', doc) 1699 self.assertIn('SerializePartialToString', doc) 1700 self.assertIn('repeated_float', doc) 1701 base = unittest_pb2.TestAllTypes.__bases__[0] 1702 self.assertRaises(AttributeError, getattr, base, '_extensions_by_name') 1703 1704 1705# Class to test proto3-only features/behavior (updated field presence & enums) 1706@testing_refleaks.TestCase 1707class Proto3Test(unittest.TestCase): 1708 1709 # Utility method for comparing equality with a map. 1710 def assertMapIterEquals(self, map_iter, dict_value): 1711 # Avoid mutating caller's copy. 1712 dict_value = dict(dict_value) 1713 1714 for k, v in map_iter: 1715 self.assertEqual(v, dict_value[k]) 1716 del dict_value[k] 1717 1718 self.assertEqual({}, dict_value) 1719 1720 def testFieldPresence(self): 1721 message = unittest_proto3_arena_pb2.TestAllTypes() 1722 1723 # We can't test presence of non-repeated, non-submessage fields. 1724 with self.assertRaises(ValueError): 1725 message.HasField('optional_int32') 1726 with self.assertRaises(ValueError): 1727 message.HasField('optional_float') 1728 with self.assertRaises(ValueError): 1729 message.HasField('optional_string') 1730 with self.assertRaises(ValueError): 1731 message.HasField('optional_bool') 1732 1733 # But we can still test presence of submessage fields. 1734 self.assertFalse(message.HasField('optional_nested_message')) 1735 1736 # As with proto2, we can't test presence of fields that don't exist, or 1737 # repeated fields. 1738 with self.assertRaises(ValueError): 1739 message.HasField('field_doesnt_exist') 1740 1741 with self.assertRaises(ValueError): 1742 message.HasField('repeated_int32') 1743 with self.assertRaises(ValueError): 1744 message.HasField('repeated_nested_message') 1745 1746 # Can not test "in" operator. 1747 with self.assertRaises(ValueError): 1748 'repeated_int32' in message 1749 with self.assertRaises(ValueError): 1750 'repeated_nested_message' in message 1751 1752 # Fields should default to their type-specific default. 1753 self.assertEqual(0, message.optional_int32) 1754 self.assertEqual(0, message.optional_float) 1755 self.assertEqual('', message.optional_string) 1756 self.assertEqual(False, message.optional_bool) 1757 self.assertEqual(0, message.optional_nested_message.bb) 1758 1759 # Setting a submessage should still return proper presence information. 1760 message.optional_nested_message.bb = 0 1761 self.assertTrue(message.HasField('optional_nested_message')) 1762 self.assertIn('optional_nested_message', message) 1763 1764 # Set the fields to non-default values. 1765 message.optional_int32 = 5 1766 message.optional_float = 1.1 1767 message.optional_string = 'abc' 1768 message.optional_bool = True 1769 message.optional_nested_message.bb = 15 1770 1771 # Clearing the fields unsets them and resets their value to default. 1772 message.ClearField('optional_int32') 1773 message.ClearField('optional_float') 1774 message.ClearField('optional_string') 1775 message.ClearField('optional_bool') 1776 message.ClearField('optional_nested_message') 1777 1778 self.assertEqual(0, message.optional_int32) 1779 self.assertEqual(0, message.optional_float) 1780 self.assertEqual('', message.optional_string) 1781 self.assertEqual(False, message.optional_bool) 1782 self.assertEqual(0, message.optional_nested_message.bb) 1783 1784 def testProto3ParserDropDefaultScalar(self): 1785 message_proto2 = unittest_pb2.TestAllTypes() 1786 message_proto2.optional_int32 = 0 1787 message_proto2.optional_string = '' 1788 message_proto2.optional_bytes = b'' 1789 self.assertEqual(len(message_proto2.ListFields()), 3) 1790 1791 message_proto3 = unittest_proto3_arena_pb2.TestAllTypes() 1792 message_proto3.ParseFromString(message_proto2.SerializeToString()) 1793 self.assertEqual(len(message_proto3.ListFields()), 0) 1794 1795 def testProto3Optional(self): 1796 msg = test_proto3_optional_pb2.TestProto3Optional() 1797 self.assertFalse(msg.HasField('optional_int32')) 1798 self.assertFalse(msg.HasField('optional_float')) 1799 self.assertFalse(msg.HasField('optional_string')) 1800 self.assertFalse(msg.HasField('optional_nested_message')) 1801 self.assertFalse(msg.optional_nested_message.HasField('bb')) 1802 1803 # Set fields. 1804 msg.optional_int32 = 1 1805 msg.optional_float = 1.0 1806 msg.optional_string = '123' 1807 msg.optional_nested_message.bb = 1 1808 self.assertTrue(msg.HasField('optional_int32')) 1809 self.assertTrue(msg.HasField('optional_float')) 1810 self.assertTrue(msg.HasField('optional_string')) 1811 self.assertTrue(msg.HasField('optional_nested_message')) 1812 self.assertTrue(msg.optional_nested_message.HasField('bb')) 1813 # Set to default value does not clear the fields 1814 msg.optional_int32 = 0 1815 msg.optional_float = 0.0 1816 msg.optional_string = '' 1817 msg.optional_nested_message.bb = 0 1818 self.assertTrue(msg.HasField('optional_int32')) 1819 self.assertTrue(msg.HasField('optional_float')) 1820 self.assertTrue(msg.HasField('optional_string')) 1821 self.assertTrue(msg.HasField('optional_nested_message')) 1822 self.assertTrue(msg.optional_nested_message.HasField('bb')) 1823 1824 # Test serialize 1825 msg2 = test_proto3_optional_pb2.TestProto3Optional() 1826 msg2.ParseFromString(msg.SerializeToString()) 1827 self.assertTrue(msg2.HasField('optional_int32')) 1828 self.assertTrue(msg2.HasField('optional_float')) 1829 self.assertTrue(msg2.HasField('optional_string')) 1830 self.assertTrue(msg2.HasField('optional_nested_message')) 1831 self.assertTrue(msg2.optional_nested_message.HasField('bb')) 1832 1833 self.assertEqual(msg.WhichOneof('_optional_int32'), 'optional_int32') 1834 1835 # Clear these fields. 1836 msg.ClearField('optional_int32') 1837 msg.ClearField('optional_float') 1838 msg.ClearField('optional_string') 1839 msg.ClearField('optional_nested_message') 1840 self.assertFalse(msg.HasField('optional_int32')) 1841 self.assertFalse(msg.HasField('optional_float')) 1842 self.assertFalse(msg.HasField('optional_string')) 1843 self.assertFalse(msg.HasField('optional_nested_message')) 1844 self.assertFalse(msg.optional_nested_message.HasField('bb')) 1845 1846 self.assertEqual(msg.WhichOneof('_optional_int32'), None) 1847 1848 # Test has presence: 1849 for field in test_proto3_optional_pb2.TestProto3Optional.DESCRIPTOR.fields: 1850 if field.name.startswith('optional_'): 1851 self.assertTrue(field.has_presence) 1852 for field in unittest_pb2.TestAllTypes.DESCRIPTOR.fields: 1853 if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: 1854 self.assertFalse(field.has_presence) 1855 else: 1856 self.assertTrue(field.has_presence) 1857 proto3_descriptor = unittest_proto3_arena_pb2.TestAllTypes.DESCRIPTOR 1858 repeated_field = proto3_descriptor.fields_by_name['repeated_int32'] 1859 self.assertFalse(repeated_field.has_presence) 1860 singular_field = proto3_descriptor.fields_by_name['optional_int32'] 1861 self.assertFalse(singular_field.has_presence) 1862 optional_field = proto3_descriptor.fields_by_name['proto3_optional_int32'] 1863 self.assertTrue(optional_field.has_presence) 1864 message_field = proto3_descriptor.fields_by_name['optional_nested_message'] 1865 self.assertTrue(message_field.has_presence) 1866 oneof_field = proto3_descriptor.fields_by_name['oneof_uint32'] 1867 self.assertTrue(oneof_field.has_presence) 1868 1869 def testAssignUnknownEnum(self): 1870 """Assigning an unknown enum value is allowed and preserves the value.""" 1871 m = unittest_proto3_arena_pb2.TestAllTypes() 1872 1873 # Proto3 can assign unknown enums. 1874 m.optional_nested_enum = 1234567 1875 self.assertEqual(1234567, m.optional_nested_enum) 1876 m.repeated_nested_enum.append(22334455) 1877 self.assertEqual(22334455, m.repeated_nested_enum[0]) 1878 # Assignment is a different code path than append for the C++ impl. 1879 m.repeated_nested_enum[0] = 7654321 1880 self.assertEqual(7654321, m.repeated_nested_enum[0]) 1881 serialized = m.SerializeToString() 1882 1883 m2 = unittest_proto3_arena_pb2.TestAllTypes() 1884 m2.ParseFromString(serialized) 1885 self.assertEqual(1234567, m2.optional_nested_enum) 1886 self.assertEqual(7654321, m2.repeated_nested_enum[0]) 1887 1888 # Map isn't really a proto3-only feature. But there is no proto2 equivalent 1889 # of google/protobuf/map_unittest.proto right now, so it's not easy to 1890 # test both with the same test like we do for the other proto2/proto3 tests. 1891 # (google/protobuf/map_proto2_unittest.proto is very different in the set 1892 # of messages and fields it contains). 1893 def testScalarMapDefaults(self): 1894 msg = map_unittest_pb2.TestMap() 1895 1896 # Scalars start out unset. 1897 self.assertFalse(-123 in msg.map_int32_int32) 1898 self.assertFalse(-2**33 in msg.map_int64_int64) 1899 self.assertFalse(123 in msg.map_uint32_uint32) 1900 self.assertFalse(2**33 in msg.map_uint64_uint64) 1901 self.assertFalse(123 in msg.map_int32_double) 1902 self.assertFalse(False in msg.map_bool_bool) 1903 self.assertFalse('abc' in msg.map_string_string) 1904 self.assertFalse(111 in msg.map_int32_bytes) 1905 self.assertFalse(888 in msg.map_int32_enum) 1906 1907 # Accessing an unset key returns the default. 1908 self.assertEqual(0, msg.map_int32_int32[-123]) 1909 self.assertEqual(0, msg.map_int64_int64[-2**33]) 1910 self.assertEqual(0, msg.map_uint32_uint32[123]) 1911 self.assertEqual(0, msg.map_uint64_uint64[2**33]) 1912 self.assertEqual(0.0, msg.map_int32_double[123]) 1913 self.assertTrue(isinstance(msg.map_int32_double[123], float)) 1914 self.assertEqual(False, msg.map_bool_bool[False]) 1915 self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) 1916 self.assertEqual('', msg.map_string_string['abc']) 1917 self.assertEqual(b'', msg.map_int32_bytes[111]) 1918 self.assertEqual(0, msg.map_int32_enum[888]) 1919 1920 # It also sets the value in the map 1921 self.assertTrue(-123 in msg.map_int32_int32) 1922 self.assertTrue(-2**33 in msg.map_int64_int64) 1923 self.assertTrue(123 in msg.map_uint32_uint32) 1924 self.assertTrue(2**33 in msg.map_uint64_uint64) 1925 self.assertTrue(123 in msg.map_int32_double) 1926 self.assertTrue(False in msg.map_bool_bool) 1927 self.assertTrue('abc' in msg.map_string_string) 1928 self.assertTrue(111 in msg.map_int32_bytes) 1929 self.assertTrue(888 in msg.map_int32_enum) 1930 1931 self.assertIsInstance(msg.map_string_string['abc'], str) 1932 1933 # Accessing an unset key still throws TypeError if the type of the key 1934 # is incorrect. 1935 with self.assertRaises(TypeError): 1936 msg.map_string_string[123] 1937 1938 with self.assertRaises(TypeError): 1939 123 in msg.map_string_string 1940 1941 with self.assertRaises(TypeError): 1942 msg.map_string_string.__contains__(123) 1943 1944 def testScalarMapComparison(self): 1945 msg1 = map_unittest_pb2.TestMap() 1946 msg2 = map_unittest_pb2.TestMap() 1947 1948 self.assertEqual(msg1.map_int32_int32, msg2.map_int32_int32) 1949 1950 def testMessageMapComparison(self): 1951 msg1 = map_unittest_pb2.TestMap() 1952 msg2 = map_unittest_pb2.TestMap() 1953 1954 self.assertEqual(msg1.map_int32_foreign_message, 1955 msg2.map_int32_foreign_message) 1956 1957 def testMapGet(self): 1958 # Need to test that get() properly returns the default, even though the dict 1959 # has defaultdict-like semantics. 1960 msg = map_unittest_pb2.TestMap() 1961 1962 self.assertIsNone(msg.map_int32_int32.get(5)) 1963 self.assertEqual(10, msg.map_int32_int32.get(5, 10)) 1964 self.assertEqual(10, msg.map_int32_int32.get(key=5, default=10)) 1965 self.assertIsNone(msg.map_int32_int32.get(5)) 1966 1967 msg.map_int32_int32[5] = 15 1968 self.assertEqual(15, msg.map_int32_int32.get(5)) 1969 self.assertEqual(15, msg.map_int32_int32.get(5)) 1970 with self.assertRaises(TypeError): 1971 msg.map_int32_int32.get('') 1972 1973 self.assertIsNone(msg.map_int32_foreign_message.get(5)) 1974 self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10)) 1975 self.assertEqual(10, msg.map_int32_foreign_message.get(key=5, default=10)) 1976 1977 submsg = msg.map_int32_foreign_message[5] 1978 self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) 1979 with self.assertRaises(TypeError): 1980 msg.map_int32_foreign_message.get('') 1981 1982 def testScalarMap(self): 1983 msg = map_unittest_pb2.TestMap() 1984 1985 self.assertEqual(0, len(msg.map_int32_int32)) 1986 self.assertFalse(5 in msg.map_int32_int32) 1987 1988 msg.map_int32_int32[-123] = -456 1989 msg.map_int64_int64[-2**33] = -2**34 1990 msg.map_uint32_uint32[123] = 456 1991 msg.map_uint64_uint64[2**33] = 2**34 1992 msg.map_int32_float[2] = 1.2 1993 msg.map_int32_double[1] = 3.3 1994 msg.map_string_string['abc'] = '123' 1995 msg.map_bool_bool[True] = True 1996 msg.map_int32_enum[888] = 2 1997 # Unknown numeric enum is supported in proto3. 1998 msg.map_int32_enum[123] = 456 1999 2000 self.assertEqual([], msg.FindInitializationErrors()) 2001 2002 self.assertEqual(1, len(msg.map_string_string)) 2003 2004 # Bad key. 2005 with self.assertRaises(TypeError): 2006 msg.map_string_string[123] = '123' 2007 2008 # Verify that trying to assign a bad key doesn't actually add a member to 2009 # the map. 2010 self.assertEqual(1, len(msg.map_string_string)) 2011 2012 # Bad value. 2013 with self.assertRaises(TypeError): 2014 msg.map_string_string['123'] = 123 2015 2016 serialized = msg.SerializeToString() 2017 msg2 = map_unittest_pb2.TestMap() 2018 msg2.ParseFromString(serialized) 2019 2020 # Bad key. 2021 with self.assertRaises(TypeError): 2022 msg2.map_string_string[123] = '123' 2023 2024 # Bad value. 2025 with self.assertRaises(TypeError): 2026 msg2.map_string_string['123'] = 123 2027 2028 self.assertEqual(-456, msg2.map_int32_int32[-123]) 2029 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 2030 self.assertEqual(456, msg2.map_uint32_uint32[123]) 2031 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 2032 self.assertAlmostEqual(1.2, msg.map_int32_float[2]) 2033 self.assertEqual(3.3, msg.map_int32_double[1]) 2034 self.assertEqual('123', msg2.map_string_string['abc']) 2035 self.assertEqual(True, msg2.map_bool_bool[True]) 2036 self.assertEqual(2, msg2.map_int32_enum[888]) 2037 self.assertEqual(456, msg2.map_int32_enum[123]) 2038 self.assertEqual('{-123: -456}', str(msg2.map_int32_int32)) 2039 2040 def testMapEntryAlwaysSerialized(self): 2041 msg = map_unittest_pb2.TestMap() 2042 msg.map_int32_int32[0] = 0 2043 msg.map_string_string[''] = '' 2044 self.assertEqual(msg.ByteSize(), 12) 2045 self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00', 2046 msg.SerializeToString()) 2047 2048 def testStringUnicodeConversionInMap(self): 2049 msg = map_unittest_pb2.TestMap() 2050 2051 unicode_obj = u'\u1234' 2052 bytes_obj = unicode_obj.encode('utf8') 2053 2054 msg.map_string_string[bytes_obj] = bytes_obj 2055 2056 (key, value) = list(msg.map_string_string.items())[0] 2057 2058 self.assertEqual(key, unicode_obj) 2059 self.assertEqual(value, unicode_obj) 2060 2061 self.assertIsInstance(key, str) 2062 self.assertIsInstance(value, str) 2063 2064 def testMessageMap(self): 2065 msg = map_unittest_pb2.TestMap() 2066 2067 self.assertEqual(0, len(msg.map_int32_foreign_message)) 2068 self.assertFalse(5 in msg.map_int32_foreign_message) 2069 2070 msg.map_int32_foreign_message[123] 2071 # get_or_create() is an alias for getitem. 2072 msg.map_int32_foreign_message.get_or_create(-456) 2073 2074 self.assertEqual(2, len(msg.map_int32_foreign_message)) 2075 self.assertIn(123, msg.map_int32_foreign_message) 2076 self.assertIn(-456, msg.map_int32_foreign_message) 2077 self.assertEqual(2, len(msg.map_int32_foreign_message)) 2078 2079 # Bad key. 2080 with self.assertRaises(TypeError): 2081 msg.map_int32_foreign_message['123'] 2082 2083 with self.assertRaises(TypeError): 2084 '123' in msg.map_int32_foreign_message 2085 2086 with self.assertRaises(TypeError): 2087 msg.map_int32_foreign_message.__contains__('123') 2088 2089 # Can't assign directly to submessage. 2090 with self.assertRaises(ValueError): 2091 msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] 2092 2093 # Verify that trying to assign a bad key doesn't actually add a member to 2094 # the map. 2095 self.assertEqual(2, len(msg.map_int32_foreign_message)) 2096 2097 serialized = msg.SerializeToString() 2098 msg2 = map_unittest_pb2.TestMap() 2099 msg2.ParseFromString(serialized) 2100 2101 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 2102 self.assertIn(123, msg2.map_int32_foreign_message) 2103 self.assertIn(-456, msg2.map_int32_foreign_message) 2104 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 2105 msg2.map_int32_foreign_message[123].c = 1 2106 # TODO: Fix text format for message map. 2107 self.assertIn( 2108 str(msg2.map_int32_foreign_message), 2109 ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }')) 2110 2111 def testNestedMessageMapItemDelete(self): 2112 msg = map_unittest_pb2.TestMap() 2113 msg.map_int32_all_types[1].optional_nested_message.bb = 1 2114 del msg.map_int32_all_types[1] 2115 msg.map_int32_all_types[2].optional_nested_message.bb = 2 2116 self.assertEqual(1, len(msg.map_int32_all_types)) 2117 msg.map_int32_all_types[1].optional_nested_message.bb = 1 2118 self.assertEqual(2, len(msg.map_int32_all_types)) 2119 2120 serialized = msg.SerializeToString() 2121 msg2 = map_unittest_pb2.TestMap() 2122 msg2.ParseFromString(serialized) 2123 keys = [1, 2] 2124 # The loop triggers PyErr_Occurred() in c extension. 2125 for key in keys: 2126 del msg2.map_int32_all_types[key] 2127 2128 def testMapByteSize(self): 2129 msg = map_unittest_pb2.TestMap() 2130 msg.map_int32_int32[1] = 1 2131 size = msg.ByteSize() 2132 msg.map_int32_int32[1] = 128 2133 self.assertEqual(msg.ByteSize(), size + 1) 2134 2135 msg.map_int32_foreign_message[19].c = 1 2136 size = msg.ByteSize() 2137 msg.map_int32_foreign_message[19].c = 128 2138 self.assertEqual(msg.ByteSize(), size + 1) 2139 2140 def testMergeFrom(self): 2141 msg = map_unittest_pb2.TestMap() 2142 msg.map_int32_int32[12] = 34 2143 msg.map_int32_int32[56] = 78 2144 msg.map_int64_int64[22] = 33 2145 msg.map_int32_foreign_message[111].c = 5 2146 msg.map_int32_foreign_message[222].c = 10 2147 2148 msg2 = map_unittest_pb2.TestMap() 2149 msg2.map_int32_int32[12] = 55 2150 msg2.map_int64_int64[88] = 99 2151 msg2.map_int32_foreign_message[222].c = 15 2152 msg2.map_int32_foreign_message[222].d = 20 2153 old_map_value = msg2.map_int32_foreign_message[222] 2154 2155 msg2.MergeFrom(msg) 2156 # Compare with expected message instead of call 2157 # msg2.map_int32_foreign_message[222] to make sure MergeFrom does not 2158 # sync with repeated field and there is no duplicated keys. 2159 expected_msg = map_unittest_pb2.TestMap() 2160 expected_msg.CopyFrom(msg) 2161 expected_msg.map_int64_int64[88] = 99 2162 self.assertEqual(msg2, expected_msg) 2163 2164 self.assertEqual(34, msg2.map_int32_int32[12]) 2165 self.assertEqual(78, msg2.map_int32_int32[56]) 2166 self.assertEqual(33, msg2.map_int64_int64[22]) 2167 self.assertEqual(99, msg2.map_int64_int64[88]) 2168 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 2169 self.assertEqual(10, msg2.map_int32_foreign_message[222].c) 2170 self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d')) 2171 if api_implementation.Type() != 'cpp': 2172 # During the call to MergeFrom(), the C++ implementation will have 2173 # deallocated the underlying message, but this is very difficult to detect 2174 # properly. The line below is likely to cause a segmentation fault. 2175 # With the Python implementation, old_map_value is just 'detached' from 2176 # the main message. Using it will not crash of course, but since it still 2177 # have a reference to the parent message I'm sure we can find interesting 2178 # ways to cause inconsistencies. 2179 self.assertEqual(15, old_map_value.c) 2180 2181 # Verify that there is only one entry per key, even though the MergeFrom 2182 # may have internally created multiple entries for a single key in the 2183 # list representation. 2184 as_dict = {} 2185 for key in msg2.map_int32_foreign_message: 2186 self.assertFalse(key in as_dict) 2187 as_dict[key] = msg2.map_int32_foreign_message[key].c 2188 2189 self.assertEqual({111: 5, 222: 10}, as_dict) 2190 2191 # Special case: test that delete of item really removes the item, even if 2192 # there might have physically been duplicate keys due to the previous merge. 2193 # This is only a special case for the C++ implementation which stores the 2194 # map as an array. 2195 del msg2.map_int32_int32[12] 2196 self.assertFalse(12 in msg2.map_int32_int32) 2197 2198 del msg2.map_int32_foreign_message[222] 2199 self.assertFalse(222 in msg2.map_int32_foreign_message) 2200 with self.assertRaises(TypeError): 2201 del msg2.map_int32_foreign_message[''] 2202 2203 def testMapMergeFrom(self): 2204 msg = map_unittest_pb2.TestMap() 2205 msg.map_int32_int32[12] = 34 2206 msg.map_int32_int32[56] = 78 2207 msg.map_int64_int64[22] = 33 2208 msg.map_int32_foreign_message[111].c = 5 2209 msg.map_int32_foreign_message[222].c = 10 2210 2211 msg2 = map_unittest_pb2.TestMap() 2212 msg2.map_int32_int32[12] = 55 2213 msg2.map_int64_int64[88] = 99 2214 msg2.map_int32_foreign_message[222].c = 15 2215 msg2.map_int32_foreign_message[222].d = 20 2216 2217 msg2.map_int32_int32.MergeFrom(msg.map_int32_int32) 2218 self.assertEqual(34, msg2.map_int32_int32[12]) 2219 self.assertEqual(78, msg2.map_int32_int32[56]) 2220 2221 msg2.map_int64_int64.MergeFrom(msg.map_int64_int64) 2222 self.assertEqual(33, msg2.map_int64_int64[22]) 2223 self.assertEqual(99, msg2.map_int64_int64[88]) 2224 2225 msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message) 2226 # Compare with expected message instead of call 2227 # msg.map_int32_foreign_message[222] to make sure MergeFrom does not 2228 # sync with repeated field and no duplicated keys. 2229 expected_msg = map_unittest_pb2.TestMap() 2230 expected_msg.CopyFrom(msg) 2231 expected_msg.map_int64_int64[88] = 99 2232 self.assertEqual(msg2, expected_msg) 2233 2234 # Test when cpp extension cache a map. 2235 m1 = map_unittest_pb2.TestMap() 2236 m2 = map_unittest_pb2.TestMap() 2237 self.assertEqual(m1.map_int32_foreign_message, m1.map_int32_foreign_message) 2238 m2.map_int32_foreign_message[123].c = 10 2239 m1.MergeFrom(m2) 2240 self.assertEqual(10, m2.map_int32_foreign_message[123].c) 2241 2242 # Test merge maps within different message types. 2243 m1 = map_unittest_pb2.TestMap() 2244 m2 = map_unittest_pb2.TestMessageMap() 2245 m2.map_int32_message[123].optional_int32 = 10 2246 m1.map_int32_all_types.MergeFrom(m2.map_int32_message) 2247 self.assertEqual(10, m1.map_int32_all_types[123].optional_int32) 2248 2249 # Test overwrite message value map 2250 msg = map_unittest_pb2.TestMap() 2251 msg.map_int32_foreign_message[222].c = 123 2252 msg2 = map_unittest_pb2.TestMap() 2253 msg2.map_int32_foreign_message[222].d = 20 2254 msg.MergeFromString(msg2.SerializeToString()) 2255 self.assertEqual(msg.map_int32_foreign_message[222].d, 20) 2256 self.assertNotEqual(msg.map_int32_foreign_message[222].c, 123) 2257 2258 # Merge a dict to map field is not accepted 2259 with self.assertRaises(AttributeError): 2260 m1.map_int32_all_types.MergeFrom( 2261 {1: unittest_proto3_arena_pb2.TestAllTypes()}) 2262 2263 def testMergeFromBadType(self): 2264 msg = map_unittest_pb2.TestMap() 2265 with self.assertRaisesRegex( 2266 TypeError, 2267 r'Parameter to MergeFrom\(\) must be instance of same class: expected ' 2268 r'.+TestMap got int\.'): 2269 msg.MergeFrom(1) 2270 2271 def testCopyFromBadType(self): 2272 msg = map_unittest_pb2.TestMap() 2273 with self.assertRaisesRegex( 2274 TypeError, 2275 r'Parameter to [A-Za-z]*From\(\) must be instance of same class: ' 2276 r'expected .+TestMap got int\.'): 2277 msg.CopyFrom(1) 2278 2279 def testIntegerMapWithLongs(self): 2280 msg = map_unittest_pb2.TestMap() 2281 msg.map_int32_int32[int(-123)] = int(-456) 2282 msg.map_int64_int64[int(-2**33)] = int(-2**34) 2283 msg.map_uint32_uint32[int(123)] = int(456) 2284 msg.map_uint64_uint64[int(2**33)] = int(2**34) 2285 2286 serialized = msg.SerializeToString() 2287 msg2 = map_unittest_pb2.TestMap() 2288 msg2.ParseFromString(serialized) 2289 2290 self.assertEqual(-456, msg2.map_int32_int32[-123]) 2291 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 2292 self.assertEqual(456, msg2.map_uint32_uint32[123]) 2293 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 2294 2295 def testMapAssignmentCausesPresence(self): 2296 msg = map_unittest_pb2.TestMapSubmessage() 2297 msg.test_map.map_int32_int32[123] = 456 2298 2299 serialized = msg.SerializeToString() 2300 msg2 = map_unittest_pb2.TestMapSubmessage() 2301 msg2.ParseFromString(serialized) 2302 2303 self.assertEqual(msg, msg2) 2304 2305 # Now test that various mutations of the map properly invalidate the 2306 # cached size of the submessage. 2307 msg.test_map.map_int32_int32[888] = 999 2308 serialized = msg.SerializeToString() 2309 msg2.ParseFromString(serialized) 2310 self.assertEqual(msg, msg2) 2311 2312 msg.test_map.map_int32_int32.clear() 2313 serialized = msg.SerializeToString() 2314 msg2.ParseFromString(serialized) 2315 self.assertEqual(msg, msg2) 2316 2317 def testMapAssignmentCausesPresenceForSubmessages(self): 2318 msg = map_unittest_pb2.TestMapSubmessage() 2319 msg.test_map.map_int32_foreign_message[123].c = 5 2320 2321 serialized = msg.SerializeToString() 2322 msg2 = map_unittest_pb2.TestMapSubmessage() 2323 msg2.ParseFromString(serialized) 2324 2325 self.assertEqual(msg, msg2) 2326 2327 # Now test that various mutations of the map properly invalidate the 2328 # cached size of the submessage. 2329 msg.test_map.map_int32_foreign_message[888].c = 7 2330 serialized = msg.SerializeToString() 2331 msg2.ParseFromString(serialized) 2332 self.assertEqual(msg, msg2) 2333 2334 msg.test_map.map_int32_foreign_message[888].MergeFrom( 2335 msg.test_map.map_int32_foreign_message[123]) 2336 serialized = msg.SerializeToString() 2337 msg2.ParseFromString(serialized) 2338 self.assertEqual(msg, msg2) 2339 2340 msg.test_map.map_int32_foreign_message.clear() 2341 serialized = msg.SerializeToString() 2342 msg2.ParseFromString(serialized) 2343 self.assertEqual(msg, msg2) 2344 2345 def testModifyMapWhileIterating(self): 2346 msg = map_unittest_pb2.TestMap() 2347 2348 string_string_iter = iter(msg.map_string_string) 2349 int32_foreign_iter = iter(msg.map_int32_foreign_message) 2350 2351 msg.map_string_string['abc'] = '123' 2352 msg.map_int32_foreign_message[5].c = 5 2353 2354 with self.assertRaises(RuntimeError): 2355 for key in string_string_iter: 2356 pass 2357 2358 with self.assertRaises(RuntimeError): 2359 for key in int32_foreign_iter: 2360 pass 2361 2362 def testModifyMapEntryWhileIterating(self): 2363 msg = map_unittest_pb2.TestMap() 2364 2365 msg.map_string_string['abc'] = '123' 2366 msg.map_string_string['def'] = '456' 2367 msg.map_string_string['ghi'] = '789' 2368 2369 msg.map_int32_foreign_message[5].c = 5 2370 msg.map_int32_foreign_message[6].c = 6 2371 msg.map_int32_foreign_message[7].c = 7 2372 2373 string_string_keys = list(msg.map_string_string.keys()) 2374 int32_foreign_keys = list(msg.map_int32_foreign_message.keys()) 2375 2376 keys = [] 2377 for key in msg.map_string_string: 2378 keys.append(key) 2379 msg.map_string_string[key] = '000' 2380 self.assertEqual(keys, string_string_keys) 2381 self.assertEqual(keys, list(msg.map_string_string.keys())) 2382 2383 keys = [] 2384 for key in msg.map_int32_foreign_message: 2385 keys.append(key) 2386 msg.map_int32_foreign_message[key].c = 0 2387 self.assertEqual(keys, int32_foreign_keys) 2388 self.assertEqual(keys, list(msg.map_int32_foreign_message.keys())) 2389 2390 def testSubmessageMap(self): 2391 msg = map_unittest_pb2.TestMap() 2392 2393 submsg = msg.map_int32_foreign_message[111] 2394 self.assertIs(submsg, msg.map_int32_foreign_message[111]) 2395 self.assertIsInstance(submsg, unittest_pb2.ForeignMessage) 2396 2397 submsg.c = 5 2398 2399 serialized = msg.SerializeToString() 2400 msg2 = map_unittest_pb2.TestMap() 2401 msg2.ParseFromString(serialized) 2402 2403 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 2404 2405 # Doesn't allow direct submessage assignment. 2406 with self.assertRaises(ValueError): 2407 msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage() 2408 2409 def testMapIteration(self): 2410 msg = map_unittest_pb2.TestMap() 2411 2412 for k, v in msg.map_int32_int32.items(): 2413 # Should not be reached. 2414 self.assertTrue(False) 2415 2416 msg.map_int32_int32[2] = 4 2417 msg.map_int32_int32[3] = 6 2418 msg.map_int32_int32[4] = 8 2419 self.assertEqual(3, len(msg.map_int32_int32)) 2420 2421 matching_dict = {2: 4, 3: 6, 4: 8} 2422 self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) 2423 2424 def testMapItems(self): 2425 # Map items used to have strange behaviors when use c extension. Because 2426 # [] may reorder the map and invalidate any existing iterators. 2427 # TODO: Check if [] reordering the map is a bug or intended 2428 # behavior. 2429 msg = map_unittest_pb2.TestMap() 2430 msg.map_string_string['local_init_op'] = '' 2431 msg.map_string_string['trainable_variables'] = '' 2432 msg.map_string_string['variables'] = '' 2433 msg.map_string_string['init_op'] = '' 2434 msg.map_string_string['summaries'] = '' 2435 items1 = msg.map_string_string.items() 2436 items2 = msg.map_string_string.items() 2437 self.assertEqual(items1, items2) 2438 2439 def testMapDeterministicSerialization(self): 2440 golden_data = (b'r\x0c\n\x07init_op\x12\x01d' 2441 b'r\n\n\x05item1\x12\x01e' 2442 b'r\n\n\x05item2\x12\x01f' 2443 b'r\n\n\x05item3\x12\x01g' 2444 b'r\x0b\n\x05item4\x12\x02QQ' 2445 b'r\x12\n\rlocal_init_op\x12\x01a' 2446 b'r\x0e\n\tsummaries\x12\x01e' 2447 b'r\x18\n\x13trainable_variables\x12\x01b' 2448 b'r\x0e\n\tvariables\x12\x01c') 2449 msg = map_unittest_pb2.TestMap() 2450 msg.map_string_string['local_init_op'] = 'a' 2451 msg.map_string_string['trainable_variables'] = 'b' 2452 msg.map_string_string['variables'] = 'c' 2453 msg.map_string_string['init_op'] = 'd' 2454 msg.map_string_string['summaries'] = 'e' 2455 msg.map_string_string['item1'] = 'e' 2456 msg.map_string_string['item2'] = 'f' 2457 msg.map_string_string['item3'] = 'g' 2458 msg.map_string_string['item4'] = 'QQ' 2459 2460 # If deterministic serialization is not working correctly, this will be 2461 # "flaky" depending on the exact python dict hash seed. 2462 # 2463 # Fortunately, there are enough items in this map that it is extremely 2464 # unlikely to ever hit the "right" in-order combination, so the test 2465 # itself should fail reliably. 2466 self.assertEqual(golden_data, msg.SerializeToString(deterministic=True)) 2467 2468 def testMapIterationClearMessage(self): 2469 # Iterator needs to work even if message and map are deleted. 2470 msg = map_unittest_pb2.TestMap() 2471 2472 msg.map_int32_int32[2] = 4 2473 msg.map_int32_int32[3] = 6 2474 msg.map_int32_int32[4] = 8 2475 2476 it = msg.map_int32_int32.items() 2477 del msg 2478 2479 matching_dict = {2: 4, 3: 6, 4: 8} 2480 self.assertMapIterEquals(it, matching_dict) 2481 2482 def testMapConstruction(self): 2483 msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4}) 2484 self.assertEqual(2, msg.map_int32_int32[1]) 2485 self.assertEqual(4, msg.map_int32_int32[3]) 2486 2487 msg = map_unittest_pb2.TestMap( 2488 map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)}) 2489 self.assertEqual(5, msg.map_int32_foreign_message[3].c) 2490 2491 def testMapScalarFieldConstruction(self): 2492 msg1 = map_unittest_pb2.TestMap() 2493 msg1.map_int32_int32[1] = 42 2494 msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32) 2495 self.assertEqual(42, msg2.map_int32_int32[1]) 2496 2497 def testMapMessageFieldConstruction(self): 2498 msg1 = map_unittest_pb2.TestMap() 2499 msg1.map_string_foreign_message['test'].c = 42 2500 msg2 = map_unittest_pb2.TestMap( 2501 map_string_foreign_message=msg1.map_string_foreign_message) 2502 self.assertEqual(42, msg2.map_string_foreign_message['test'].c) 2503 2504 def testMapFieldRaisesCorrectError(self): 2505 # Should raise a TypeError when given a non-iterable. 2506 with self.assertRaises(TypeError): 2507 map_unittest_pb2.TestMap(map_string_foreign_message=1) 2508 2509 def testMapValidAfterFieldCleared(self): 2510 # Map needs to work even if field is cleared. 2511 # For the C++ implementation this tests the correctness of 2512 # MapContainer::Release() 2513 msg = map_unittest_pb2.TestMap() 2514 int32_map = msg.map_int32_int32 2515 2516 int32_map[2] = 4 2517 int32_map[3] = 6 2518 int32_map[4] = 8 2519 2520 msg.ClearField('map_int32_int32') 2521 self.assertEqual(b'', msg.SerializeToString()) 2522 matching_dict = {2: 4, 3: 6, 4: 8} 2523 self.assertMapIterEquals(int32_map.items(), matching_dict) 2524 2525 def testMessageMapValidAfterFieldCleared(self): 2526 # Map needs to work even if field is cleared. 2527 # For the C++ implementation this tests the correctness of 2528 # MapContainer::Release() 2529 msg = map_unittest_pb2.TestMap() 2530 int32_foreign_message = msg.map_int32_foreign_message 2531 2532 int32_foreign_message[2].c = 5 2533 2534 msg.ClearField('map_int32_foreign_message') 2535 self.assertEqual(b'', msg.SerializeToString()) 2536 self.assertTrue(2 in int32_foreign_message.keys()) 2537 2538 def testMessageMapItemValidAfterTopMessageCleared(self): 2539 # Message map item needs to work even if it is cleared. 2540 # For the C++ implementation this tests the correctness of 2541 # MapContainer::Release() 2542 msg = map_unittest_pb2.TestMap() 2543 msg.map_int32_all_types[2].optional_string = 'bar' 2544 2545 if api_implementation.Type() == 'cpp': 2546 # Need to keep the map reference because of b/27942626. 2547 # TODO: Remove it. 2548 unused_map = msg.map_int32_all_types # pylint: disable=unused-variable 2549 msg_value = msg.map_int32_all_types[2] 2550 msg.Clear() 2551 2552 # Reset to trigger sync between repeated field and map in c++. 2553 msg.map_int32_all_types[3].optional_string = 'foo' 2554 self.assertEqual(msg_value.optional_string, 'bar') 2555 2556 def testMapIterInvalidatedByClearField(self): 2557 # Map iterator is invalidated when field is cleared. 2558 # But this case does need to not crash the interpreter. 2559 # For the C++ implementation this tests the correctness of 2560 # ScalarMapContainer::Release() 2561 msg = map_unittest_pb2.TestMap() 2562 2563 it = iter(msg.map_int32_int32) 2564 2565 msg.ClearField('map_int32_int32') 2566 with self.assertRaises(RuntimeError): 2567 for _ in it: 2568 pass 2569 2570 it = iter(msg.map_int32_foreign_message) 2571 msg.ClearField('map_int32_foreign_message') 2572 with self.assertRaises(RuntimeError): 2573 for _ in it: 2574 pass 2575 2576 def testMapDelete(self): 2577 msg = map_unittest_pb2.TestMap() 2578 2579 self.assertEqual(0, len(msg.map_int32_int32)) 2580 2581 msg.map_int32_int32[4] = 6 2582 self.assertEqual(1, len(msg.map_int32_int32)) 2583 2584 with self.assertRaises(KeyError): 2585 del msg.map_int32_int32[88] 2586 2587 del msg.map_int32_int32[4] 2588 self.assertEqual(0, len(msg.map_int32_int32)) 2589 2590 with self.assertRaises(KeyError): 2591 del msg.map_int32_all_types[32] 2592 2593 def testMapsAreMapping(self): 2594 msg = map_unittest_pb2.TestMap() 2595 self.assertIsInstance(msg.map_int32_int32, collections.abc.Mapping) 2596 self.assertIsInstance(msg.map_int32_int32, collections.abc.MutableMapping) 2597 self.assertIsInstance(msg.map_int32_foreign_message, 2598 collections.abc.Mapping) 2599 self.assertIsInstance(msg.map_int32_foreign_message, 2600 collections.abc.MutableMapping) 2601 2602 def testMapsCompare(self): 2603 msg = map_unittest_pb2.TestMap() 2604 msg.map_int32_int32[-123] = -456 2605 self.assertEqual(msg.map_int32_int32, msg.map_int32_int32) 2606 self.assertEqual(msg.map_int32_foreign_message, 2607 msg.map_int32_foreign_message) 2608 self.assertNotEqual(msg.map_int32_int32, 0) 2609 2610 def testMapFindInitializationErrorsSmokeTest(self): 2611 msg = map_unittest_pb2.TestMap() 2612 msg.map_string_string['abc'] = '123' 2613 msg.map_int32_int32[35] = 64 2614 msg.map_string_foreign_message['foo'].c = 5 2615 self.assertEqual(0, len(msg.FindInitializationErrors())) 2616 2617 @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2') 2618 def testStrictUtf8Check(self): 2619 # Test u'\ud801' is rejected at parser in both python2 and python3. 2620 serialized = (b'r\x03\xed\xa0\x81') 2621 msg = unittest_proto3_arena_pb2.TestAllTypes() 2622 with self.assertRaises(Exception) as context: 2623 msg.MergeFromString(serialized) 2624 if api_implementation.Type() == 'python': 2625 self.assertIn('optional_string', str(context.exception)) 2626 else: 2627 self.assertIn('Error parsing message', str(context.exception)) 2628 2629 # Test optional_string=u'' is accepted. 2630 serialized = unittest_proto3_arena_pb2.TestAllTypes( 2631 optional_string=u'').SerializeToString() 2632 msg2 = unittest_proto3_arena_pb2.TestAllTypes() 2633 msg2.MergeFromString(serialized) 2634 self.assertEqual(msg2.optional_string, u'') 2635 2636 msg = unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud001') 2637 self.assertEqual(msg.optional_string, u'\ud001') 2638 2639 def testSurrogatesInPython3(self): 2640 # Surrogates are rejected at setters in Python3. 2641 with self.assertRaises(ValueError): 2642 unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\udc01') 2643 with self.assertRaises(ValueError): 2644 unittest_proto3_arena_pb2.TestAllTypes(optional_string=b'\xed\xa0\x81') 2645 with self.assertRaises(ValueError): 2646 unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801') 2647 with self.assertRaises(ValueError): 2648 unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\ud801') 2649 2650 def testCrashNullAA(self): 2651 self.assertEqual( 2652 unittest_proto3_arena_pb2.TestAllTypes.NestedMessage(), 2653 unittest_proto3_arena_pb2.TestAllTypes.NestedMessage()) 2654 2655 def testCrashNullAB(self): 2656 self.assertEqual( 2657 unittest_proto3_arena_pb2.TestAllTypes.NestedMessage(), 2658 unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message) 2659 2660 def testCrashNullBA(self): 2661 self.assertEqual( 2662 unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message, 2663 unittest_proto3_arena_pb2.TestAllTypes.NestedMessage()) 2664 2665 def testCrashNullBB(self): 2666 self.assertEqual( 2667 unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message, 2668 unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message) 2669 2670 2671@testing_refleaks.TestCase 2672class ValidTypeNamesTest(unittest.TestCase): 2673 2674 def assertImportFromName(self, msg, base_name): 2675 # Parse <type 'module.class_name'> to extra 'some.name' as a string. 2676 tp_name = str(type(msg)).split("'")[1] 2677 valid_names = ('Repeated%sContainer' % base_name, 2678 'Repeated%sFieldContainer' % base_name) 2679 self.assertTrue( 2680 any(tp_name.endswith(v) for v in valid_names), 2681 '%r does end with any of %r' % (tp_name, valid_names)) 2682 2683 parts = tp_name.split('.') 2684 class_name = parts[-1] 2685 module_name = '.'.join(parts[:-1]) 2686 __import__(module_name, fromlist=[class_name]) 2687 2688 def testTypeNamesCanBeImported(self): 2689 # If import doesn't work, pickling won't work either. 2690 pb = unittest_pb2.TestAllTypes() 2691 self.assertImportFromName(pb.repeated_int32, 'Scalar') 2692 self.assertImportFromName(pb.repeated_nested_message, 'Composite') 2693 2694 2695# We can only test this case under proto2, because proto3 will reject invalid 2696# UTF-8 in the parser, so there should be no way of creating a string field 2697# that contains invalid UTF-8. 2698# 2699# We also can't test it in pure-Python, which validates all string fields for 2700# UTF-8 even when the spec says it shouldn't. 2701@unittest.skipIf(api_implementation.Type() == 'python', 2702 'Python can\'t create invalid UTF-8 strings') 2703@testing_refleaks.TestCase 2704class InvalidUtf8Test(unittest.TestCase): 2705 2706 def testInvalidUtf8Printing(self): 2707 one_bytes = unittest_pb2.OneBytes() 2708 one_bytes.data = b'ABC\xff123' 2709 one_string = unittest_pb2.OneString() 2710 one_string.ParseFromString(one_bytes.SerializeToString()) 2711 self.assertIn('data: "ABC\\377123"', str(one_string)) 2712 2713 def testValidUtf8Printing(self): 2714 self.assertIn('data: "€"', str(unittest_pb2.OneString(data='€'))) # 2 byte 2715 self.assertIn('data: "£"', str(unittest_pb2.OneString(data='£'))) # 3 byte 2716 self.assertIn('data: ""', str(unittest_pb2.OneString(data=''))) # 4 byte 2717 2718 2719@testing_refleaks.TestCase 2720class PackedFieldTest(unittest.TestCase): 2721 2722 def setMessage(self, message): 2723 message.repeated_int32.append(1) 2724 message.repeated_int64.append(1) 2725 message.repeated_uint32.append(1) 2726 message.repeated_uint64.append(1) 2727 message.repeated_sint32.append(1) 2728 message.repeated_sint64.append(1) 2729 message.repeated_fixed32.append(1) 2730 message.repeated_fixed64.append(1) 2731 message.repeated_sfixed32.append(1) 2732 message.repeated_sfixed64.append(1) 2733 message.repeated_float.append(1.0) 2734 message.repeated_double.append(1.0) 2735 message.repeated_bool.append(True) 2736 message.repeated_nested_enum.append(1) 2737 2738 def testPackedFields(self): 2739 message = packed_field_test_pb2.TestPackedTypes() 2740 self.setMessage(message) 2741 golden_data = (b'\x0A\x01\x01' 2742 b'\x12\x01\x01' 2743 b'\x1A\x01\x01' 2744 b'\x22\x01\x01' 2745 b'\x2A\x01\x02' 2746 b'\x32\x01\x02' 2747 b'\x3A\x04\x01\x00\x00\x00' 2748 b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00' 2749 b'\x4A\x04\x01\x00\x00\x00' 2750 b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00' 2751 b'\x5A\x04\x00\x00\x80\x3f' 2752 b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f' 2753 b'\x6A\x01\x01' 2754 b'\x72\x01\x01') 2755 self.assertEqual(golden_data, message.SerializeToString()) 2756 2757 def testUnpackedFields(self): 2758 message = packed_field_test_pb2.TestUnpackedTypes() 2759 self.setMessage(message) 2760 golden_data = (b'\x08\x01' 2761 b'\x10\x01' 2762 b'\x18\x01' 2763 b'\x20\x01' 2764 b'\x28\x02' 2765 b'\x30\x02' 2766 b'\x3D\x01\x00\x00\x00' 2767 b'\x41\x01\x00\x00\x00\x00\x00\x00\x00' 2768 b'\x4D\x01\x00\x00\x00' 2769 b'\x51\x01\x00\x00\x00\x00\x00\x00\x00' 2770 b'\x5D\x00\x00\x80\x3f' 2771 b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f' 2772 b'\x68\x01' 2773 b'\x70\x01') 2774 self.assertEqual(golden_data, message.SerializeToString()) 2775 2776 2777 2778@testing_refleaks.TestCase 2779class OversizeProtosTest(unittest.TestCase): 2780 2781 def GenerateNestedProto(self, n): 2782 msg = unittest_pb2.TestRecursiveMessage() 2783 sub = msg 2784 for _ in range(n): 2785 sub = sub.a 2786 sub.i = 0 2787 return msg.SerializeToString() 2788 2789 def testSucceedOkSizedProto(self): 2790 msg = unittest_pb2.TestRecursiveMessage() 2791 msg.ParseFromString(self.GenerateNestedProto(100)) 2792 2793 def testAssertOversizeProto(self): 2794 if api_implementation.Type() != 'python': 2795 api_implementation._c_module.SetAllowOversizeProtos(False) 2796 msg = unittest_pb2.TestRecursiveMessage() 2797 with self.assertRaises(message.DecodeError) as context: 2798 msg.ParseFromString(self.GenerateNestedProto(101)) 2799 self.assertIn('Error parsing message', str(context.exception)) 2800 2801 def testSucceedOversizeProto(self): 2802 if api_implementation.Type() == 'python': 2803 decoder.SetRecursionLimit(310) 2804 else: 2805 api_implementation._c_module.SetAllowOversizeProtos(True) 2806 msg = unittest_pb2.TestRecursiveMessage() 2807 msg.ParseFromString(self.GenerateNestedProto(101)) 2808 decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT) 2809 2810 2811if __name__ == '__main__': 2812 unittest.main() 2813