1# -*- coding: utf-8 -*- 2# Protocol Buffers - Google's data interchange format 3# Copyright 2008 Google Inc. All rights reserved. 4# https://developers.google.com/protocol-buffers/ 5# 6# Redistribution and use in source and binary forms, with or without 7# modification, are permitted provided that the following conditions are 8# met: 9# 10# * Redistributions of source code must retain the above copyright 11# notice, this list of conditions and the following disclaimer. 12# * Redistributions in binary form must reproduce the above 13# copyright notice, this list of conditions and the following disclaimer 14# in the documentation and/or other materials provided with the 15# distribution. 16# * Neither the name of Google Inc. nor the names of its 17# contributors may be used to endorse or promote products derived from 18# this software without specific prior written permission. 19# 20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32"""Tests python protocol buffers against the golden message. 33 34Note that the golden messages exercise every known field type, thus this 35test ends up exercising and verifying nearly all of the parsing and 36serialization code in the whole library. 37 38TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of 39sense to call this a test of the "message" module, which only declares an 40abstract interface. 41""" 42 43__author__ = 'gps@google.com (Gregory P. Smith)' 44 45import collections 46import copy 47import math 48import operator 49import pickle 50import pydoc 51import sys 52import unittest 53import warnings 54 55cmp = lambda x, y: (x > y) - (x < y) 56 57from google.protobuf import map_proto2_unittest_pb2 58from google.protobuf import map_unittest_pb2 59from google.protobuf import unittest_pb2 60from google.protobuf import unittest_proto3_arena_pb2 61from google.protobuf import descriptor_pb2 62from google.protobuf import descriptor 63from google.protobuf import descriptor_pool 64from google.protobuf import message_factory 65from google.protobuf import text_format 66from google.protobuf.internal import api_implementation 67from google.protobuf.internal import encoder 68from google.protobuf.internal import more_extensions_pb2 69from google.protobuf.internal import packed_field_test_pb2 70from google.protobuf.internal import test_util 71from google.protobuf.internal import test_proto3_optional_pb2 72from google.protobuf.internal import testing_refleaks 73from google.protobuf import message 74from google.protobuf.internal import _parameterized 75 76UCS2_MAXUNICODE = 65535 77 78warnings.simplefilter('error', DeprecationWarning) 79 80 81@_parameterized.named_parameters(('_proto2', unittest_pb2), 82 ('_proto3', unittest_proto3_arena_pb2)) 83@testing_refleaks.TestCase 84class MessageTest(unittest.TestCase): 85 86 def testBadUtf8String(self, message_module): 87 if api_implementation.Type() != 'python': 88 self.skipTest('Skipping testBadUtf8String, currently only the python ' 89 'api implementation raises UnicodeDecodeError when a ' 90 'string field contains bad utf-8.') 91 bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') 92 with self.assertRaises(UnicodeDecodeError) as context: 93 message_module.TestAllTypes.FromString(bad_utf8_data) 94 self.assertIn('TestAllTypes.optional_string', str(context.exception)) 95 96 def testGoldenMessage(self, message_module): 97 # Proto3 doesn't have the "default_foo" members or foreign enums, 98 # and doesn't preserve unknown fields, so for proto3 we use a golden 99 # message that doesn't have these fields set. 100 if message_module is unittest_pb2: 101 golden_data = test_util.GoldenFileData('golden_message_oneof_implemented') 102 else: 103 golden_data = test_util.GoldenFileData('golden_message_proto3') 104 105 golden_message = message_module.TestAllTypes() 106 golden_message.ParseFromString(golden_data) 107 if message_module is unittest_pb2: 108 test_util.ExpectAllFieldsSet(self, golden_message) 109 self.assertEqual(golden_data, golden_message.SerializeToString()) 110 golden_copy = copy.deepcopy(golden_message) 111 self.assertEqual(golden_data, golden_copy.SerializeToString()) 112 113 def testGoldenPackedMessage(self, message_module): 114 golden_data = test_util.GoldenFileData('golden_packed_fields_message') 115 golden_message = message_module.TestPackedTypes() 116 parsed_bytes = golden_message.ParseFromString(golden_data) 117 all_set = message_module.TestPackedTypes() 118 test_util.SetAllPackedFields(all_set) 119 self.assertEqual(parsed_bytes, len(golden_data)) 120 self.assertEqual(all_set, golden_message) 121 self.assertEqual(golden_data, all_set.SerializeToString()) 122 golden_copy = copy.deepcopy(golden_message) 123 self.assertEqual(golden_data, golden_copy.SerializeToString()) 124 125 def testParseErrors(self, message_module): 126 msg = message_module.TestAllTypes() 127 self.assertRaises(TypeError, msg.FromString, 0) 128 self.assertRaises(Exception, msg.FromString, '0') 129 # TODO(jieluo): Fix cpp extension to raise error instead of warning. 130 # b/27494216 131 end_tag = encoder.TagBytes(1, 4) 132 if api_implementation.Type() == 'python': 133 with self.assertRaises(message.DecodeError) as context: 134 msg.FromString(end_tag) 135 self.assertEqual('Unexpected end-group tag.', str(context.exception)) 136 137 # Field number 0 is illegal. 138 self.assertRaises(message.DecodeError, msg.FromString, b'\3\4') 139 140 def testDeterminismParameters(self, message_module): 141 # This message is always deterministically serialized, even if determinism 142 # is disabled, so we can use it to verify that all the determinism 143 # parameters work correctly. 144 golden_data = (b'\xe2\x02\nOne string' 145 b'\xe2\x02\nTwo string' 146 b'\xe2\x02\nRed string' 147 b'\xe2\x02\x0bBlue string') 148 golden_message = message_module.TestAllTypes() 149 golden_message.repeated_string.extend([ 150 'One string', 151 'Two string', 152 'Red string', 153 'Blue string', 154 ]) 155 self.assertEqual(golden_data, 156 golden_message.SerializeToString(deterministic=None)) 157 self.assertEqual(golden_data, 158 golden_message.SerializeToString(deterministic=False)) 159 self.assertEqual(golden_data, 160 golden_message.SerializeToString(deterministic=True)) 161 162 class BadArgError(Exception): 163 pass 164 165 class BadArg(object): 166 167 def __nonzero__(self): 168 raise BadArgError() 169 170 def __bool__(self): 171 raise BadArgError() 172 173 with self.assertRaises(BadArgError): 174 golden_message.SerializeToString(deterministic=BadArg()) 175 176 def testPickleSupport(self, message_module): 177 golden_data = test_util.GoldenFileData('golden_message') 178 golden_message = message_module.TestAllTypes() 179 golden_message.ParseFromString(golden_data) 180 pickled_message = pickle.dumps(golden_message) 181 182 unpickled_message = pickle.loads(pickled_message) 183 self.assertEqual(unpickled_message, golden_message) 184 185 def testPickleNestedMessage(self, message_module): 186 golden_message = message_module.TestPickleNestedMessage.NestedMessage(bb=1) 187 pickled_message = pickle.dumps(golden_message) 188 unpickled_message = pickle.loads(pickled_message) 189 self.assertEqual(unpickled_message, golden_message) 190 191 def testPickleNestedNestedMessage(self, message_module): 192 cls = message_module.TestPickleNestedMessage.NestedMessage 193 golden_message = cls.NestedNestedMessage(cc=1) 194 pickled_message = pickle.dumps(golden_message) 195 unpickled_message = pickle.loads(pickled_message) 196 self.assertEqual(unpickled_message, golden_message) 197 198 def testPositiveInfinity(self, message_module): 199 if message_module is unittest_pb2: 200 golden_data = (b'\x5D\x00\x00\x80\x7F' 201 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 202 b'\xCD\x02\x00\x00\x80\x7F' 203 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') 204 else: 205 golden_data = (b'\x5D\x00\x00\x80\x7F' 206 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 207 b'\xCA\x02\x04\x00\x00\x80\x7F' 208 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 209 210 golden_message = message_module.TestAllTypes() 211 golden_message.ParseFromString(golden_data) 212 self.assertEqual(golden_message.optional_float, math.inf) 213 self.assertEqual(golden_message.optional_double, math.inf) 214 self.assertEqual(golden_message.repeated_float[0], math.inf) 215 self.assertEqual(golden_message.repeated_double[0], math.inf) 216 self.assertEqual(golden_data, golden_message.SerializeToString()) 217 218 def testNegativeInfinity(self, message_module): 219 if message_module is unittest_pb2: 220 golden_data = (b'\x5D\x00\x00\x80\xFF' 221 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 222 b'\xCD\x02\x00\x00\x80\xFF' 223 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') 224 else: 225 golden_data = (b'\x5D\x00\x00\x80\xFF' 226 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 227 b'\xCA\x02\x04\x00\x00\x80\xFF' 228 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 229 230 golden_message = message_module.TestAllTypes() 231 golden_message.ParseFromString(golden_data) 232 self.assertEqual(golden_message.optional_float, -math.inf) 233 self.assertEqual(golden_message.optional_double, -math.inf) 234 self.assertEqual(golden_message.repeated_float[0], -math.inf) 235 self.assertEqual(golden_message.repeated_double[0], -math.inf) 236 self.assertEqual(golden_data, golden_message.SerializeToString()) 237 238 def testNotANumber(self, message_module): 239 golden_data = (b'\x5D\x00\x00\xC0\x7F' 240 b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' 241 b'\xCD\x02\x00\x00\xC0\x7F' 242 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') 243 golden_message = message_module.TestAllTypes() 244 golden_message.ParseFromString(golden_data) 245 self.assertTrue(math.isnan(golden_message.optional_float)) 246 self.assertTrue(math.isnan(golden_message.optional_double)) 247 self.assertTrue(math.isnan(golden_message.repeated_float[0])) 248 self.assertTrue(math.isnan(golden_message.repeated_double[0])) 249 250 # The protocol buffer may serialize to any one of multiple different 251 # representations of a NaN. Rather than verify a specific representation, 252 # verify the serialized string can be converted into a correctly 253 # behaving protocol buffer. 254 serialized = golden_message.SerializeToString() 255 message = message_module.TestAllTypes() 256 message.ParseFromString(serialized) 257 self.assertTrue(math.isnan(message.optional_float)) 258 self.assertTrue(math.isnan(message.optional_double)) 259 self.assertTrue(math.isnan(message.repeated_float[0])) 260 self.assertTrue(math.isnan(message.repeated_double[0])) 261 262 def testPositiveInfinityPacked(self, message_module): 263 golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' 264 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 265 golden_message = message_module.TestPackedTypes() 266 golden_message.ParseFromString(golden_data) 267 self.assertEqual(golden_message.packed_float[0], math.inf) 268 self.assertEqual(golden_message.packed_double[0], math.inf) 269 self.assertEqual(golden_data, golden_message.SerializeToString()) 270 271 def testNegativeInfinityPacked(self, message_module): 272 golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' 273 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 274 golden_message = message_module.TestPackedTypes() 275 golden_message.ParseFromString(golden_data) 276 self.assertEqual(golden_message.packed_float[0], -math.inf) 277 self.assertEqual(golden_message.packed_double[0], -math.inf) 278 self.assertEqual(golden_data, golden_message.SerializeToString()) 279 280 def testNotANumberPacked(self, message_module): 281 golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' 282 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') 283 golden_message = message_module.TestPackedTypes() 284 golden_message.ParseFromString(golden_data) 285 self.assertTrue(math.isnan(golden_message.packed_float[0])) 286 self.assertTrue(math.isnan(golden_message.packed_double[0])) 287 288 serialized = golden_message.SerializeToString() 289 message = message_module.TestPackedTypes() 290 message.ParseFromString(serialized) 291 self.assertTrue(math.isnan(message.packed_float[0])) 292 self.assertTrue(math.isnan(message.packed_double[0])) 293 294 def testExtremeFloatValues(self, message_module): 295 message = message_module.TestAllTypes() 296 297 # Most positive exponent, no significand bits set. 298 kMostPosExponentNoSigBits = math.pow(2, 127) 299 message.optional_float = kMostPosExponentNoSigBits 300 message.ParseFromString(message.SerializeToString()) 301 self.assertTrue(message.optional_float == kMostPosExponentNoSigBits) 302 303 # Most positive exponent, one significand bit set. 304 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127) 305 message.optional_float = kMostPosExponentOneSigBit 306 message.ParseFromString(message.SerializeToString()) 307 self.assertTrue(message.optional_float == kMostPosExponentOneSigBit) 308 309 # Repeat last two cases with values of same magnitude, but negative. 310 message.optional_float = -kMostPosExponentNoSigBits 311 message.ParseFromString(message.SerializeToString()) 312 self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits) 313 314 message.optional_float = -kMostPosExponentOneSigBit 315 message.ParseFromString(message.SerializeToString()) 316 self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit) 317 318 # Most negative exponent, no significand bits set. 319 kMostNegExponentNoSigBits = math.pow(2, -127) 320 message.optional_float = kMostNegExponentNoSigBits 321 message.ParseFromString(message.SerializeToString()) 322 self.assertTrue(message.optional_float == kMostNegExponentNoSigBits) 323 324 # Most negative exponent, one significand bit set. 325 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127) 326 message.optional_float = kMostNegExponentOneSigBit 327 message.ParseFromString(message.SerializeToString()) 328 self.assertTrue(message.optional_float == kMostNegExponentOneSigBit) 329 330 # Repeat last two cases with values of the same magnitude, but negative. 331 message.optional_float = -kMostNegExponentNoSigBits 332 message.ParseFromString(message.SerializeToString()) 333 self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits) 334 335 message.optional_float = -kMostNegExponentOneSigBit 336 message.ParseFromString(message.SerializeToString()) 337 self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) 338 339 # Max 4 bytes float value 340 max_float = float.fromhex('0x1.fffffep+127') 341 message.optional_float = max_float 342 self.assertAlmostEqual(message.optional_float, max_float) 343 serialized_data = message.SerializeToString() 344 message.ParseFromString(serialized_data) 345 self.assertAlmostEqual(message.optional_float, max_float) 346 347 # Test set double to float field. 348 message.optional_float = 3.4028235e+39 349 self.assertEqual(message.optional_float, float('inf')) 350 serialized_data = message.SerializeToString() 351 message.ParseFromString(serialized_data) 352 self.assertEqual(message.optional_float, float('inf')) 353 354 message.optional_float = -3.4028235e+39 355 self.assertEqual(message.optional_float, float('-inf')) 356 357 message.optional_float = 1.4028235e-39 358 self.assertAlmostEqual(message.optional_float, 1.4028235e-39) 359 360 def testExtremeDoubleValues(self, message_module): 361 message = message_module.TestAllTypes() 362 363 # Most positive exponent, no significand bits set. 364 kMostPosExponentNoSigBits = math.pow(2, 1023) 365 message.optional_double = kMostPosExponentNoSigBits 366 message.ParseFromString(message.SerializeToString()) 367 self.assertTrue(message.optional_double == kMostPosExponentNoSigBits) 368 369 # Most positive exponent, one significand bit set. 370 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023) 371 message.optional_double = kMostPosExponentOneSigBit 372 message.ParseFromString(message.SerializeToString()) 373 self.assertTrue(message.optional_double == kMostPosExponentOneSigBit) 374 375 # Repeat last two cases with values of same magnitude, but negative. 376 message.optional_double = -kMostPosExponentNoSigBits 377 message.ParseFromString(message.SerializeToString()) 378 self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits) 379 380 message.optional_double = -kMostPosExponentOneSigBit 381 message.ParseFromString(message.SerializeToString()) 382 self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit) 383 384 # Most negative exponent, no significand bits set. 385 kMostNegExponentNoSigBits = math.pow(2, -1023) 386 message.optional_double = kMostNegExponentNoSigBits 387 message.ParseFromString(message.SerializeToString()) 388 self.assertTrue(message.optional_double == kMostNegExponentNoSigBits) 389 390 # Most negative exponent, one significand bit set. 391 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023) 392 message.optional_double = kMostNegExponentOneSigBit 393 message.ParseFromString(message.SerializeToString()) 394 self.assertTrue(message.optional_double == kMostNegExponentOneSigBit) 395 396 # Repeat last two cases with values of the same magnitude, but negative. 397 message.optional_double = -kMostNegExponentNoSigBits 398 message.ParseFromString(message.SerializeToString()) 399 self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits) 400 401 message.optional_double = -kMostNegExponentOneSigBit 402 message.ParseFromString(message.SerializeToString()) 403 self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) 404 405 def testFloatPrinting(self, message_module): 406 message = message_module.TestAllTypes() 407 message.optional_float = 2.0 408 self.assertEqual(str(message), 'optional_float: 2.0\n') 409 410 def testHighPrecisionFloatPrinting(self, message_module): 411 msg = message_module.TestAllTypes() 412 msg.optional_float = 0.12345678912345678 413 old_float = msg.optional_float 414 msg.ParseFromString(msg.SerializeToString()) 415 self.assertEqual(old_float, msg.optional_float) 416 417 def testHighPrecisionDoublePrinting(self, message_module): 418 msg = message_module.TestAllTypes() 419 msg.optional_double = 0.12345678912345678 420 self.assertEqual(str(msg), 'optional_double: 0.12345678912345678\n') 421 422 def testUnknownFieldPrinting(self, message_module): 423 populated = message_module.TestAllTypes() 424 test_util.SetAllNonLazyFields(populated) 425 empty = message_module.TestEmptyMessage() 426 empty.ParseFromString(populated.SerializeToString()) 427 self.assertEqual(str(empty), '') 428 429 def testAppendRepeatedCompositeField(self, message_module): 430 msg = message_module.TestAllTypes() 431 msg.repeated_nested_message.append( 432 message_module.TestAllTypes.NestedMessage(bb=1)) 433 nested = message_module.TestAllTypes.NestedMessage(bb=2) 434 msg.repeated_nested_message.append(nested) 435 try: 436 msg.repeated_nested_message.append(1) 437 except TypeError: 438 pass 439 self.assertEqual(2, len(msg.repeated_nested_message)) 440 self.assertEqual([1, 2], [m.bb for m in msg.repeated_nested_message]) 441 442 def testInsertRepeatedCompositeField(self, message_module): 443 msg = message_module.TestAllTypes() 444 msg.repeated_nested_message.insert( 445 -1, message_module.TestAllTypes.NestedMessage(bb=1)) 446 sub_msg = msg.repeated_nested_message[0] 447 msg.repeated_nested_message.insert( 448 0, message_module.TestAllTypes.NestedMessage(bb=2)) 449 msg.repeated_nested_message.insert( 450 99, message_module.TestAllTypes.NestedMessage(bb=3)) 451 msg.repeated_nested_message.insert( 452 -2, message_module.TestAllTypes.NestedMessage(bb=-1)) 453 msg.repeated_nested_message.insert( 454 -1000, message_module.TestAllTypes.NestedMessage(bb=-1000)) 455 try: 456 msg.repeated_nested_message.insert(1, 999) 457 except TypeError: 458 pass 459 self.assertEqual(5, len(msg.repeated_nested_message)) 460 self.assertEqual([-1000, 2, -1, 1, 3], 461 [m.bb for m in msg.repeated_nested_message]) 462 self.assertEqual( 463 str(msg), 'repeated_nested_message {\n' 464 ' bb: -1000\n' 465 '}\n' 466 'repeated_nested_message {\n' 467 ' bb: 2\n' 468 '}\n' 469 'repeated_nested_message {\n' 470 ' bb: -1\n' 471 '}\n' 472 'repeated_nested_message {\n' 473 ' bb: 1\n' 474 '}\n' 475 'repeated_nested_message {\n' 476 ' bb: 3\n' 477 '}\n') 478 self.assertEqual(sub_msg.bb, 1) 479 480 def testMergeFromRepeatedField(self, message_module): 481 msg = message_module.TestAllTypes() 482 msg.repeated_int32.append(1) 483 msg.repeated_int32.append(3) 484 msg.repeated_nested_message.add(bb=1) 485 msg.repeated_nested_message.add(bb=2) 486 other_msg = message_module.TestAllTypes() 487 other_msg.repeated_nested_message.add(bb=3) 488 other_msg.repeated_nested_message.add(bb=4) 489 other_msg.repeated_int32.append(5) 490 other_msg.repeated_int32.append(7) 491 492 msg.repeated_int32.MergeFrom(other_msg.repeated_int32) 493 self.assertEqual(4, len(msg.repeated_int32)) 494 495 msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message) 496 self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message]) 497 498 def testAddWrongRepeatedNestedField(self, message_module): 499 msg = message_module.TestAllTypes() 500 try: 501 msg.repeated_nested_message.add('wrong') 502 except TypeError: 503 pass 504 try: 505 msg.repeated_nested_message.add(value_field='wrong') 506 except ValueError: 507 pass 508 self.assertEqual(len(msg.repeated_nested_message), 0) 509 510 def testRepeatedContains(self, message_module): 511 msg = message_module.TestAllTypes() 512 msg.repeated_int32.extend([1, 2, 3]) 513 self.assertIn(2, msg.repeated_int32) 514 self.assertNotIn(0, msg.repeated_int32) 515 516 msg.repeated_nested_message.add(bb=1) 517 sub_msg1 = msg.repeated_nested_message[0] 518 sub_msg2 = message_module.TestAllTypes.NestedMessage(bb=2) 519 sub_msg3 = message_module.TestAllTypes.NestedMessage(bb=3) 520 msg.repeated_nested_message.append(sub_msg2) 521 msg.repeated_nested_message.insert(0, sub_msg3) 522 self.assertIn(sub_msg1, msg.repeated_nested_message) 523 self.assertIn(sub_msg2, msg.repeated_nested_message) 524 self.assertIn(sub_msg3, msg.repeated_nested_message) 525 526 def testRepeatedScalarIterable(self, message_module): 527 msg = message_module.TestAllTypes() 528 msg.repeated_int32.extend([1, 2, 3]) 529 add = 0 530 for item in msg.repeated_int32: 531 add += item 532 self.assertEqual(add, 6) 533 534 def testRepeatedNestedFieldIteration(self, message_module): 535 msg = message_module.TestAllTypes() 536 msg.repeated_nested_message.add(bb=1) 537 msg.repeated_nested_message.add(bb=2) 538 msg.repeated_nested_message.add(bb=3) 539 msg.repeated_nested_message.add(bb=4) 540 541 self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message]) 542 self.assertEqual([4, 3, 2, 1], 543 [m.bb for m in reversed(msg.repeated_nested_message)]) 544 self.assertEqual([4, 3, 2, 1], 545 [m.bb for m in msg.repeated_nested_message[::-1]]) 546 547 def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module): 548 """Check some different types with the default comparator.""" 549 message = message_module.TestAllTypes() 550 551 # TODO(mattp): would testing more scalar types strengthen test? 552 message.repeated_int32.append(1) 553 message.repeated_int32.append(3) 554 message.repeated_int32.append(2) 555 message.repeated_int32.sort() 556 self.assertEqual(message.repeated_int32[0], 1) 557 self.assertEqual(message.repeated_int32[1], 2) 558 self.assertEqual(message.repeated_int32[2], 3) 559 self.assertEqual(str(message.repeated_int32), str([1, 2, 3])) 560 561 message.repeated_float.append(1.1) 562 message.repeated_float.append(1.3) 563 message.repeated_float.append(1.2) 564 message.repeated_float.sort() 565 self.assertAlmostEqual(message.repeated_float[0], 1.1) 566 self.assertAlmostEqual(message.repeated_float[1], 1.2) 567 self.assertAlmostEqual(message.repeated_float[2], 1.3) 568 569 message.repeated_string.append('a') 570 message.repeated_string.append('c') 571 message.repeated_string.append('b') 572 message.repeated_string.sort() 573 self.assertEqual(message.repeated_string[0], 'a') 574 self.assertEqual(message.repeated_string[1], 'b') 575 self.assertEqual(message.repeated_string[2], 'c') 576 self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c'])) 577 578 message.repeated_bytes.append(b'a') 579 message.repeated_bytes.append(b'c') 580 message.repeated_bytes.append(b'b') 581 message.repeated_bytes.sort() 582 self.assertEqual(message.repeated_bytes[0], b'a') 583 self.assertEqual(message.repeated_bytes[1], b'b') 584 self.assertEqual(message.repeated_bytes[2], b'c') 585 self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c'])) 586 587 def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): 588 """Check some different types with custom comparator.""" 589 message = message_module.TestAllTypes() 590 591 message.repeated_int32.append(-3) 592 message.repeated_int32.append(-2) 593 message.repeated_int32.append(-1) 594 message.repeated_int32.sort(key=abs) 595 self.assertEqual(message.repeated_int32[0], -1) 596 self.assertEqual(message.repeated_int32[1], -2) 597 self.assertEqual(message.repeated_int32[2], -3) 598 599 message.repeated_string.append('aaa') 600 message.repeated_string.append('bb') 601 message.repeated_string.append('c') 602 message.repeated_string.sort(key=len) 603 self.assertEqual(message.repeated_string[0], 'c') 604 self.assertEqual(message.repeated_string[1], 'bb') 605 self.assertEqual(message.repeated_string[2], 'aaa') 606 607 def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module): 608 """Check passing a custom comparator to sort a repeated composite field.""" 609 message = message_module.TestAllTypes() 610 611 message.repeated_nested_message.add().bb = 1 612 message.repeated_nested_message.add().bb = 3 613 message.repeated_nested_message.add().bb = 2 614 message.repeated_nested_message.add().bb = 6 615 message.repeated_nested_message.add().bb = 5 616 message.repeated_nested_message.add().bb = 4 617 message.repeated_nested_message.sort(key=operator.attrgetter('bb')) 618 self.assertEqual(message.repeated_nested_message[0].bb, 1) 619 self.assertEqual(message.repeated_nested_message[1].bb, 2) 620 self.assertEqual(message.repeated_nested_message[2].bb, 3) 621 self.assertEqual(message.repeated_nested_message[3].bb, 4) 622 self.assertEqual(message.repeated_nested_message[4].bb, 5) 623 self.assertEqual(message.repeated_nested_message[5].bb, 6) 624 self.assertEqual( 625 str(message.repeated_nested_message), 626 '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]') 627 628 def testSortingRepeatedCompositeFieldsStable(self, message_module): 629 """Check passing a custom comparator to sort a repeated composite field.""" 630 message = message_module.TestAllTypes() 631 632 message.repeated_nested_message.add().bb = 21 633 message.repeated_nested_message.add().bb = 20 634 message.repeated_nested_message.add().bb = 13 635 message.repeated_nested_message.add().bb = 33 636 message.repeated_nested_message.add().bb = 11 637 message.repeated_nested_message.add().bb = 24 638 message.repeated_nested_message.add().bb = 10 639 message.repeated_nested_message.sort(key=lambda z: z.bb // 10) 640 self.assertEqual([13, 11, 10, 21, 20, 24, 33], 641 [n.bb for n in message.repeated_nested_message]) 642 643 # Make sure that for the C++ implementation, the underlying fields 644 # are actually reordered. 645 pb = message.SerializeToString() 646 message.Clear() 647 message.MergeFromString(pb) 648 self.assertEqual([13, 11, 10, 21, 20, 24, 33], 649 [n.bb for n in message.repeated_nested_message]) 650 651 def testRepeatedCompositeFieldSortArguments(self, message_module): 652 """Check sorting a repeated composite field using list.sort() arguments.""" 653 message = message_module.TestAllTypes() 654 655 get_bb = operator.attrgetter('bb') 656 message.repeated_nested_message.add().bb = 1 657 message.repeated_nested_message.add().bb = 3 658 message.repeated_nested_message.add().bb = 2 659 message.repeated_nested_message.add().bb = 6 660 message.repeated_nested_message.add().bb = 5 661 message.repeated_nested_message.add().bb = 4 662 message.repeated_nested_message.sort(key=get_bb) 663 self.assertEqual([k.bb for k in message.repeated_nested_message], 664 [1, 2, 3, 4, 5, 6]) 665 message.repeated_nested_message.sort(key=get_bb, reverse=True) 666 self.assertEqual([k.bb for k in message.repeated_nested_message], 667 [6, 5, 4, 3, 2, 1]) 668 669 def testRepeatedScalarFieldSortArguments(self, message_module): 670 """Check sorting a scalar field using list.sort() arguments.""" 671 message = message_module.TestAllTypes() 672 673 message.repeated_int32.append(-3) 674 message.repeated_int32.append(-2) 675 message.repeated_int32.append(-1) 676 message.repeated_int32.sort(key=abs) 677 self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) 678 message.repeated_int32.sort(key=abs, reverse=True) 679 self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) 680 681 message.repeated_string.append('aaa') 682 message.repeated_string.append('bb') 683 message.repeated_string.append('c') 684 message.repeated_string.sort(key=len) 685 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) 686 message.repeated_string.sort(key=len, reverse=True) 687 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) 688 689 def testRepeatedFieldsComparable(self, message_module): 690 m1 = message_module.TestAllTypes() 691 m2 = message_module.TestAllTypes() 692 m1.repeated_int32.append(0) 693 m1.repeated_int32.append(1) 694 m1.repeated_int32.append(2) 695 m2.repeated_int32.append(0) 696 m2.repeated_int32.append(1) 697 m2.repeated_int32.append(2) 698 m1.repeated_nested_message.add().bb = 1 699 m1.repeated_nested_message.add().bb = 2 700 m1.repeated_nested_message.add().bb = 3 701 m2.repeated_nested_message.add().bb = 1 702 m2.repeated_nested_message.add().bb = 2 703 m2.repeated_nested_message.add().bb = 3 704 705 def testRepeatedFieldsAreSequences(self, message_module): 706 m = message_module.TestAllTypes() 707 self.assertIsInstance(m.repeated_int32, collections.abc.MutableSequence) 708 self.assertIsInstance(m.repeated_nested_message, 709 collections.abc.MutableSequence) 710 711 def testRepeatedFieldsNotHashable(self, message_module): 712 m = message_module.TestAllTypes() 713 with self.assertRaises(TypeError): 714 hash(m.repeated_int32) 715 with self.assertRaises(TypeError): 716 hash(m.repeated_nested_message) 717 718 def testRepeatedFieldInsideNestedMessage(self, message_module): 719 m = message_module.NestedTestAllTypes() 720 m.payload.repeated_int32.extend([]) 721 self.assertTrue(m.HasField('payload')) 722 723 def testMergeFrom(self, message_module): 724 m1 = message_module.TestAllTypes() 725 m2 = message_module.TestAllTypes() 726 # Cpp extension will lazily create a sub message which is immutable. 727 nested = m1.optional_nested_message 728 self.assertEqual(0, nested.bb) 729 m2.optional_nested_message.bb = 1 730 # Make sure cmessage pointing to a mutable message after merge instead of 731 # the lazily created message. 732 m1.MergeFrom(m2) 733 self.assertEqual(1, nested.bb) 734 735 # Test more nested sub message. 736 msg1 = message_module.NestedTestAllTypes() 737 msg2 = message_module.NestedTestAllTypes() 738 nested = msg1.child.payload.optional_nested_message 739 self.assertEqual(0, nested.bb) 740 msg2.child.payload.optional_nested_message.bb = 1 741 msg1.MergeFrom(msg2) 742 self.assertEqual(1, nested.bb) 743 744 # Test repeated field. 745 self.assertEqual(msg1.payload.repeated_nested_message, 746 msg1.payload.repeated_nested_message) 747 nested = msg2.payload.repeated_nested_message.add() 748 nested.bb = 1 749 msg1.MergeFrom(msg2) 750 self.assertEqual(1, len(msg1.payload.repeated_nested_message)) 751 self.assertEqual(1, nested.bb) 752 753 def testMergeFromString(self, message_module): 754 m1 = message_module.TestAllTypes() 755 m2 = message_module.TestAllTypes() 756 # Cpp extension will lazily create a sub message which is immutable. 757 self.assertEqual(0, m1.optional_nested_message.bb) 758 m2.optional_nested_message.bb = 1 759 # Make sure cmessage pointing to a mutable message after merge instead of 760 # the lazily created message. 761 m1.MergeFromString(m2.SerializeToString()) 762 self.assertEqual(1, m1.optional_nested_message.bb) 763 764 def testMergeFromStringUsingMemoryView(self, message_module): 765 m2 = message_module.TestAllTypes() 766 m2.optional_string = 'scalar string' 767 m2.repeated_string.append('repeated string') 768 m2.optional_bytes = b'scalar bytes' 769 m2.repeated_bytes.append(b'repeated bytes') 770 771 serialized = m2.SerializeToString() 772 memview = memoryview(serialized) 773 m1 = message_module.TestAllTypes.FromString(memview) 774 775 self.assertEqual(m1.optional_bytes, b'scalar bytes') 776 self.assertEqual(m1.repeated_bytes, [b'repeated bytes']) 777 self.assertEqual(m1.optional_string, 'scalar string') 778 self.assertEqual(m1.repeated_string, ['repeated string']) 779 # Make sure that the memoryview was correctly converted to bytes, and 780 # that a sub-sliced memoryview is not being used. 781 self.assertIsInstance(m1.optional_bytes, bytes) 782 self.assertIsInstance(m1.repeated_bytes[0], bytes) 783 self.assertIsInstance(m1.optional_string, str) 784 self.assertIsInstance(m1.repeated_string[0], str) 785 786 def testMergeFromEmpty(self, message_module): 787 m1 = message_module.TestAllTypes() 788 # Cpp extension will lazily create a sub message which is immutable. 789 self.assertEqual(0, m1.optional_nested_message.bb) 790 self.assertFalse(m1.HasField('optional_nested_message')) 791 # Make sure the sub message is still immutable after merge from empty. 792 m1.MergeFromString(b'') # field state should not change 793 self.assertFalse(m1.HasField('optional_nested_message')) 794 795 def ensureNestedMessageExists(self, msg, attribute): 796 """Make sure that a nested message object exists. 797 798 As soon as a nested message attribute is accessed, it will be present in the 799 _fields dict, without being marked as actually being set. 800 """ 801 getattr(msg, attribute) 802 self.assertFalse(msg.HasField(attribute)) 803 804 def testOneofGetCaseNonexistingField(self, message_module): 805 m = message_module.TestAllTypes() 806 self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') 807 self.assertRaises(Exception, m.WhichOneof, 0) 808 809 def testOneofDefaultValues(self, message_module): 810 m = message_module.TestAllTypes() 811 self.assertIs(None, m.WhichOneof('oneof_field')) 812 self.assertFalse(m.HasField('oneof_field')) 813 self.assertFalse(m.HasField('oneof_uint32')) 814 815 # Oneof is set even when setting it to a default value. 816 m.oneof_uint32 = 0 817 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 818 self.assertTrue(m.HasField('oneof_field')) 819 self.assertTrue(m.HasField('oneof_uint32')) 820 self.assertFalse(m.HasField('oneof_string')) 821 822 m.oneof_string = '' 823 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 824 self.assertTrue(m.HasField('oneof_string')) 825 self.assertFalse(m.HasField('oneof_uint32')) 826 827 def testOneofSemantics(self, message_module): 828 m = message_module.TestAllTypes() 829 self.assertIs(None, m.WhichOneof('oneof_field')) 830 831 m.oneof_uint32 = 11 832 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 833 self.assertTrue(m.HasField('oneof_uint32')) 834 835 m.oneof_string = u'foo' 836 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 837 self.assertFalse(m.HasField('oneof_uint32')) 838 self.assertTrue(m.HasField('oneof_string')) 839 840 # Read nested message accessor without accessing submessage. 841 m.oneof_nested_message 842 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 843 self.assertTrue(m.HasField('oneof_string')) 844 self.assertFalse(m.HasField('oneof_nested_message')) 845 846 # Read accessor of nested message without accessing submessage. 847 m.oneof_nested_message.bb 848 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 849 self.assertTrue(m.HasField('oneof_string')) 850 self.assertFalse(m.HasField('oneof_nested_message')) 851 852 m.oneof_nested_message.bb = 11 853 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 854 self.assertFalse(m.HasField('oneof_string')) 855 self.assertTrue(m.HasField('oneof_nested_message')) 856 857 m.oneof_bytes = b'bb' 858 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 859 self.assertFalse(m.HasField('oneof_nested_message')) 860 self.assertTrue(m.HasField('oneof_bytes')) 861 862 def testOneofCompositeFieldReadAccess(self, message_module): 863 m = message_module.TestAllTypes() 864 m.oneof_uint32 = 11 865 866 self.ensureNestedMessageExists(m, 'oneof_nested_message') 867 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 868 self.assertEqual(11, m.oneof_uint32) 869 870 def testOneofWhichOneof(self, message_module): 871 m = message_module.TestAllTypes() 872 self.assertIs(None, m.WhichOneof('oneof_field')) 873 if message_module is unittest_pb2: 874 self.assertFalse(m.HasField('oneof_field')) 875 876 m.oneof_uint32 = 11 877 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 878 if message_module is unittest_pb2: 879 self.assertTrue(m.HasField('oneof_field')) 880 881 m.oneof_bytes = b'bb' 882 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 883 884 m.ClearField('oneof_bytes') 885 self.assertIs(None, m.WhichOneof('oneof_field')) 886 if message_module is unittest_pb2: 887 self.assertFalse(m.HasField('oneof_field')) 888 889 def testOneofClearField(self, message_module): 890 m = message_module.TestAllTypes() 891 m.oneof_uint32 = 11 892 m.ClearField('oneof_field') 893 if message_module is unittest_pb2: 894 self.assertFalse(m.HasField('oneof_field')) 895 self.assertFalse(m.HasField('oneof_uint32')) 896 self.assertIs(None, m.WhichOneof('oneof_field')) 897 898 def testOneofClearSetField(self, message_module): 899 m = message_module.TestAllTypes() 900 m.oneof_uint32 = 11 901 m.ClearField('oneof_uint32') 902 if message_module is unittest_pb2: 903 self.assertFalse(m.HasField('oneof_field')) 904 self.assertFalse(m.HasField('oneof_uint32')) 905 self.assertIs(None, m.WhichOneof('oneof_field')) 906 907 def testOneofClearUnsetField(self, message_module): 908 m = message_module.TestAllTypes() 909 m.oneof_uint32 = 11 910 self.ensureNestedMessageExists(m, 'oneof_nested_message') 911 m.ClearField('oneof_nested_message') 912 self.assertEqual(11, m.oneof_uint32) 913 if message_module is unittest_pb2: 914 self.assertTrue(m.HasField('oneof_field')) 915 self.assertTrue(m.HasField('oneof_uint32')) 916 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 917 918 def testOneofDeserialize(self, message_module): 919 m = message_module.TestAllTypes() 920 m.oneof_uint32 = 11 921 m2 = message_module.TestAllTypes() 922 m2.ParseFromString(m.SerializeToString()) 923 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 924 925 def testOneofCopyFrom(self, message_module): 926 m = message_module.TestAllTypes() 927 m.oneof_uint32 = 11 928 m2 = message_module.TestAllTypes() 929 m2.CopyFrom(m) 930 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 931 932 def testOneofNestedMergeFrom(self, message_module): 933 m = message_module.NestedTestAllTypes() 934 m.payload.oneof_uint32 = 11 935 m2 = message_module.NestedTestAllTypes() 936 m2.payload.oneof_bytes = b'bb' 937 m2.child.payload.oneof_bytes = b'bb' 938 m2.MergeFrom(m) 939 self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field')) 940 self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field')) 941 942 def testOneofMessageMergeFrom(self, message_module): 943 m = message_module.NestedTestAllTypes() 944 m.payload.oneof_nested_message.bb = 11 945 m.child.payload.oneof_nested_message.bb = 12 946 m2 = message_module.NestedTestAllTypes() 947 m2.payload.oneof_uint32 = 13 948 m2.MergeFrom(m) 949 self.assertEqual('oneof_nested_message', 950 m2.payload.WhichOneof('oneof_field')) 951 self.assertEqual('oneof_nested_message', 952 m2.child.payload.WhichOneof('oneof_field')) 953 954 def testOneofNestedMessageInit(self, message_module): 955 m = message_module.TestAllTypes( 956 oneof_nested_message=message_module.TestAllTypes.NestedMessage()) 957 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 958 959 def testOneofClear(self, message_module): 960 m = message_module.TestAllTypes() 961 m.oneof_uint32 = 11 962 m.Clear() 963 self.assertIsNone(m.WhichOneof('oneof_field')) 964 m.oneof_bytes = b'bb' 965 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 966 967 def testAssignByteStringToUnicodeField(self, message_module): 968 """Assigning a byte string to a string field should result 969 970 in the value being converted to a Unicode string. 971 """ 972 m = message_module.TestAllTypes() 973 m.optional_string = str('') 974 self.assertIsInstance(m.optional_string, str) 975 976 def testLongValuedSlice(self, message_module): 977 """It should be possible to use int-valued indices in slices. 978 979 This didn't used to work in the v2 C++ implementation. 980 """ 981 m = message_module.TestAllTypes() 982 983 # Repeated scalar 984 m.repeated_int32.append(1) 985 sl = m.repeated_int32[int(0):int(len(m.repeated_int32))] 986 self.assertEqual(len(m.repeated_int32), len(sl)) 987 988 # Repeated composite 989 m.repeated_nested_message.add().bb = 3 990 sl = m.repeated_nested_message[int(0):int(len(m.repeated_nested_message))] 991 self.assertEqual(len(m.repeated_nested_message), len(sl)) 992 993 def testExtendShouldNotSwallowExceptions(self, message_module): 994 """This didn't use to work in the v2 C++ implementation.""" 995 m = message_module.TestAllTypes() 996 with self.assertRaises(NameError) as _: 997 m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable 998 with self.assertRaises(NameError) as _: 999 m.repeated_nested_enum.extend(a for i in range(10)) # pylint: disable=undefined-variable 1000 1001 FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()] 1002 1003 def testExtendInt32WithNothing(self, message_module): 1004 """Test no-ops extending repeated int32 fields.""" 1005 m = message_module.TestAllTypes() 1006 self.assertSequenceEqual([], m.repeated_int32) 1007 1008 # TODO(ptucker): Deprecate this behavior. b/18413862 1009 for falsy_value in MessageTest.FALSY_VALUES: 1010 m.repeated_int32.extend(falsy_value) 1011 self.assertSequenceEqual([], m.repeated_int32) 1012 1013 m.repeated_int32.extend([]) 1014 self.assertSequenceEqual([], m.repeated_int32) 1015 1016 def testExtendFloatWithNothing(self, message_module): 1017 """Test no-ops extending repeated float fields.""" 1018 m = message_module.TestAllTypes() 1019 self.assertSequenceEqual([], m.repeated_float) 1020 1021 # TODO(ptucker): Deprecate this behavior. b/18413862 1022 for falsy_value in MessageTest.FALSY_VALUES: 1023 m.repeated_float.extend(falsy_value) 1024 self.assertSequenceEqual([], m.repeated_float) 1025 1026 m.repeated_float.extend([]) 1027 self.assertSequenceEqual([], m.repeated_float) 1028 1029 def testExtendStringWithNothing(self, message_module): 1030 """Test no-ops extending repeated string fields.""" 1031 m = message_module.TestAllTypes() 1032 self.assertSequenceEqual([], m.repeated_string) 1033 1034 # TODO(ptucker): Deprecate this behavior. b/18413862 1035 for falsy_value in MessageTest.FALSY_VALUES: 1036 m.repeated_string.extend(falsy_value) 1037 self.assertSequenceEqual([], m.repeated_string) 1038 1039 m.repeated_string.extend([]) 1040 self.assertSequenceEqual([], m.repeated_string) 1041 1042 def testExtendInt32WithPythonList(self, message_module): 1043 """Test extending repeated int32 fields with python lists.""" 1044 m = message_module.TestAllTypes() 1045 self.assertSequenceEqual([], m.repeated_int32) 1046 m.repeated_int32.extend([0]) 1047 self.assertSequenceEqual([0], m.repeated_int32) 1048 m.repeated_int32.extend([1, 2]) 1049 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 1050 m.repeated_int32.extend([3, 4]) 1051 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 1052 1053 def testExtendFloatWithPythonList(self, message_module): 1054 """Test extending repeated float fields with python lists.""" 1055 m = message_module.TestAllTypes() 1056 self.assertSequenceEqual([], m.repeated_float) 1057 m.repeated_float.extend([0.0]) 1058 self.assertSequenceEqual([0.0], m.repeated_float) 1059 m.repeated_float.extend([1.0, 2.0]) 1060 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 1061 m.repeated_float.extend([3.0, 4.0]) 1062 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 1063 1064 def testExtendStringWithPythonList(self, message_module): 1065 """Test extending repeated string fields with python lists.""" 1066 m = message_module.TestAllTypes() 1067 self.assertSequenceEqual([], m.repeated_string) 1068 m.repeated_string.extend(['']) 1069 self.assertSequenceEqual([''], m.repeated_string) 1070 m.repeated_string.extend(['11', '22']) 1071 self.assertSequenceEqual(['', '11', '22'], m.repeated_string) 1072 m.repeated_string.extend(['33', '44']) 1073 self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string) 1074 1075 def testExtendStringWithString(self, message_module): 1076 """Test extending repeated string fields with characters from a string.""" 1077 m = message_module.TestAllTypes() 1078 self.assertSequenceEqual([], m.repeated_string) 1079 m.repeated_string.extend('abc') 1080 self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string) 1081 1082 class TestIterable(object): 1083 """This iterable object mimics the behavior of numpy.array. 1084 1085 __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1. 1086 1087 """ 1088 1089 def __init__(self, values=None): 1090 self._list = values or [] 1091 1092 def __nonzero__(self): 1093 size = len(self._list) 1094 if size == 0: 1095 return False 1096 if size == 1: 1097 return bool(self._list[0]) 1098 raise ValueError('Truth value is ambiguous.') 1099 1100 def __len__(self): 1101 return len(self._list) 1102 1103 def __iter__(self): 1104 return self._list.__iter__() 1105 1106 def testExtendInt32WithIterable(self, message_module): 1107 """Test extending repeated int32 fields with iterable.""" 1108 m = message_module.TestAllTypes() 1109 self.assertSequenceEqual([], m.repeated_int32) 1110 m.repeated_int32.extend(MessageTest.TestIterable([])) 1111 self.assertSequenceEqual([], m.repeated_int32) 1112 m.repeated_int32.extend(MessageTest.TestIterable([0])) 1113 self.assertSequenceEqual([0], m.repeated_int32) 1114 m.repeated_int32.extend(MessageTest.TestIterable([1, 2])) 1115 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 1116 m.repeated_int32.extend(MessageTest.TestIterable([3, 4])) 1117 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 1118 1119 def testExtendFloatWithIterable(self, message_module): 1120 """Test extending repeated float fields with iterable.""" 1121 m = message_module.TestAllTypes() 1122 self.assertSequenceEqual([], m.repeated_float) 1123 m.repeated_float.extend(MessageTest.TestIterable([])) 1124 self.assertSequenceEqual([], m.repeated_float) 1125 m.repeated_float.extend(MessageTest.TestIterable([0.0])) 1126 self.assertSequenceEqual([0.0], m.repeated_float) 1127 m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0])) 1128 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 1129 m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0])) 1130 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 1131 1132 def testExtendStringWithIterable(self, message_module): 1133 """Test extending repeated string fields with iterable.""" 1134 m = message_module.TestAllTypes() 1135 self.assertSequenceEqual([], m.repeated_string) 1136 m.repeated_string.extend(MessageTest.TestIterable([])) 1137 self.assertSequenceEqual([], m.repeated_string) 1138 m.repeated_string.extend(MessageTest.TestIterable([''])) 1139 self.assertSequenceEqual([''], m.repeated_string) 1140 m.repeated_string.extend(MessageTest.TestIterable(['1', '2'])) 1141 self.assertSequenceEqual(['', '1', '2'], m.repeated_string) 1142 m.repeated_string.extend(MessageTest.TestIterable(['3', '4'])) 1143 self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string) 1144 1145 class TestIndex(object): 1146 """This index object mimics the behavior of numpy.int64 and other types.""" 1147 1148 def __init__(self, value=None): 1149 self.value = value 1150 1151 def __index__(self): 1152 return self.value 1153 1154 def testRepeatedIndexingWithIntIndex(self, message_module): 1155 msg = message_module.TestAllTypes() 1156 msg.repeated_int32.extend([1, 2, 3]) 1157 self.assertEqual(1, msg.repeated_int32[MessageTest.TestIndex(0)]) 1158 1159 def testRepeatedIndexingWithNegative1IntIndex(self, message_module): 1160 msg = message_module.TestAllTypes() 1161 msg.repeated_int32.extend([1, 2, 3]) 1162 self.assertEqual(3, msg.repeated_int32[MessageTest.TestIndex(-1)]) 1163 1164 def testRepeatedIndexingWithNegative1Int(self, message_module): 1165 msg = message_module.TestAllTypes() 1166 msg.repeated_int32.extend([1, 2, 3]) 1167 self.assertEqual(3, msg.repeated_int32[-1]) 1168 1169 def testPickleRepeatedScalarContainer(self, message_module): 1170 # Pickle repeated scalar container is not supported. 1171 m = message_module.TestAllTypes() 1172 with self.assertRaises(pickle.PickleError) as _: 1173 pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) 1174 1175 def testSortEmptyRepeatedCompositeContainer(self, message_module): 1176 """Exercise a scenario that has led to segfaults in the past.""" 1177 m = message_module.TestAllTypes() 1178 m.repeated_nested_message.sort() 1179 1180 def testHasFieldOnRepeatedField(self, message_module): 1181 """Using HasField on a repeated field should raise an exception.""" 1182 m = message_module.TestAllTypes() 1183 with self.assertRaises(ValueError) as _: 1184 m.HasField('repeated_int32') 1185 1186 def testRepeatedScalarFieldPop(self, message_module): 1187 m = message_module.TestAllTypes() 1188 with self.assertRaises(IndexError) as _: 1189 m.repeated_int32.pop() 1190 m.repeated_int32.extend(range(5)) 1191 self.assertEqual(4, m.repeated_int32.pop()) 1192 self.assertEqual(0, m.repeated_int32.pop(0)) 1193 self.assertEqual(2, m.repeated_int32.pop(1)) 1194 self.assertEqual([1, 3], m.repeated_int32) 1195 1196 def testRepeatedCompositeFieldPop(self, message_module): 1197 m = message_module.TestAllTypes() 1198 with self.assertRaises(IndexError) as _: 1199 m.repeated_nested_message.pop() 1200 with self.assertRaises(TypeError) as _: 1201 m.repeated_nested_message.pop('0') 1202 for i in range(5): 1203 n = m.repeated_nested_message.add() 1204 n.bb = i 1205 self.assertEqual(4, m.repeated_nested_message.pop().bb) 1206 self.assertEqual(0, m.repeated_nested_message.pop(0).bb) 1207 self.assertEqual(2, m.repeated_nested_message.pop(1).bb) 1208 self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message]) 1209 1210 def testRepeatedCompareWithSelf(self, message_module): 1211 m = message_module.TestAllTypes() 1212 for i in range(5): 1213 m.repeated_int32.insert(i, i) 1214 n = m.repeated_nested_message.add() 1215 n.bb = i 1216 self.assertSequenceEqual(m.repeated_int32, m.repeated_int32) 1217 self.assertEqual(m.repeated_nested_message, m.repeated_nested_message) 1218 1219 def testReleasedNestedMessages(self, message_module): 1220 """A case that lead to a segfault when a message detached from its parent 1221 1222 container has itself a child container. 1223 """ 1224 m = message_module.NestedTestAllTypes() 1225 m = m.repeated_child.add() 1226 m = m.child 1227 m = m.repeated_child.add() 1228 self.assertEqual(m.payload.optional_int32, 0) 1229 1230 def testSetRepeatedComposite(self, message_module): 1231 m = message_module.TestAllTypes() 1232 with self.assertRaises(AttributeError): 1233 m.repeated_int32 = [] 1234 m.repeated_int32.append(1) 1235 with self.assertRaises(AttributeError): 1236 m.repeated_int32 = [] 1237 1238 def testReturningType(self, message_module): 1239 m = message_module.TestAllTypes() 1240 self.assertEqual(float, type(m.optional_float)) 1241 self.assertEqual(float, type(m.optional_double)) 1242 self.assertEqual(bool, type(m.optional_bool)) 1243 m.optional_float = 1 1244 m.optional_double = 1 1245 m.optional_bool = 1 1246 m.repeated_float.append(1) 1247 m.repeated_double.append(1) 1248 m.repeated_bool.append(1) 1249 m.ParseFromString(m.SerializeToString()) 1250 self.assertEqual(float, type(m.optional_float)) 1251 self.assertEqual(float, type(m.optional_double)) 1252 self.assertEqual('1.0', str(m.optional_double)) 1253 self.assertEqual(bool, type(m.optional_bool)) 1254 self.assertEqual(float, type(m.repeated_float[0])) 1255 self.assertEqual(float, type(m.repeated_double[0])) 1256 self.assertEqual(bool, type(m.repeated_bool[0])) 1257 self.assertEqual(True, m.repeated_bool[0]) 1258 1259 1260# Class to test proto2-only features (required, extensions, etc.) 1261@testing_refleaks.TestCase 1262class Proto2Test(unittest.TestCase): 1263 1264 def testFieldPresence(self): 1265 message = unittest_pb2.TestAllTypes() 1266 1267 self.assertFalse(message.HasField('optional_int32')) 1268 self.assertFalse(message.HasField('optional_bool')) 1269 self.assertFalse(message.HasField('optional_nested_message')) 1270 1271 with self.assertRaises(ValueError): 1272 message.HasField('field_doesnt_exist') 1273 1274 with self.assertRaises(ValueError): 1275 message.HasField('repeated_int32') 1276 with self.assertRaises(ValueError): 1277 message.HasField('repeated_nested_message') 1278 1279 self.assertEqual(0, message.optional_int32) 1280 self.assertEqual(False, message.optional_bool) 1281 self.assertEqual(0, message.optional_nested_message.bb) 1282 1283 # Fields are set even when setting the values to default values. 1284 message.optional_int32 = 0 1285 message.optional_bool = False 1286 message.optional_nested_message.bb = 0 1287 self.assertTrue(message.HasField('optional_int32')) 1288 self.assertTrue(message.HasField('optional_bool')) 1289 self.assertTrue(message.HasField('optional_nested_message')) 1290 1291 # Set the fields to non-default values. 1292 message.optional_int32 = 5 1293 message.optional_bool = True 1294 message.optional_nested_message.bb = 15 1295 1296 self.assertTrue(message.HasField(u'optional_int32')) 1297 self.assertTrue(message.HasField('optional_bool')) 1298 self.assertTrue(message.HasField('optional_nested_message')) 1299 1300 # Clearing the fields unsets them and resets their value to default. 1301 message.ClearField('optional_int32') 1302 message.ClearField(u'optional_bool') 1303 message.ClearField('optional_nested_message') 1304 1305 self.assertFalse(message.HasField('optional_int32')) 1306 self.assertFalse(message.HasField('optional_bool')) 1307 self.assertFalse(message.HasField('optional_nested_message')) 1308 self.assertEqual(0, message.optional_int32) 1309 self.assertEqual(False, message.optional_bool) 1310 self.assertEqual(0, message.optional_nested_message.bb) 1311 1312 def testAssignInvalidEnum(self): 1313 """Assigning an invalid enum number is not allowed in proto2.""" 1314 m = unittest_pb2.TestAllTypes() 1315 1316 # Proto2 can not assign unknown enum. 1317 with self.assertRaises(ValueError) as _: 1318 m.optional_nested_enum = 1234567 1319 self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567) 1320 # Assignment is a different code path than append for the C++ impl. 1321 m.repeated_nested_enum.append(2) 1322 m.repeated_nested_enum[0] = 2 1323 with self.assertRaises(ValueError): 1324 m.repeated_nested_enum[0] = 123456 1325 1326 # Unknown enum value can be parsed but is ignored. 1327 m2 = unittest_proto3_arena_pb2.TestAllTypes() 1328 m2.optional_nested_enum = 1234567 1329 m2.repeated_nested_enum.append(7654321) 1330 serialized = m2.SerializeToString() 1331 1332 m3 = unittest_pb2.TestAllTypes() 1333 m3.ParseFromString(serialized) 1334 self.assertFalse(m3.HasField('optional_nested_enum')) 1335 # 1 is the default value for optional_nested_enum. 1336 self.assertEqual(1, m3.optional_nested_enum) 1337 self.assertEqual(0, len(m3.repeated_nested_enum)) 1338 m2.Clear() 1339 m2.ParseFromString(m3.SerializeToString()) 1340 self.assertEqual(1234567, m2.optional_nested_enum) 1341 self.assertEqual(7654321, m2.repeated_nested_enum[0]) 1342 1343 def testUnknownEnumMap(self): 1344 m = map_proto2_unittest_pb2.TestEnumMap() 1345 m.known_map_field[123] = 0 1346 with self.assertRaises(ValueError): 1347 m.unknown_map_field[1] = 123 1348 1349 def testExtensionsErrors(self): 1350 msg = unittest_pb2.TestAllTypes() 1351 self.assertRaises(AttributeError, getattr, msg, 'Extensions') 1352 1353 def testMergeFromExtensions(self): 1354 msg1 = more_extensions_pb2.TopLevelMessage() 1355 msg2 = more_extensions_pb2.TopLevelMessage() 1356 # Cpp extension will lazily create a sub message which is immutable. 1357 self.assertEqual( 1358 0, 1359 msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension]) 1360 self.assertFalse(msg1.HasField('submessage')) 1361 msg2.submessage.Extensions[more_extensions_pb2.optional_int_extension] = 123 1362 # Make sure cmessage and extensions pointing to a mutable message 1363 # after merge instead of the lazily created message. 1364 msg1.MergeFrom(msg2) 1365 self.assertEqual( 1366 123, 1367 msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension]) 1368 1369 def testGoldenExtensions(self): 1370 golden_data = test_util.GoldenFileData('golden_message') 1371 golden_message = unittest_pb2.TestAllExtensions() 1372 golden_message.ParseFromString(golden_data) 1373 all_set = unittest_pb2.TestAllExtensions() 1374 test_util.SetAllExtensions(all_set) 1375 self.assertEqual(all_set, golden_message) 1376 self.assertEqual(golden_data, golden_message.SerializeToString()) 1377 golden_copy = copy.deepcopy(golden_message) 1378 self.assertEqual(golden_data, golden_copy.SerializeToString()) 1379 1380 def testGoldenPackedExtensions(self): 1381 golden_data = test_util.GoldenFileData('golden_packed_fields_message') 1382 golden_message = unittest_pb2.TestPackedExtensions() 1383 golden_message.ParseFromString(golden_data) 1384 all_set = unittest_pb2.TestPackedExtensions() 1385 test_util.SetAllPackedExtensions(all_set) 1386 self.assertEqual(all_set, golden_message) 1387 self.assertEqual(golden_data, all_set.SerializeToString()) 1388 golden_copy = copy.deepcopy(golden_message) 1389 self.assertEqual(golden_data, golden_copy.SerializeToString()) 1390 1391 def testPickleIncompleteProto(self): 1392 golden_message = unittest_pb2.TestRequired(a=1) 1393 pickled_message = pickle.dumps(golden_message) 1394 1395 unpickled_message = pickle.loads(pickled_message) 1396 self.assertEqual(unpickled_message, golden_message) 1397 self.assertEqual(unpickled_message.a, 1) 1398 # This is still an incomplete proto - so serializing should fail 1399 self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) 1400 1401 # TODO(haberman): this isn't really a proto2-specific test except that this 1402 # message has a required field in it. Should probably be factored out so 1403 # that we can test the other parts with proto3. 1404 def testParsingMerge(self): 1405 """Check the merge behavior when a required or optional field appears 1406 1407 multiple times in the input. 1408 """ 1409 messages = [ 1410 unittest_pb2.TestAllTypes(), 1411 unittest_pb2.TestAllTypes(), 1412 unittest_pb2.TestAllTypes() 1413 ] 1414 messages[0].optional_int32 = 1 1415 messages[1].optional_int64 = 2 1416 messages[2].optional_int32 = 3 1417 messages[2].optional_string = 'hello' 1418 1419 merged_message = unittest_pb2.TestAllTypes() 1420 merged_message.optional_int32 = 3 1421 merged_message.optional_int64 = 2 1422 merged_message.optional_string = 'hello' 1423 1424 generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() 1425 generator.field1.extend(messages) 1426 generator.field2.extend(messages) 1427 generator.field3.extend(messages) 1428 generator.ext1.extend(messages) 1429 generator.ext2.extend(messages) 1430 generator.group1.add().field1.MergeFrom(messages[0]) 1431 generator.group1.add().field1.MergeFrom(messages[1]) 1432 generator.group1.add().field1.MergeFrom(messages[2]) 1433 generator.group2.add().field1.MergeFrom(messages[0]) 1434 generator.group2.add().field1.MergeFrom(messages[1]) 1435 generator.group2.add().field1.MergeFrom(messages[2]) 1436 1437 data = generator.SerializeToString() 1438 parsing_merge = unittest_pb2.TestParsingMerge() 1439 parsing_merge.ParseFromString(data) 1440 1441 # Required and optional fields should be merged. 1442 self.assertEqual(parsing_merge.required_all_types, merged_message) 1443 self.assertEqual(parsing_merge.optional_all_types, merged_message) 1444 self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, 1445 merged_message) 1446 self.assertEqual( 1447 parsing_merge.Extensions[unittest_pb2.TestParsingMerge.optional_ext], 1448 merged_message) 1449 1450 # Repeated fields should not be merged. 1451 self.assertEqual(len(parsing_merge.repeated_all_types), 3) 1452 self.assertEqual(len(parsing_merge.repeatedgroup), 3) 1453 self.assertEqual( 1454 len(parsing_merge.Extensions[ 1455 unittest_pb2.TestParsingMerge.repeated_ext]), 3) 1456 1457 def testPythonicInit(self): 1458 message = unittest_pb2.TestAllTypes( 1459 optional_int32=100, 1460 optional_fixed32=200, 1461 optional_float=300.5, 1462 optional_bytes=b'x', 1463 optionalgroup={'a': 400}, 1464 optional_nested_message={'bb': 500}, 1465 optional_foreign_message={}, 1466 optional_nested_enum='BAZ', 1467 repeatedgroup=[{ 1468 'a': 600 1469 }, { 1470 'a': 700 1471 }], 1472 repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], 1473 default_int32=800, 1474 oneof_string='y') 1475 self.assertIsInstance(message, unittest_pb2.TestAllTypes) 1476 self.assertEqual(100, message.optional_int32) 1477 self.assertEqual(200, message.optional_fixed32) 1478 self.assertEqual(300.5, message.optional_float) 1479 self.assertEqual(b'x', message.optional_bytes) 1480 self.assertEqual(400, message.optionalgroup.a) 1481 self.assertIsInstance(message.optional_nested_message, 1482 unittest_pb2.TestAllTypes.NestedMessage) 1483 self.assertEqual(500, message.optional_nested_message.bb) 1484 self.assertTrue(message.HasField('optional_foreign_message')) 1485 self.assertEqual(message.optional_foreign_message, 1486 unittest_pb2.ForeignMessage()) 1487 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1488 message.optional_nested_enum) 1489 self.assertEqual(2, len(message.repeatedgroup)) 1490 self.assertEqual(600, message.repeatedgroup[0].a) 1491 self.assertEqual(700, message.repeatedgroup[1].a) 1492 self.assertEqual(2, len(message.repeated_nested_enum)) 1493 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 1494 message.repeated_nested_enum[0]) 1495 self.assertEqual(unittest_pb2.TestAllTypes.BAR, 1496 message.repeated_nested_enum[1]) 1497 self.assertEqual(800, message.default_int32) 1498 self.assertEqual('y', message.oneof_string) 1499 self.assertFalse(message.HasField('optional_int64')) 1500 self.assertEqual(0, len(message.repeated_float)) 1501 self.assertEqual(42, message.default_int64) 1502 1503 message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ') 1504 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1505 message.optional_nested_enum) 1506 1507 with self.assertRaises(ValueError): 1508 unittest_pb2.TestAllTypes( 1509 optional_nested_message={'INVALID_NESTED_FIELD': 17}) 1510 1511 with self.assertRaises(TypeError): 1512 unittest_pb2.TestAllTypes( 1513 optional_nested_message={'bb': 'INVALID_VALUE_TYPE'}) 1514 1515 with self.assertRaises(ValueError): 1516 unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL') 1517 1518 with self.assertRaises(ValueError): 1519 unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') 1520 1521 def testPythonicInitWithDict(self): 1522 # Both string/unicode field name keys should work. 1523 kwargs = { 1524 'optional_int32': 100, 1525 u'optional_fixed32': 200, 1526 } 1527 msg = unittest_pb2.TestAllTypes(**kwargs) 1528 self.assertEqual(100, msg.optional_int32) 1529 self.assertEqual(200, msg.optional_fixed32) 1530 1531 1532 def test_documentation(self): 1533 # Also used by the interactive help() function. 1534 doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message') 1535 self.assertIn('class TestAllTypes', doc) 1536 self.assertIn('SerializePartialToString', doc) 1537 self.assertIn('repeated_float', doc) 1538 base = unittest_pb2.TestAllTypes.__bases__[0] 1539 self.assertRaises(AttributeError, getattr, base, '_extensions_by_name') 1540 1541 1542# Class to test proto3-only features/behavior (updated field presence & enums) 1543@testing_refleaks.TestCase 1544class Proto3Test(unittest.TestCase): 1545 1546 # Utility method for comparing equality with a map. 1547 def assertMapIterEquals(self, map_iter, dict_value): 1548 # Avoid mutating caller's copy. 1549 dict_value = dict(dict_value) 1550 1551 for k, v in map_iter: 1552 self.assertEqual(v, dict_value[k]) 1553 del dict_value[k] 1554 1555 self.assertEqual({}, dict_value) 1556 1557 def testFieldPresence(self): 1558 message = unittest_proto3_arena_pb2.TestAllTypes() 1559 1560 # We can't test presence of non-repeated, non-submessage fields. 1561 with self.assertRaises(ValueError): 1562 message.HasField('optional_int32') 1563 with self.assertRaises(ValueError): 1564 message.HasField('optional_float') 1565 with self.assertRaises(ValueError): 1566 message.HasField('optional_string') 1567 with self.assertRaises(ValueError): 1568 message.HasField('optional_bool') 1569 1570 # But we can still test presence of submessage fields. 1571 self.assertFalse(message.HasField('optional_nested_message')) 1572 1573 # As with proto2, we can't test presence of fields that don't exist, or 1574 # repeated fields. 1575 with self.assertRaises(ValueError): 1576 message.HasField('field_doesnt_exist') 1577 1578 with self.assertRaises(ValueError): 1579 message.HasField('repeated_int32') 1580 with self.assertRaises(ValueError): 1581 message.HasField('repeated_nested_message') 1582 1583 # Fields should default to their type-specific default. 1584 self.assertEqual(0, message.optional_int32) 1585 self.assertEqual(0, message.optional_float) 1586 self.assertEqual('', message.optional_string) 1587 self.assertEqual(False, message.optional_bool) 1588 self.assertEqual(0, message.optional_nested_message.bb) 1589 1590 # Setting a submessage should still return proper presence information. 1591 message.optional_nested_message.bb = 0 1592 self.assertTrue(message.HasField('optional_nested_message')) 1593 1594 # Set the fields to non-default values. 1595 message.optional_int32 = 5 1596 message.optional_float = 1.1 1597 message.optional_string = 'abc' 1598 message.optional_bool = True 1599 message.optional_nested_message.bb = 15 1600 1601 # Clearing the fields unsets them and resets their value to default. 1602 message.ClearField('optional_int32') 1603 message.ClearField('optional_float') 1604 message.ClearField('optional_string') 1605 message.ClearField('optional_bool') 1606 message.ClearField('optional_nested_message') 1607 1608 self.assertEqual(0, message.optional_int32) 1609 self.assertEqual(0, message.optional_float) 1610 self.assertEqual('', message.optional_string) 1611 self.assertEqual(False, message.optional_bool) 1612 self.assertEqual(0, message.optional_nested_message.bb) 1613 1614 def testProto3ParserDropDefaultScalar(self): 1615 message_proto2 = unittest_pb2.TestAllTypes() 1616 message_proto2.optional_int32 = 0 1617 message_proto2.optional_string = '' 1618 message_proto2.optional_bytes = b'' 1619 self.assertEqual(len(message_proto2.ListFields()), 3) 1620 1621 message_proto3 = unittest_proto3_arena_pb2.TestAllTypes() 1622 message_proto3.ParseFromString(message_proto2.SerializeToString()) 1623 self.assertEqual(len(message_proto3.ListFields()), 0) 1624 1625 def testProto3Optional(self): 1626 msg = test_proto3_optional_pb2.TestProto3Optional() 1627 self.assertFalse(msg.HasField('optional_int32')) 1628 self.assertFalse(msg.HasField('optional_float')) 1629 self.assertFalse(msg.HasField('optional_string')) 1630 self.assertFalse(msg.HasField('optional_nested_message')) 1631 self.assertFalse(msg.optional_nested_message.HasField('bb')) 1632 1633 # Set fields. 1634 msg.optional_int32 = 1 1635 msg.optional_float = 1.0 1636 msg.optional_string = '123' 1637 msg.optional_nested_message.bb = 1 1638 self.assertTrue(msg.HasField('optional_int32')) 1639 self.assertTrue(msg.HasField('optional_float')) 1640 self.assertTrue(msg.HasField('optional_string')) 1641 self.assertTrue(msg.HasField('optional_nested_message')) 1642 self.assertTrue(msg.optional_nested_message.HasField('bb')) 1643 # Set to default value does not clear the fields 1644 msg.optional_int32 = 0 1645 msg.optional_float = 0.0 1646 msg.optional_string = '' 1647 msg.optional_nested_message.bb = 0 1648 self.assertTrue(msg.HasField('optional_int32')) 1649 self.assertTrue(msg.HasField('optional_float')) 1650 self.assertTrue(msg.HasField('optional_string')) 1651 self.assertTrue(msg.HasField('optional_nested_message')) 1652 self.assertTrue(msg.optional_nested_message.HasField('bb')) 1653 1654 # Test serialize 1655 msg2 = test_proto3_optional_pb2.TestProto3Optional() 1656 msg2.ParseFromString(msg.SerializeToString()) 1657 self.assertTrue(msg2.HasField('optional_int32')) 1658 self.assertTrue(msg2.HasField('optional_float')) 1659 self.assertTrue(msg2.HasField('optional_string')) 1660 self.assertTrue(msg2.HasField('optional_nested_message')) 1661 self.assertTrue(msg2.optional_nested_message.HasField('bb')) 1662 1663 self.assertEqual(msg.WhichOneof('_optional_int32'), 'optional_int32') 1664 1665 # Clear these fields. 1666 msg.ClearField('optional_int32') 1667 msg.ClearField('optional_float') 1668 msg.ClearField('optional_string') 1669 msg.ClearField('optional_nested_message') 1670 self.assertFalse(msg.HasField('optional_int32')) 1671 self.assertFalse(msg.HasField('optional_float')) 1672 self.assertFalse(msg.HasField('optional_string')) 1673 self.assertFalse(msg.HasField('optional_nested_message')) 1674 self.assertFalse(msg.optional_nested_message.HasField('bb')) 1675 1676 self.assertEqual(msg.WhichOneof('_optional_int32'), None) 1677 1678 # Test has presence: 1679 for field in test_proto3_optional_pb2.TestProto3Optional.DESCRIPTOR.fields: 1680 self.assertTrue(field.has_presence) 1681 for field in unittest_pb2.TestAllTypes.DESCRIPTOR.fields: 1682 if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: 1683 self.assertFalse(field.has_presence) 1684 else: 1685 self.assertTrue(field.has_presence) 1686 proto3_descriptor = unittest_proto3_arena_pb2.TestAllTypes.DESCRIPTOR 1687 repeated_field = proto3_descriptor.fields_by_name['repeated_int32'] 1688 self.assertFalse(repeated_field.has_presence) 1689 singular_field = proto3_descriptor.fields_by_name['optional_int32'] 1690 self.assertFalse(singular_field.has_presence) 1691 optional_field = proto3_descriptor.fields_by_name['proto3_optional_int32'] 1692 self.assertTrue(optional_field.has_presence) 1693 message_field = proto3_descriptor.fields_by_name['optional_nested_message'] 1694 self.assertTrue(message_field.has_presence) 1695 oneof_field = proto3_descriptor.fields_by_name['oneof_uint32'] 1696 self.assertTrue(oneof_field.has_presence) 1697 1698 def testAssignUnknownEnum(self): 1699 """Assigning an unknown enum value is allowed and preserves the value.""" 1700 m = unittest_proto3_arena_pb2.TestAllTypes() 1701 1702 # Proto3 can assign unknown enums. 1703 m.optional_nested_enum = 1234567 1704 self.assertEqual(1234567, m.optional_nested_enum) 1705 m.repeated_nested_enum.append(22334455) 1706 self.assertEqual(22334455, m.repeated_nested_enum[0]) 1707 # Assignment is a different code path than append for the C++ impl. 1708 m.repeated_nested_enum[0] = 7654321 1709 self.assertEqual(7654321, m.repeated_nested_enum[0]) 1710 serialized = m.SerializeToString() 1711 1712 m2 = unittest_proto3_arena_pb2.TestAllTypes() 1713 m2.ParseFromString(serialized) 1714 self.assertEqual(1234567, m2.optional_nested_enum) 1715 self.assertEqual(7654321, m2.repeated_nested_enum[0]) 1716 1717 # Map isn't really a proto3-only feature. But there is no proto2 equivalent 1718 # of google/protobuf/map_unittest.proto right now, so it's not easy to 1719 # test both with the same test like we do for the other proto2/proto3 tests. 1720 # (google/protobuf/map_proto2_unittest.proto is very different in the set 1721 # of messages and fields it contains). 1722 def testScalarMapDefaults(self): 1723 msg = map_unittest_pb2.TestMap() 1724 1725 # Scalars start out unset. 1726 self.assertFalse(-123 in msg.map_int32_int32) 1727 self.assertFalse(-2**33 in msg.map_int64_int64) 1728 self.assertFalse(123 in msg.map_uint32_uint32) 1729 self.assertFalse(2**33 in msg.map_uint64_uint64) 1730 self.assertFalse(123 in msg.map_int32_double) 1731 self.assertFalse(False in msg.map_bool_bool) 1732 self.assertFalse('abc' in msg.map_string_string) 1733 self.assertFalse(111 in msg.map_int32_bytes) 1734 self.assertFalse(888 in msg.map_int32_enum) 1735 1736 # Accessing an unset key returns the default. 1737 self.assertEqual(0, msg.map_int32_int32[-123]) 1738 self.assertEqual(0, msg.map_int64_int64[-2**33]) 1739 self.assertEqual(0, msg.map_uint32_uint32[123]) 1740 self.assertEqual(0, msg.map_uint64_uint64[2**33]) 1741 self.assertEqual(0.0, msg.map_int32_double[123]) 1742 self.assertTrue(isinstance(msg.map_int32_double[123], float)) 1743 self.assertEqual(False, msg.map_bool_bool[False]) 1744 self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) 1745 self.assertEqual('', msg.map_string_string['abc']) 1746 self.assertEqual(b'', msg.map_int32_bytes[111]) 1747 self.assertEqual(0, msg.map_int32_enum[888]) 1748 1749 # It also sets the value in the map 1750 self.assertTrue(-123 in msg.map_int32_int32) 1751 self.assertTrue(-2**33 in msg.map_int64_int64) 1752 self.assertTrue(123 in msg.map_uint32_uint32) 1753 self.assertTrue(2**33 in msg.map_uint64_uint64) 1754 self.assertTrue(123 in msg.map_int32_double) 1755 self.assertTrue(False in msg.map_bool_bool) 1756 self.assertTrue('abc' in msg.map_string_string) 1757 self.assertTrue(111 in msg.map_int32_bytes) 1758 self.assertTrue(888 in msg.map_int32_enum) 1759 1760 self.assertIsInstance(msg.map_string_string['abc'], str) 1761 1762 # Accessing an unset key still throws TypeError if the type of the key 1763 # is incorrect. 1764 with self.assertRaises(TypeError): 1765 msg.map_string_string[123] 1766 1767 with self.assertRaises(TypeError): 1768 123 in msg.map_string_string 1769 1770 def testMapGet(self): 1771 # Need to test that get() properly returns the default, even though the dict 1772 # has defaultdict-like semantics. 1773 msg = map_unittest_pb2.TestMap() 1774 1775 self.assertIsNone(msg.map_int32_int32.get(5)) 1776 self.assertEqual(10, msg.map_int32_int32.get(5, 10)) 1777 self.assertEqual(10, msg.map_int32_int32.get(key=5, default=10)) 1778 self.assertIsNone(msg.map_int32_int32.get(5)) 1779 1780 msg.map_int32_int32[5] = 15 1781 self.assertEqual(15, msg.map_int32_int32.get(5)) 1782 self.assertEqual(15, msg.map_int32_int32.get(5)) 1783 with self.assertRaises(TypeError): 1784 msg.map_int32_int32.get('') 1785 1786 self.assertIsNone(msg.map_int32_foreign_message.get(5)) 1787 self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10)) 1788 self.assertEqual(10, msg.map_int32_foreign_message.get(key=5, default=10)) 1789 1790 submsg = msg.map_int32_foreign_message[5] 1791 self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) 1792 with self.assertRaises(TypeError): 1793 msg.map_int32_foreign_message.get('') 1794 1795 def testScalarMap(self): 1796 msg = map_unittest_pb2.TestMap() 1797 1798 self.assertEqual(0, len(msg.map_int32_int32)) 1799 self.assertFalse(5 in msg.map_int32_int32) 1800 1801 msg.map_int32_int32[-123] = -456 1802 msg.map_int64_int64[-2**33] = -2**34 1803 msg.map_uint32_uint32[123] = 456 1804 msg.map_uint64_uint64[2**33] = 2**34 1805 msg.map_int32_float[2] = 1.2 1806 msg.map_int32_double[1] = 3.3 1807 msg.map_string_string['abc'] = '123' 1808 msg.map_bool_bool[True] = True 1809 msg.map_int32_enum[888] = 2 1810 # Unknown numeric enum is supported in proto3. 1811 msg.map_int32_enum[123] = 456 1812 1813 self.assertEqual([], msg.FindInitializationErrors()) 1814 1815 self.assertEqual(1, len(msg.map_string_string)) 1816 1817 # Bad key. 1818 with self.assertRaises(TypeError): 1819 msg.map_string_string[123] = '123' 1820 1821 # Verify that trying to assign a bad key doesn't actually add a member to 1822 # the map. 1823 self.assertEqual(1, len(msg.map_string_string)) 1824 1825 # Bad value. 1826 with self.assertRaises(TypeError): 1827 msg.map_string_string['123'] = 123 1828 1829 serialized = msg.SerializeToString() 1830 msg2 = map_unittest_pb2.TestMap() 1831 msg2.ParseFromString(serialized) 1832 1833 # Bad key. 1834 with self.assertRaises(TypeError): 1835 msg2.map_string_string[123] = '123' 1836 1837 # Bad value. 1838 with self.assertRaises(TypeError): 1839 msg2.map_string_string['123'] = 123 1840 1841 self.assertEqual(-456, msg2.map_int32_int32[-123]) 1842 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 1843 self.assertEqual(456, msg2.map_uint32_uint32[123]) 1844 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 1845 self.assertAlmostEqual(1.2, msg.map_int32_float[2]) 1846 self.assertEqual(3.3, msg.map_int32_double[1]) 1847 self.assertEqual('123', msg2.map_string_string['abc']) 1848 self.assertEqual(True, msg2.map_bool_bool[True]) 1849 self.assertEqual(2, msg2.map_int32_enum[888]) 1850 self.assertEqual(456, msg2.map_int32_enum[123]) 1851 self.assertEqual('{-123: -456}', str(msg2.map_int32_int32)) 1852 1853 def testMapEntryAlwaysSerialized(self): 1854 msg = map_unittest_pb2.TestMap() 1855 msg.map_int32_int32[0] = 0 1856 msg.map_string_string[''] = '' 1857 self.assertEqual(msg.ByteSize(), 12) 1858 self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00', 1859 msg.SerializeToString()) 1860 1861 def testStringUnicodeConversionInMap(self): 1862 msg = map_unittest_pb2.TestMap() 1863 1864 unicode_obj = u'\u1234' 1865 bytes_obj = unicode_obj.encode('utf8') 1866 1867 msg.map_string_string[bytes_obj] = bytes_obj 1868 1869 (key, value) = list(msg.map_string_string.items())[0] 1870 1871 self.assertEqual(key, unicode_obj) 1872 self.assertEqual(value, unicode_obj) 1873 1874 self.assertIsInstance(key, str) 1875 self.assertIsInstance(value, str) 1876 1877 def testMessageMap(self): 1878 msg = map_unittest_pb2.TestMap() 1879 1880 self.assertEqual(0, len(msg.map_int32_foreign_message)) 1881 self.assertFalse(5 in msg.map_int32_foreign_message) 1882 1883 msg.map_int32_foreign_message[123] 1884 # get_or_create() is an alias for getitem. 1885 msg.map_int32_foreign_message.get_or_create(-456) 1886 1887 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1888 self.assertIn(123, msg.map_int32_foreign_message) 1889 self.assertIn(-456, msg.map_int32_foreign_message) 1890 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1891 1892 # Bad key. 1893 with self.assertRaises(TypeError): 1894 msg.map_int32_foreign_message['123'] 1895 1896 # Can't assign directly to submessage. 1897 with self.assertRaises(ValueError): 1898 msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] 1899 1900 # Verify that trying to assign a bad key doesn't actually add a member to 1901 # the map. 1902 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1903 1904 serialized = msg.SerializeToString() 1905 msg2 = map_unittest_pb2.TestMap() 1906 msg2.ParseFromString(serialized) 1907 1908 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 1909 self.assertIn(123, msg2.map_int32_foreign_message) 1910 self.assertIn(-456, msg2.map_int32_foreign_message) 1911 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 1912 msg2.map_int32_foreign_message[123].c = 1 1913 # TODO(jieluo): Fix text format for message map. 1914 self.assertIn( 1915 str(msg2.map_int32_foreign_message), 1916 ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }')) 1917 1918 def testNestedMessageMapItemDelete(self): 1919 msg = map_unittest_pb2.TestMap() 1920 msg.map_int32_all_types[1].optional_nested_message.bb = 1 1921 del msg.map_int32_all_types[1] 1922 msg.map_int32_all_types[2].optional_nested_message.bb = 2 1923 self.assertEqual(1, len(msg.map_int32_all_types)) 1924 msg.map_int32_all_types[1].optional_nested_message.bb = 1 1925 self.assertEqual(2, len(msg.map_int32_all_types)) 1926 1927 serialized = msg.SerializeToString() 1928 msg2 = map_unittest_pb2.TestMap() 1929 msg2.ParseFromString(serialized) 1930 keys = [1, 2] 1931 # The loop triggers PyErr_Occurred() in c extension. 1932 for key in keys: 1933 del msg2.map_int32_all_types[key] 1934 1935 def testMapByteSize(self): 1936 msg = map_unittest_pb2.TestMap() 1937 msg.map_int32_int32[1] = 1 1938 size = msg.ByteSize() 1939 msg.map_int32_int32[1] = 128 1940 self.assertEqual(msg.ByteSize(), size + 1) 1941 1942 msg.map_int32_foreign_message[19].c = 1 1943 size = msg.ByteSize() 1944 msg.map_int32_foreign_message[19].c = 128 1945 self.assertEqual(msg.ByteSize(), size + 1) 1946 1947 def testMergeFrom(self): 1948 msg = map_unittest_pb2.TestMap() 1949 msg.map_int32_int32[12] = 34 1950 msg.map_int32_int32[56] = 78 1951 msg.map_int64_int64[22] = 33 1952 msg.map_int32_foreign_message[111].c = 5 1953 msg.map_int32_foreign_message[222].c = 10 1954 1955 msg2 = map_unittest_pb2.TestMap() 1956 msg2.map_int32_int32[12] = 55 1957 msg2.map_int64_int64[88] = 99 1958 msg2.map_int32_foreign_message[222].c = 15 1959 msg2.map_int32_foreign_message[222].d = 20 1960 old_map_value = msg2.map_int32_foreign_message[222] 1961 1962 msg2.MergeFrom(msg) 1963 # Compare with expected message instead of call 1964 # msg2.map_int32_foreign_message[222] to make sure MergeFrom does not 1965 # sync with repeated field and there is no duplicated keys. 1966 expected_msg = map_unittest_pb2.TestMap() 1967 expected_msg.CopyFrom(msg) 1968 expected_msg.map_int64_int64[88] = 99 1969 self.assertEqual(msg2, expected_msg) 1970 1971 self.assertEqual(34, msg2.map_int32_int32[12]) 1972 self.assertEqual(78, msg2.map_int32_int32[56]) 1973 self.assertEqual(33, msg2.map_int64_int64[22]) 1974 self.assertEqual(99, msg2.map_int64_int64[88]) 1975 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 1976 self.assertEqual(10, msg2.map_int32_foreign_message[222].c) 1977 self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d')) 1978 if api_implementation.Type() != 'cpp': 1979 # During the call to MergeFrom(), the C++ implementation will have 1980 # deallocated the underlying message, but this is very difficult to detect 1981 # properly. The line below is likely to cause a segmentation fault. 1982 # With the Python implementation, old_map_value is just 'detached' from 1983 # the main message. Using it will not crash of course, but since it still 1984 # have a reference to the parent message I'm sure we can find interesting 1985 # ways to cause inconsistencies. 1986 self.assertEqual(15, old_map_value.c) 1987 1988 # Verify that there is only one entry per key, even though the MergeFrom 1989 # may have internally created multiple entries for a single key in the 1990 # list representation. 1991 as_dict = {} 1992 for key in msg2.map_int32_foreign_message: 1993 self.assertFalse(key in as_dict) 1994 as_dict[key] = msg2.map_int32_foreign_message[key].c 1995 1996 self.assertEqual({111: 5, 222: 10}, as_dict) 1997 1998 # Special case: test that delete of item really removes the item, even if 1999 # there might have physically been duplicate keys due to the previous merge. 2000 # This is only a special case for the C++ implementation which stores the 2001 # map as an array. 2002 del msg2.map_int32_int32[12] 2003 self.assertFalse(12 in msg2.map_int32_int32) 2004 2005 del msg2.map_int32_foreign_message[222] 2006 self.assertFalse(222 in msg2.map_int32_foreign_message) 2007 with self.assertRaises(TypeError): 2008 del msg2.map_int32_foreign_message[''] 2009 2010 def testMapMergeFrom(self): 2011 msg = map_unittest_pb2.TestMap() 2012 msg.map_int32_int32[12] = 34 2013 msg.map_int32_int32[56] = 78 2014 msg.map_int64_int64[22] = 33 2015 msg.map_int32_foreign_message[111].c = 5 2016 msg.map_int32_foreign_message[222].c = 10 2017 2018 msg2 = map_unittest_pb2.TestMap() 2019 msg2.map_int32_int32[12] = 55 2020 msg2.map_int64_int64[88] = 99 2021 msg2.map_int32_foreign_message[222].c = 15 2022 msg2.map_int32_foreign_message[222].d = 20 2023 2024 msg2.map_int32_int32.MergeFrom(msg.map_int32_int32) 2025 self.assertEqual(34, msg2.map_int32_int32[12]) 2026 self.assertEqual(78, msg2.map_int32_int32[56]) 2027 2028 msg2.map_int64_int64.MergeFrom(msg.map_int64_int64) 2029 self.assertEqual(33, msg2.map_int64_int64[22]) 2030 self.assertEqual(99, msg2.map_int64_int64[88]) 2031 2032 msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message) 2033 # Compare with expected message instead of call 2034 # msg.map_int32_foreign_message[222] to make sure MergeFrom does not 2035 # sync with repeated field and no duplicated keys. 2036 expected_msg = map_unittest_pb2.TestMap() 2037 expected_msg.CopyFrom(msg) 2038 expected_msg.map_int64_int64[88] = 99 2039 self.assertEqual(msg2, expected_msg) 2040 2041 # Test when cpp extension cache a map. 2042 m1 = map_unittest_pb2.TestMap() 2043 m2 = map_unittest_pb2.TestMap() 2044 self.assertEqual(m1.map_int32_foreign_message, m1.map_int32_foreign_message) 2045 m2.map_int32_foreign_message[123].c = 10 2046 m1.MergeFrom(m2) 2047 self.assertEqual(10, m2.map_int32_foreign_message[123].c) 2048 2049 # Test merge maps within different message types. 2050 m1 = map_unittest_pb2.TestMap() 2051 m2 = map_unittest_pb2.TestMessageMap() 2052 m2.map_int32_message[123].optional_int32 = 10 2053 m1.map_int32_all_types.MergeFrom(m2.map_int32_message) 2054 self.assertEqual(10, m1.map_int32_all_types[123].optional_int32) 2055 2056 # Test overwrite message value map 2057 msg = map_unittest_pb2.TestMap() 2058 msg.map_int32_foreign_message[222].c = 123 2059 msg2 = map_unittest_pb2.TestMap() 2060 msg2.map_int32_foreign_message[222].d = 20 2061 msg.MergeFromString(msg2.SerializeToString()) 2062 self.assertEqual(msg.map_int32_foreign_message[222].d, 20) 2063 self.assertNotEqual(msg.map_int32_foreign_message[222].c, 123) 2064 2065 # Merge a dict to map field is not accepted 2066 with self.assertRaises(AttributeError): 2067 m1.map_int32_all_types.MergeFrom( 2068 {1: unittest_proto3_arena_pb2.TestAllTypes()}) 2069 2070 def testMergeFromBadType(self): 2071 msg = map_unittest_pb2.TestMap() 2072 with self.assertRaisesRegex( 2073 TypeError, 2074 r'Parameter to MergeFrom\(\) must be instance of same class: expected ' 2075 r'.+TestMap got int\.'): 2076 msg.MergeFrom(1) 2077 2078 def testCopyFromBadType(self): 2079 msg = map_unittest_pb2.TestMap() 2080 with self.assertRaisesRegex( 2081 TypeError, 2082 r'Parameter to [A-Za-z]*From\(\) must be instance of same class: ' 2083 r'expected .+TestMap got int\.'): 2084 msg.CopyFrom(1) 2085 2086 def testIntegerMapWithLongs(self): 2087 msg = map_unittest_pb2.TestMap() 2088 msg.map_int32_int32[int(-123)] = int(-456) 2089 msg.map_int64_int64[int(-2**33)] = int(-2**34) 2090 msg.map_uint32_uint32[int(123)] = int(456) 2091 msg.map_uint64_uint64[int(2**33)] = int(2**34) 2092 2093 serialized = msg.SerializeToString() 2094 msg2 = map_unittest_pb2.TestMap() 2095 msg2.ParseFromString(serialized) 2096 2097 self.assertEqual(-456, msg2.map_int32_int32[-123]) 2098 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 2099 self.assertEqual(456, msg2.map_uint32_uint32[123]) 2100 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 2101 2102 def testMapAssignmentCausesPresence(self): 2103 msg = map_unittest_pb2.TestMapSubmessage() 2104 msg.test_map.map_int32_int32[123] = 456 2105 2106 serialized = msg.SerializeToString() 2107 msg2 = map_unittest_pb2.TestMapSubmessage() 2108 msg2.ParseFromString(serialized) 2109 2110 self.assertEqual(msg, msg2) 2111 2112 # Now test that various mutations of the map properly invalidate the 2113 # cached size of the submessage. 2114 msg.test_map.map_int32_int32[888] = 999 2115 serialized = msg.SerializeToString() 2116 msg2.ParseFromString(serialized) 2117 self.assertEqual(msg, msg2) 2118 2119 msg.test_map.map_int32_int32.clear() 2120 serialized = msg.SerializeToString() 2121 msg2.ParseFromString(serialized) 2122 self.assertEqual(msg, msg2) 2123 2124 def testMapAssignmentCausesPresenceForSubmessages(self): 2125 msg = map_unittest_pb2.TestMapSubmessage() 2126 msg.test_map.map_int32_foreign_message[123].c = 5 2127 2128 serialized = msg.SerializeToString() 2129 msg2 = map_unittest_pb2.TestMapSubmessage() 2130 msg2.ParseFromString(serialized) 2131 2132 self.assertEqual(msg, msg2) 2133 2134 # Now test that various mutations of the map properly invalidate the 2135 # cached size of the submessage. 2136 msg.test_map.map_int32_foreign_message[888].c = 7 2137 serialized = msg.SerializeToString() 2138 msg2.ParseFromString(serialized) 2139 self.assertEqual(msg, msg2) 2140 2141 msg.test_map.map_int32_foreign_message[888].MergeFrom( 2142 msg.test_map.map_int32_foreign_message[123]) 2143 serialized = msg.SerializeToString() 2144 msg2.ParseFromString(serialized) 2145 self.assertEqual(msg, msg2) 2146 2147 msg.test_map.map_int32_foreign_message.clear() 2148 serialized = msg.SerializeToString() 2149 msg2.ParseFromString(serialized) 2150 self.assertEqual(msg, msg2) 2151 2152 def testModifyMapWhileIterating(self): 2153 msg = map_unittest_pb2.TestMap() 2154 2155 string_string_iter = iter(msg.map_string_string) 2156 int32_foreign_iter = iter(msg.map_int32_foreign_message) 2157 2158 msg.map_string_string['abc'] = '123' 2159 msg.map_int32_foreign_message[5].c = 5 2160 2161 with self.assertRaises(RuntimeError): 2162 for key in string_string_iter: 2163 pass 2164 2165 with self.assertRaises(RuntimeError): 2166 for key in int32_foreign_iter: 2167 pass 2168 2169 def testModifyMapEntryWhileIterating(self): 2170 msg = map_unittest_pb2.TestMap() 2171 2172 msg.map_string_string['abc'] = '123' 2173 msg.map_string_string['def'] = '456' 2174 msg.map_string_string['ghi'] = '789' 2175 2176 msg.map_int32_foreign_message[5].c = 5 2177 msg.map_int32_foreign_message[6].c = 6 2178 msg.map_int32_foreign_message[7].c = 7 2179 2180 string_string_keys = list(msg.map_string_string.keys()) 2181 int32_foreign_keys = list(msg.map_int32_foreign_message.keys()) 2182 2183 keys = [] 2184 for key in msg.map_string_string: 2185 keys.append(key) 2186 msg.map_string_string[key] = '000' 2187 self.assertEqual(keys, string_string_keys) 2188 self.assertEqual(keys, list(msg.map_string_string.keys())) 2189 2190 keys = [] 2191 for key in msg.map_int32_foreign_message: 2192 keys.append(key) 2193 msg.map_int32_foreign_message[key].c = 0 2194 self.assertEqual(keys, int32_foreign_keys) 2195 self.assertEqual(keys, list(msg.map_int32_foreign_message.keys())) 2196 2197 def testSubmessageMap(self): 2198 msg = map_unittest_pb2.TestMap() 2199 2200 submsg = msg.map_int32_foreign_message[111] 2201 self.assertIs(submsg, msg.map_int32_foreign_message[111]) 2202 self.assertIsInstance(submsg, unittest_pb2.ForeignMessage) 2203 2204 submsg.c = 5 2205 2206 serialized = msg.SerializeToString() 2207 msg2 = map_unittest_pb2.TestMap() 2208 msg2.ParseFromString(serialized) 2209 2210 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 2211 2212 # Doesn't allow direct submessage assignment. 2213 with self.assertRaises(ValueError): 2214 msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage() 2215 2216 def testMapIteration(self): 2217 msg = map_unittest_pb2.TestMap() 2218 2219 for k, v in msg.map_int32_int32.items(): 2220 # Should not be reached. 2221 self.assertTrue(False) 2222 2223 msg.map_int32_int32[2] = 4 2224 msg.map_int32_int32[3] = 6 2225 msg.map_int32_int32[4] = 8 2226 self.assertEqual(3, len(msg.map_int32_int32)) 2227 2228 matching_dict = {2: 4, 3: 6, 4: 8} 2229 self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) 2230 2231 def testMapItems(self): 2232 # Map items used to have strange behaviors when use c extension. Because 2233 # [] may reorder the map and invalidate any existing iterators. 2234 # TODO(jieluo): Check if [] reordering the map is a bug or intended 2235 # behavior. 2236 msg = map_unittest_pb2.TestMap() 2237 msg.map_string_string['local_init_op'] = '' 2238 msg.map_string_string['trainable_variables'] = '' 2239 msg.map_string_string['variables'] = '' 2240 msg.map_string_string['init_op'] = '' 2241 msg.map_string_string['summaries'] = '' 2242 items1 = msg.map_string_string.items() 2243 items2 = msg.map_string_string.items() 2244 self.assertEqual(items1, items2) 2245 2246 def testMapDeterministicSerialization(self): 2247 golden_data = (b'r\x0c\n\x07init_op\x12\x01d' 2248 b'r\n\n\x05item1\x12\x01e' 2249 b'r\n\n\x05item2\x12\x01f' 2250 b'r\n\n\x05item3\x12\x01g' 2251 b'r\x0b\n\x05item4\x12\x02QQ' 2252 b'r\x12\n\rlocal_init_op\x12\x01a' 2253 b'r\x0e\n\tsummaries\x12\x01e' 2254 b'r\x18\n\x13trainable_variables\x12\x01b' 2255 b'r\x0e\n\tvariables\x12\x01c') 2256 msg = map_unittest_pb2.TestMap() 2257 msg.map_string_string['local_init_op'] = 'a' 2258 msg.map_string_string['trainable_variables'] = 'b' 2259 msg.map_string_string['variables'] = 'c' 2260 msg.map_string_string['init_op'] = 'd' 2261 msg.map_string_string['summaries'] = 'e' 2262 msg.map_string_string['item1'] = 'e' 2263 msg.map_string_string['item2'] = 'f' 2264 msg.map_string_string['item3'] = 'g' 2265 msg.map_string_string['item4'] = 'QQ' 2266 2267 # If deterministic serialization is not working correctly, this will be 2268 # "flaky" depending on the exact python dict hash seed. 2269 # 2270 # Fortunately, there are enough items in this map that it is extremely 2271 # unlikely to ever hit the "right" in-order combination, so the test 2272 # itself should fail reliably. 2273 self.assertEqual(golden_data, msg.SerializeToString(deterministic=True)) 2274 2275 def testMapIterationClearMessage(self): 2276 # Iterator needs to work even if message and map are deleted. 2277 msg = map_unittest_pb2.TestMap() 2278 2279 msg.map_int32_int32[2] = 4 2280 msg.map_int32_int32[3] = 6 2281 msg.map_int32_int32[4] = 8 2282 2283 it = msg.map_int32_int32.items() 2284 del msg 2285 2286 matching_dict = {2: 4, 3: 6, 4: 8} 2287 self.assertMapIterEquals(it, matching_dict) 2288 2289 def testMapConstruction(self): 2290 msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4}) 2291 self.assertEqual(2, msg.map_int32_int32[1]) 2292 self.assertEqual(4, msg.map_int32_int32[3]) 2293 2294 msg = map_unittest_pb2.TestMap( 2295 map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)}) 2296 self.assertEqual(5, msg.map_int32_foreign_message[3].c) 2297 2298 def testMapScalarFieldConstruction(self): 2299 msg1 = map_unittest_pb2.TestMap() 2300 msg1.map_int32_int32[1] = 42 2301 msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32) 2302 self.assertEqual(42, msg2.map_int32_int32[1]) 2303 2304 def testMapMessageFieldConstruction(self): 2305 msg1 = map_unittest_pb2.TestMap() 2306 msg1.map_string_foreign_message['test'].c = 42 2307 msg2 = map_unittest_pb2.TestMap( 2308 map_string_foreign_message=msg1.map_string_foreign_message) 2309 self.assertEqual(42, msg2.map_string_foreign_message['test'].c) 2310 2311 def testMapFieldRaisesCorrectError(self): 2312 # Should raise a TypeError when given a non-iterable. 2313 with self.assertRaises(TypeError): 2314 map_unittest_pb2.TestMap(map_string_foreign_message=1) 2315 2316 def testMapValidAfterFieldCleared(self): 2317 # Map needs to work even if field is cleared. 2318 # For the C++ implementation this tests the correctness of 2319 # MapContainer::Release() 2320 msg = map_unittest_pb2.TestMap() 2321 int32_map = msg.map_int32_int32 2322 2323 int32_map[2] = 4 2324 int32_map[3] = 6 2325 int32_map[4] = 8 2326 2327 msg.ClearField('map_int32_int32') 2328 self.assertEqual(b'', msg.SerializeToString()) 2329 matching_dict = {2: 4, 3: 6, 4: 8} 2330 self.assertMapIterEquals(int32_map.items(), matching_dict) 2331 2332 def testMessageMapValidAfterFieldCleared(self): 2333 # Map needs to work even if field is cleared. 2334 # For the C++ implementation this tests the correctness of 2335 # MapContainer::Release() 2336 msg = map_unittest_pb2.TestMap() 2337 int32_foreign_message = msg.map_int32_foreign_message 2338 2339 int32_foreign_message[2].c = 5 2340 2341 msg.ClearField('map_int32_foreign_message') 2342 self.assertEqual(b'', msg.SerializeToString()) 2343 self.assertTrue(2 in int32_foreign_message.keys()) 2344 2345 def testMessageMapItemValidAfterTopMessageCleared(self): 2346 # Message map item needs to work even if it is cleared. 2347 # For the C++ implementation this tests the correctness of 2348 # MapContainer::Release() 2349 msg = map_unittest_pb2.TestMap() 2350 msg.map_int32_all_types[2].optional_string = 'bar' 2351 2352 if api_implementation.Type() == 'cpp': 2353 # Need to keep the map reference because of b/27942626. 2354 # TODO(jieluo): Remove it. 2355 unused_map = msg.map_int32_all_types # pylint: disable=unused-variable 2356 msg_value = msg.map_int32_all_types[2] 2357 msg.Clear() 2358 2359 # Reset to trigger sync between repeated field and map in c++. 2360 msg.map_int32_all_types[3].optional_string = 'foo' 2361 self.assertEqual(msg_value.optional_string, 'bar') 2362 2363 def testMapIterInvalidatedByClearField(self): 2364 # Map iterator is invalidated when field is cleared. 2365 # But this case does need to not crash the interpreter. 2366 # For the C++ implementation this tests the correctness of 2367 # ScalarMapContainer::Release() 2368 msg = map_unittest_pb2.TestMap() 2369 2370 it = iter(msg.map_int32_int32) 2371 2372 msg.ClearField('map_int32_int32') 2373 with self.assertRaises(RuntimeError): 2374 for _ in it: 2375 pass 2376 2377 it = iter(msg.map_int32_foreign_message) 2378 msg.ClearField('map_int32_foreign_message') 2379 with self.assertRaises(RuntimeError): 2380 for _ in it: 2381 pass 2382 2383 def testMapDelete(self): 2384 msg = map_unittest_pb2.TestMap() 2385 2386 self.assertEqual(0, len(msg.map_int32_int32)) 2387 2388 msg.map_int32_int32[4] = 6 2389 self.assertEqual(1, len(msg.map_int32_int32)) 2390 2391 with self.assertRaises(KeyError): 2392 del msg.map_int32_int32[88] 2393 2394 del msg.map_int32_int32[4] 2395 self.assertEqual(0, len(msg.map_int32_int32)) 2396 2397 with self.assertRaises(KeyError): 2398 del msg.map_int32_all_types[32] 2399 2400 def testMapsAreMapping(self): 2401 msg = map_unittest_pb2.TestMap() 2402 self.assertIsInstance(msg.map_int32_int32, collections.abc.Mapping) 2403 self.assertIsInstance(msg.map_int32_int32, collections.abc.MutableMapping) 2404 self.assertIsInstance(msg.map_int32_foreign_message, 2405 collections.abc.Mapping) 2406 self.assertIsInstance(msg.map_int32_foreign_message, 2407 collections.abc.MutableMapping) 2408 2409 def testMapsCompare(self): 2410 msg = map_unittest_pb2.TestMap() 2411 msg.map_int32_int32[-123] = -456 2412 self.assertEqual(msg.map_int32_int32, msg.map_int32_int32) 2413 self.assertEqual(msg.map_int32_foreign_message, 2414 msg.map_int32_foreign_message) 2415 self.assertNotEqual(msg.map_int32_int32, 0) 2416 2417 def testMapFindInitializationErrorsSmokeTest(self): 2418 msg = map_unittest_pb2.TestMap() 2419 msg.map_string_string['abc'] = '123' 2420 msg.map_int32_int32[35] = 64 2421 msg.map_string_foreign_message['foo'].c = 5 2422 self.assertEqual(0, len(msg.FindInitializationErrors())) 2423 2424 @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2') 2425 def testStrictUtf8Check(self): 2426 # Test u'\ud801' is rejected at parser in both python2 and python3. 2427 serialized = (b'r\x03\xed\xa0\x81') 2428 msg = unittest_proto3_arena_pb2.TestAllTypes() 2429 with self.assertRaises(Exception) as context: 2430 msg.MergeFromString(serialized) 2431 if api_implementation.Type() == 'python': 2432 self.assertIn('optional_string', str(context.exception)) 2433 else: 2434 self.assertIn('Error parsing message', str(context.exception)) 2435 2436 # Test optional_string=u'' is accepted. 2437 serialized = unittest_proto3_arena_pb2.TestAllTypes( 2438 optional_string=u'').SerializeToString() 2439 msg2 = unittest_proto3_arena_pb2.TestAllTypes() 2440 msg2.MergeFromString(serialized) 2441 self.assertEqual(msg2.optional_string, u'') 2442 2443 msg = unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud001') 2444 self.assertEqual(msg.optional_string, u'\ud001') 2445 2446 def testSurrogatesInPython3(self): 2447 # Surrogates are rejected at setters in Python3. 2448 with self.assertRaises(ValueError): 2449 unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\udc01') 2450 with self.assertRaises(ValueError): 2451 unittest_proto3_arena_pb2.TestAllTypes(optional_string=b'\xed\xa0\x81') 2452 with self.assertRaises(ValueError): 2453 unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801') 2454 with self.assertRaises(ValueError): 2455 unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\ud801') 2456 2457 2458 2459 2460@testing_refleaks.TestCase 2461class ValidTypeNamesTest(unittest.TestCase): 2462 2463 def assertImportFromName(self, msg, base_name): 2464 # Parse <type 'module.class_name'> to extra 'some.name' as a string. 2465 tp_name = str(type(msg)).split("'")[1] 2466 valid_names = ('Repeated%sContainer' % base_name, 2467 'Repeated%sFieldContainer' % base_name) 2468 self.assertTrue( 2469 any(tp_name.endswith(v) for v in valid_names), 2470 '%r does end with any of %r' % (tp_name, valid_names)) 2471 2472 parts = tp_name.split('.') 2473 class_name = parts[-1] 2474 module_name = '.'.join(parts[:-1]) 2475 __import__(module_name, fromlist=[class_name]) 2476 2477 def testTypeNamesCanBeImported(self): 2478 # If import doesn't work, pickling won't work either. 2479 pb = unittest_pb2.TestAllTypes() 2480 self.assertImportFromName(pb.repeated_int32, 'Scalar') 2481 self.assertImportFromName(pb.repeated_nested_message, 'Composite') 2482 2483 2484@testing_refleaks.TestCase 2485class PackedFieldTest(unittest.TestCase): 2486 2487 def setMessage(self, message): 2488 message.repeated_int32.append(1) 2489 message.repeated_int64.append(1) 2490 message.repeated_uint32.append(1) 2491 message.repeated_uint64.append(1) 2492 message.repeated_sint32.append(1) 2493 message.repeated_sint64.append(1) 2494 message.repeated_fixed32.append(1) 2495 message.repeated_fixed64.append(1) 2496 message.repeated_sfixed32.append(1) 2497 message.repeated_sfixed64.append(1) 2498 message.repeated_float.append(1.0) 2499 message.repeated_double.append(1.0) 2500 message.repeated_bool.append(True) 2501 message.repeated_nested_enum.append(1) 2502 2503 def testPackedFields(self): 2504 message = packed_field_test_pb2.TestPackedTypes() 2505 self.setMessage(message) 2506 golden_data = (b'\x0A\x01\x01' 2507 b'\x12\x01\x01' 2508 b'\x1A\x01\x01' 2509 b'\x22\x01\x01' 2510 b'\x2A\x01\x02' 2511 b'\x32\x01\x02' 2512 b'\x3A\x04\x01\x00\x00\x00' 2513 b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00' 2514 b'\x4A\x04\x01\x00\x00\x00' 2515 b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00' 2516 b'\x5A\x04\x00\x00\x80\x3f' 2517 b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f' 2518 b'\x6A\x01\x01' 2519 b'\x72\x01\x01') 2520 self.assertEqual(golden_data, message.SerializeToString()) 2521 2522 def testUnpackedFields(self): 2523 message = packed_field_test_pb2.TestUnpackedTypes() 2524 self.setMessage(message) 2525 golden_data = (b'\x08\x01' 2526 b'\x10\x01' 2527 b'\x18\x01' 2528 b'\x20\x01' 2529 b'\x28\x02' 2530 b'\x30\x02' 2531 b'\x3D\x01\x00\x00\x00' 2532 b'\x41\x01\x00\x00\x00\x00\x00\x00\x00' 2533 b'\x4D\x01\x00\x00\x00' 2534 b'\x51\x01\x00\x00\x00\x00\x00\x00\x00' 2535 b'\x5D\x00\x00\x80\x3f' 2536 b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f' 2537 b'\x68\x01' 2538 b'\x70\x01') 2539 self.assertEqual(golden_data, message.SerializeToString()) 2540 2541 2542@unittest.skipIf(api_implementation.Type() != 'cpp', 2543 'explicit tests of the C++ implementation') 2544@testing_refleaks.TestCase 2545class OversizeProtosTest(unittest.TestCase): 2546 2547 @classmethod 2548 def setUpClass(cls): 2549 # At the moment, reference cycles between DescriptorPool and Message classes 2550 # are not detected and these objects are never freed. 2551 # To avoid errors with ReferenceLeakChecker, we create the class only once. 2552 file_desc = """ 2553 name: "f/f.msg2" 2554 package: "f" 2555 message_type { 2556 name: "msg1" 2557 field { 2558 name: "payload" 2559 number: 1 2560 label: LABEL_OPTIONAL 2561 type: TYPE_STRING 2562 } 2563 } 2564 message_type { 2565 name: "msg2" 2566 field { 2567 name: "field" 2568 number: 1 2569 label: LABEL_OPTIONAL 2570 type: TYPE_MESSAGE 2571 type_name: "msg1" 2572 } 2573 } 2574 """ 2575 pool = descriptor_pool.DescriptorPool() 2576 desc = descriptor_pb2.FileDescriptorProto() 2577 text_format.Parse(file_desc, desc) 2578 pool.Add(desc) 2579 cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype( 2580 pool.FindMessageTypeByName('f.msg2')) 2581 2582 def setUp(self): 2583 self.p = self.proto_cls() 2584 self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) 2585 self.p_serialized = self.p.SerializeToString() 2586 2587 def testAssertOversizeProto(self): 2588 from google.protobuf.pyext._message import SetAllowOversizeProtos 2589 SetAllowOversizeProtos(False) 2590 q = self.proto_cls() 2591 try: 2592 q.ParseFromString(self.p_serialized) 2593 except message.DecodeError as e: 2594 self.assertEqual(str(e), 'Error parsing message') 2595 2596 def testSucceedOversizeProto(self): 2597 from google.protobuf.pyext._message import SetAllowOversizeProtos 2598 SetAllowOversizeProtos(True) 2599 q = self.proto_cls() 2600 q.ParseFromString(self.p_serialized) 2601 self.assertEqual(self.p.field.payload, q.field.payload) 2602 2603 2604if __name__ == '__main__': 2605 unittest.main() 2606