1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# https://developers.google.com/protocol-buffers/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Tests python protocol buffers against the golden message. 35 36Note that the golden messages exercise every known field type, thus this 37test ends up exercising and verifying nearly all of the parsing and 38serialization code in the whole library. 39 40TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of 41sense to call this a test of the "message" module, which only declares an 42abstract interface. 43""" 44 45__author__ = 'gps@google.com (Gregory P. Smith)' 46 47 48import copy 49import math 50import operator 51import pickle 52import pydoc 53import six 54import sys 55import warnings 56 57try: 58 # Since python 3 59 import collections.abc as collections_abc 60except ImportError: 61 # Won't work after python 3.8 62 import collections as collections_abc 63 64try: 65 import unittest2 as unittest # PY26 66except ImportError: 67 import unittest 68try: 69 cmp # Python 2 70except NameError: 71 cmp = lambda x, y: (x > y) - (x < y) # Python 3 72 73from google.protobuf import map_proto2_unittest_pb2 74from google.protobuf import map_unittest_pb2 75from google.protobuf import unittest_pb2 76from google.protobuf import unittest_proto3_arena_pb2 77from google.protobuf import descriptor_pb2 78from google.protobuf import descriptor_pool 79from google.protobuf import message_factory 80from google.protobuf import text_format 81from google.protobuf.internal import api_implementation 82from google.protobuf.internal import encoder 83from google.protobuf.internal import more_extensions_pb2 84from google.protobuf.internal import packed_field_test_pb2 85from google.protobuf.internal import test_util 86from google.protobuf.internal import test_proto3_optional_pb2 87from google.protobuf.internal import testing_refleaks 88from google.protobuf import message 89from google.protobuf.internal import _parameterized 90 91UCS2_MAXUNICODE = 65535 92if six.PY3: 93 long = int 94 95 96# Python pre-2.6 does not have isinf() or isnan() functions, so we have 97# to provide our own. 98def isnan(val): 99 # NaN is never equal to itself. 100 return val != val 101def isinf(val): 102 # Infinity times zero equals NaN. 103 return not isnan(val) and isnan(val * 0) 104def IsPosInf(val): 105 return isinf(val) and (val > 0) 106def IsNegInf(val): 107 return isinf(val) and (val < 0) 108 109 110warnings.simplefilter('error', DeprecationWarning) 111 112 113@_parameterized.named_parameters( 114 ('_proto2', unittest_pb2), 115 ('_proto3', unittest_proto3_arena_pb2)) 116@testing_refleaks.TestCase 117class MessageTest(unittest.TestCase): 118 119 def testBadUtf8String(self, message_module): 120 if api_implementation.Type() != 'python': 121 self.skipTest("Skipping testBadUtf8String, currently only the python " 122 "api implementation raises UnicodeDecodeError when a " 123 "string field contains bad utf-8.") 124 bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') 125 with self.assertRaises(UnicodeDecodeError) as context: 126 message_module.TestAllTypes.FromString(bad_utf8_data) 127 self.assertIn('TestAllTypes.optional_string', str(context.exception)) 128 129 def testGoldenMessage(self, message_module): 130 # Proto3 doesn't have the "default_foo" members or foreign enums, 131 # and doesn't preserve unknown fields, so for proto3 we use a golden 132 # message that doesn't have these fields set. 133 if message_module is unittest_pb2: 134 golden_data = test_util.GoldenFileData( 135 'golden_message_oneof_implemented') 136 else: 137 golden_data = test_util.GoldenFileData('golden_message_proto3') 138 139 golden_message = message_module.TestAllTypes() 140 golden_message.ParseFromString(golden_data) 141 if message_module is unittest_pb2: 142 test_util.ExpectAllFieldsSet(self, golden_message) 143 self.assertEqual(golden_data, golden_message.SerializeToString()) 144 golden_copy = copy.deepcopy(golden_message) 145 self.assertEqual(golden_data, golden_copy.SerializeToString()) 146 147 def testGoldenPackedMessage(self, message_module): 148 golden_data = test_util.GoldenFileData('golden_packed_fields_message') 149 golden_message = message_module.TestPackedTypes() 150 parsed_bytes = golden_message.ParseFromString(golden_data) 151 all_set = message_module.TestPackedTypes() 152 test_util.SetAllPackedFields(all_set) 153 self.assertEqual(parsed_bytes, len(golden_data)) 154 self.assertEqual(all_set, golden_message) 155 self.assertEqual(golden_data, all_set.SerializeToString()) 156 golden_copy = copy.deepcopy(golden_message) 157 self.assertEqual(golden_data, golden_copy.SerializeToString()) 158 159 def testParseErrors(self, message_module): 160 msg = message_module.TestAllTypes() 161 self.assertRaises(TypeError, msg.FromString, 0) 162 self.assertRaises(Exception, msg.FromString, '0') 163 # TODO(jieluo): Fix cpp extension to raise error instead of warning. 164 # b/27494216 165 end_tag = encoder.TagBytes(1, 4) 166 if api_implementation.Type() == 'python': 167 with self.assertRaises(message.DecodeError) as context: 168 msg.FromString(end_tag) 169 self.assertEqual('Unexpected end-group tag.', str(context.exception)) 170 171 # Field number 0 is illegal. 172 self.assertRaises(message.DecodeError, msg.FromString, b'\3\4') 173 174 def testDeterminismParameters(self, message_module): 175 # This message is always deterministically serialized, even if determinism 176 # is disabled, so we can use it to verify that all the determinism 177 # parameters work correctly. 178 golden_data = (b'\xe2\x02\nOne string' 179 b'\xe2\x02\nTwo string' 180 b'\xe2\x02\nRed string' 181 b'\xe2\x02\x0bBlue string') 182 golden_message = message_module.TestAllTypes() 183 golden_message.repeated_string.extend([ 184 'One string', 185 'Two string', 186 'Red string', 187 'Blue string', 188 ]) 189 self.assertEqual(golden_data, 190 golden_message.SerializeToString(deterministic=None)) 191 self.assertEqual(golden_data, 192 golden_message.SerializeToString(deterministic=False)) 193 self.assertEqual(golden_data, 194 golden_message.SerializeToString(deterministic=True)) 195 196 class BadArgError(Exception): 197 pass 198 199 class BadArg(object): 200 201 def __nonzero__(self): 202 raise BadArgError() 203 204 def __bool__(self): 205 raise BadArgError() 206 207 with self.assertRaises(BadArgError): 208 golden_message.SerializeToString(deterministic=BadArg()) 209 210 def testPickleSupport(self, message_module): 211 golden_data = test_util.GoldenFileData('golden_message') 212 golden_message = message_module.TestAllTypes() 213 golden_message.ParseFromString(golden_data) 214 pickled_message = pickle.dumps(golden_message) 215 216 unpickled_message = pickle.loads(pickled_message) 217 self.assertEqual(unpickled_message, golden_message) 218 219 def testPickleNestedMessage(self, message_module): 220 golden_message = message_module.TestPickleNestedMessage.NestedMessage(bb=1) 221 pickled_message = pickle.dumps(golden_message) 222 unpickled_message = pickle.loads(pickled_message) 223 self.assertEqual(unpickled_message, golden_message) 224 225 def testPickleNestedNestedMessage(self, message_module): 226 cls = message_module.TestPickleNestedMessage.NestedMessage 227 golden_message = cls.NestedNestedMessage(cc=1) 228 pickled_message = pickle.dumps(golden_message) 229 unpickled_message = pickle.loads(pickled_message) 230 self.assertEqual(unpickled_message, golden_message) 231 232 def testPositiveInfinity(self, message_module): 233 if message_module is unittest_pb2: 234 golden_data = (b'\x5D\x00\x00\x80\x7F' 235 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 236 b'\xCD\x02\x00\x00\x80\x7F' 237 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') 238 else: 239 golden_data = (b'\x5D\x00\x00\x80\x7F' 240 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 241 b'\xCA\x02\x04\x00\x00\x80\x7F' 242 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 243 244 golden_message = message_module.TestAllTypes() 245 golden_message.ParseFromString(golden_data) 246 self.assertTrue(IsPosInf(golden_message.optional_float)) 247 self.assertTrue(IsPosInf(golden_message.optional_double)) 248 self.assertTrue(IsPosInf(golden_message.repeated_float[0])) 249 self.assertTrue(IsPosInf(golden_message.repeated_double[0])) 250 self.assertEqual(golden_data, golden_message.SerializeToString()) 251 252 def testNegativeInfinity(self, message_module): 253 if message_module is unittest_pb2: 254 golden_data = (b'\x5D\x00\x00\x80\xFF' 255 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 256 b'\xCD\x02\x00\x00\x80\xFF' 257 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') 258 else: 259 golden_data = (b'\x5D\x00\x00\x80\xFF' 260 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 261 b'\xCA\x02\x04\x00\x00\x80\xFF' 262 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 263 264 golden_message = message_module.TestAllTypes() 265 golden_message.ParseFromString(golden_data) 266 self.assertTrue(IsNegInf(golden_message.optional_float)) 267 self.assertTrue(IsNegInf(golden_message.optional_double)) 268 self.assertTrue(IsNegInf(golden_message.repeated_float[0])) 269 self.assertTrue(IsNegInf(golden_message.repeated_double[0])) 270 self.assertEqual(golden_data, golden_message.SerializeToString()) 271 272 def testNotANumber(self, message_module): 273 golden_data = (b'\x5D\x00\x00\xC0\x7F' 274 b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' 275 b'\xCD\x02\x00\x00\xC0\x7F' 276 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') 277 golden_message = message_module.TestAllTypes() 278 golden_message.ParseFromString(golden_data) 279 self.assertTrue(isnan(golden_message.optional_float)) 280 self.assertTrue(isnan(golden_message.optional_double)) 281 self.assertTrue(isnan(golden_message.repeated_float[0])) 282 self.assertTrue(isnan(golden_message.repeated_double[0])) 283 284 # The protocol buffer may serialize to any one of multiple different 285 # representations of a NaN. Rather than verify a specific representation, 286 # verify the serialized string can be converted into a correctly 287 # behaving protocol buffer. 288 serialized = golden_message.SerializeToString() 289 message = message_module.TestAllTypes() 290 message.ParseFromString(serialized) 291 self.assertTrue(isnan(message.optional_float)) 292 self.assertTrue(isnan(message.optional_double)) 293 self.assertTrue(isnan(message.repeated_float[0])) 294 self.assertTrue(isnan(message.repeated_double[0])) 295 296 def testPositiveInfinityPacked(self, message_module): 297 golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' 298 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 299 golden_message = message_module.TestPackedTypes() 300 golden_message.ParseFromString(golden_data) 301 self.assertTrue(IsPosInf(golden_message.packed_float[0])) 302 self.assertTrue(IsPosInf(golden_message.packed_double[0])) 303 self.assertEqual(golden_data, golden_message.SerializeToString()) 304 305 def testNegativeInfinityPacked(self, message_module): 306 golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' 307 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 308 golden_message = message_module.TestPackedTypes() 309 golden_message.ParseFromString(golden_data) 310 self.assertTrue(IsNegInf(golden_message.packed_float[0])) 311 self.assertTrue(IsNegInf(golden_message.packed_double[0])) 312 self.assertEqual(golden_data, golden_message.SerializeToString()) 313 314 def testNotANumberPacked(self, message_module): 315 golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' 316 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') 317 golden_message = message_module.TestPackedTypes() 318 golden_message.ParseFromString(golden_data) 319 self.assertTrue(isnan(golden_message.packed_float[0])) 320 self.assertTrue(isnan(golden_message.packed_double[0])) 321 322 serialized = golden_message.SerializeToString() 323 message = message_module.TestPackedTypes() 324 message.ParseFromString(serialized) 325 self.assertTrue(isnan(message.packed_float[0])) 326 self.assertTrue(isnan(message.packed_double[0])) 327 328 def testExtremeFloatValues(self, message_module): 329 message = message_module.TestAllTypes() 330 331 # Most positive exponent, no significand bits set. 332 kMostPosExponentNoSigBits = math.pow(2, 127) 333 message.optional_float = kMostPosExponentNoSigBits 334 message.ParseFromString(message.SerializeToString()) 335 self.assertTrue(message.optional_float == kMostPosExponentNoSigBits) 336 337 # Most positive exponent, one significand bit set. 338 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127) 339 message.optional_float = kMostPosExponentOneSigBit 340 message.ParseFromString(message.SerializeToString()) 341 self.assertTrue(message.optional_float == kMostPosExponentOneSigBit) 342 343 # Repeat last two cases with values of same magnitude, but negative. 344 message.optional_float = -kMostPosExponentNoSigBits 345 message.ParseFromString(message.SerializeToString()) 346 self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits) 347 348 message.optional_float = -kMostPosExponentOneSigBit 349 message.ParseFromString(message.SerializeToString()) 350 self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit) 351 352 # Most negative exponent, no significand bits set. 353 kMostNegExponentNoSigBits = math.pow(2, -127) 354 message.optional_float = kMostNegExponentNoSigBits 355 message.ParseFromString(message.SerializeToString()) 356 self.assertTrue(message.optional_float == kMostNegExponentNoSigBits) 357 358 # Most negative exponent, one significand bit set. 359 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127) 360 message.optional_float = kMostNegExponentOneSigBit 361 message.ParseFromString(message.SerializeToString()) 362 self.assertTrue(message.optional_float == kMostNegExponentOneSigBit) 363 364 # Repeat last two cases with values of the same magnitude, but negative. 365 message.optional_float = -kMostNegExponentNoSigBits 366 message.ParseFromString(message.SerializeToString()) 367 self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits) 368 369 message.optional_float = -kMostNegExponentOneSigBit 370 message.ParseFromString(message.SerializeToString()) 371 self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) 372 373 # Max 4 bytes float value 374 max_float = float.fromhex('0x1.fffffep+127') 375 message.optional_float = max_float 376 self.assertAlmostEqual(message.optional_float, max_float) 377 serialized_data = message.SerializeToString() 378 message.ParseFromString(serialized_data) 379 self.assertAlmostEqual(message.optional_float, max_float) 380 381 # Test set double to float field. 382 message.optional_float = 3.4028235e+39 383 self.assertEqual(message.optional_float, float('inf')) 384 serialized_data = message.SerializeToString() 385 message.ParseFromString(serialized_data) 386 self.assertEqual(message.optional_float, float('inf')) 387 388 message.optional_float = -3.4028235e+39 389 self.assertEqual(message.optional_float, float('-inf')) 390 391 message.optional_float = 1.4028235e-39 392 self.assertAlmostEqual(message.optional_float, 1.4028235e-39) 393 394 def testExtremeDoubleValues(self, message_module): 395 message = message_module.TestAllTypes() 396 397 # Most positive exponent, no significand bits set. 398 kMostPosExponentNoSigBits = math.pow(2, 1023) 399 message.optional_double = kMostPosExponentNoSigBits 400 message.ParseFromString(message.SerializeToString()) 401 self.assertTrue(message.optional_double == kMostPosExponentNoSigBits) 402 403 # Most positive exponent, one significand bit set. 404 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023) 405 message.optional_double = kMostPosExponentOneSigBit 406 message.ParseFromString(message.SerializeToString()) 407 self.assertTrue(message.optional_double == kMostPosExponentOneSigBit) 408 409 # Repeat last two cases with values of same magnitude, but negative. 410 message.optional_double = -kMostPosExponentNoSigBits 411 message.ParseFromString(message.SerializeToString()) 412 self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits) 413 414 message.optional_double = -kMostPosExponentOneSigBit 415 message.ParseFromString(message.SerializeToString()) 416 self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit) 417 418 # Most negative exponent, no significand bits set. 419 kMostNegExponentNoSigBits = math.pow(2, -1023) 420 message.optional_double = kMostNegExponentNoSigBits 421 message.ParseFromString(message.SerializeToString()) 422 self.assertTrue(message.optional_double == kMostNegExponentNoSigBits) 423 424 # Most negative exponent, one significand bit set. 425 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023) 426 message.optional_double = kMostNegExponentOneSigBit 427 message.ParseFromString(message.SerializeToString()) 428 self.assertTrue(message.optional_double == kMostNegExponentOneSigBit) 429 430 # Repeat last two cases with values of the same magnitude, but negative. 431 message.optional_double = -kMostNegExponentNoSigBits 432 message.ParseFromString(message.SerializeToString()) 433 self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits) 434 435 message.optional_double = -kMostNegExponentOneSigBit 436 message.ParseFromString(message.SerializeToString()) 437 self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) 438 439 def testFloatPrinting(self, message_module): 440 message = message_module.TestAllTypes() 441 message.optional_float = 2.0 442 self.assertEqual(str(message), 'optional_float: 2.0\n') 443 444 def testHighPrecisionFloatPrinting(self, message_module): 445 msg = message_module.TestAllTypes() 446 msg.optional_float = 0.12345678912345678 447 old_float = msg.optional_float 448 msg.ParseFromString(msg.SerializeToString()) 449 self.assertEqual(old_float, msg.optional_float) 450 451 def testHighPrecisionDoublePrinting(self, message_module): 452 msg = message_module.TestAllTypes() 453 msg.optional_double = 0.12345678912345678 454 if sys.version_info >= (3,): 455 self.assertEqual(str(msg), 'optional_double: 0.12345678912345678\n') 456 else: 457 self.assertEqual(str(msg), 'optional_double: 0.123456789123\n') 458 459 def testUnknownFieldPrinting(self, message_module): 460 populated = message_module.TestAllTypes() 461 test_util.SetAllNonLazyFields(populated) 462 empty = message_module.TestEmptyMessage() 463 empty.ParseFromString(populated.SerializeToString()) 464 self.assertEqual(str(empty), '') 465 466 def testAppendRepeatedCompositeField(self, message_module): 467 msg = message_module.TestAllTypes() 468 msg.repeated_nested_message.append( 469 message_module.TestAllTypes.NestedMessage(bb=1)) 470 nested = message_module.TestAllTypes.NestedMessage(bb=2) 471 msg.repeated_nested_message.append(nested) 472 try: 473 msg.repeated_nested_message.append(1) 474 except TypeError: 475 pass 476 self.assertEqual(2, len(msg.repeated_nested_message)) 477 self.assertEqual([1, 2], 478 [m.bb for m in msg.repeated_nested_message]) 479 480 def testInsertRepeatedCompositeField(self, message_module): 481 msg = message_module.TestAllTypes() 482 msg.repeated_nested_message.insert( 483 -1, message_module.TestAllTypes.NestedMessage(bb=1)) 484 sub_msg = msg.repeated_nested_message[0] 485 msg.repeated_nested_message.insert( 486 0, message_module.TestAllTypes.NestedMessage(bb=2)) 487 msg.repeated_nested_message.insert( 488 99, message_module.TestAllTypes.NestedMessage(bb=3)) 489 msg.repeated_nested_message.insert( 490 -2, message_module.TestAllTypes.NestedMessage(bb=-1)) 491 msg.repeated_nested_message.insert( 492 -1000, message_module.TestAllTypes.NestedMessage(bb=-1000)) 493 try: 494 msg.repeated_nested_message.insert(1, 999) 495 except TypeError: 496 pass 497 self.assertEqual(5, len(msg.repeated_nested_message)) 498 self.assertEqual([-1000, 2, -1, 1, 3], 499 [m.bb for m in msg.repeated_nested_message]) 500 self.assertEqual(str(msg), 501 'repeated_nested_message {\n' 502 ' bb: -1000\n' 503 '}\n' 504 'repeated_nested_message {\n' 505 ' bb: 2\n' 506 '}\n' 507 'repeated_nested_message {\n' 508 ' bb: -1\n' 509 '}\n' 510 'repeated_nested_message {\n' 511 ' bb: 1\n' 512 '}\n' 513 'repeated_nested_message {\n' 514 ' bb: 3\n' 515 '}\n') 516 self.assertEqual(sub_msg.bb, 1) 517 518 def testMergeFromRepeatedField(self, message_module): 519 msg = message_module.TestAllTypes() 520 msg.repeated_int32.append(1) 521 msg.repeated_int32.append(3) 522 msg.repeated_nested_message.add(bb=1) 523 msg.repeated_nested_message.add(bb=2) 524 other_msg = message_module.TestAllTypes() 525 other_msg.repeated_nested_message.add(bb=3) 526 other_msg.repeated_nested_message.add(bb=4) 527 other_msg.repeated_int32.append(5) 528 other_msg.repeated_int32.append(7) 529 530 msg.repeated_int32.MergeFrom(other_msg.repeated_int32) 531 self.assertEqual(4, len(msg.repeated_int32)) 532 533 msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message) 534 self.assertEqual([1, 2, 3, 4], 535 [m.bb for m in msg.repeated_nested_message]) 536 537 def testAddWrongRepeatedNestedField(self, message_module): 538 msg = message_module.TestAllTypes() 539 try: 540 msg.repeated_nested_message.add('wrong') 541 except TypeError: 542 pass 543 try: 544 msg.repeated_nested_message.add(value_field='wrong') 545 except ValueError: 546 pass 547 self.assertEqual(len(msg.repeated_nested_message), 0) 548 549 def testRepeatedContains(self, message_module): 550 msg = message_module.TestAllTypes() 551 msg.repeated_int32.extend([1, 2, 3]) 552 self.assertIn(2, msg.repeated_int32) 553 self.assertNotIn(0, msg.repeated_int32) 554 555 msg.repeated_nested_message.add(bb=1) 556 sub_msg1 = msg.repeated_nested_message[0] 557 sub_msg2 = message_module.TestAllTypes.NestedMessage(bb=2) 558 sub_msg3 = message_module.TestAllTypes.NestedMessage(bb=3) 559 msg.repeated_nested_message.append(sub_msg2) 560 msg.repeated_nested_message.insert(0, sub_msg3) 561 self.assertIn(sub_msg1, msg.repeated_nested_message) 562 self.assertIn(sub_msg2, msg.repeated_nested_message) 563 self.assertIn(sub_msg3, msg.repeated_nested_message) 564 565 def testRepeatedScalarIterable(self, message_module): 566 msg = message_module.TestAllTypes() 567 msg.repeated_int32.extend([1, 2, 3]) 568 add = 0 569 for item in msg.repeated_int32: 570 add += item 571 self.assertEqual(add, 6) 572 573 def testRepeatedNestedFieldIteration(self, message_module): 574 msg = message_module.TestAllTypes() 575 msg.repeated_nested_message.add(bb=1) 576 msg.repeated_nested_message.add(bb=2) 577 msg.repeated_nested_message.add(bb=3) 578 msg.repeated_nested_message.add(bb=4) 579 580 self.assertEqual([1, 2, 3, 4], 581 [m.bb for m in msg.repeated_nested_message]) 582 self.assertEqual([4, 3, 2, 1], 583 [m.bb for m in reversed(msg.repeated_nested_message)]) 584 self.assertEqual([4, 3, 2, 1], 585 [m.bb for m in msg.repeated_nested_message[::-1]]) 586 587 def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module): 588 """Check some different types with the default comparator.""" 589 message = message_module.TestAllTypes() 590 591 # TODO(mattp): would testing more scalar types strengthen test? 592 message.repeated_int32.append(1) 593 message.repeated_int32.append(3) 594 message.repeated_int32.append(2) 595 message.repeated_int32.sort() 596 self.assertEqual(message.repeated_int32[0], 1) 597 self.assertEqual(message.repeated_int32[1], 2) 598 self.assertEqual(message.repeated_int32[2], 3) 599 self.assertEqual(str(message.repeated_int32), str([1, 2, 3])) 600 601 message.repeated_float.append(1.1) 602 message.repeated_float.append(1.3) 603 message.repeated_float.append(1.2) 604 message.repeated_float.sort() 605 self.assertAlmostEqual(message.repeated_float[0], 1.1) 606 self.assertAlmostEqual(message.repeated_float[1], 1.2) 607 self.assertAlmostEqual(message.repeated_float[2], 1.3) 608 609 message.repeated_string.append('a') 610 message.repeated_string.append('c') 611 message.repeated_string.append('b') 612 message.repeated_string.sort() 613 self.assertEqual(message.repeated_string[0], 'a') 614 self.assertEqual(message.repeated_string[1], 'b') 615 self.assertEqual(message.repeated_string[2], 'c') 616 self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c'])) 617 618 message.repeated_bytes.append(b'a') 619 message.repeated_bytes.append(b'c') 620 message.repeated_bytes.append(b'b') 621 message.repeated_bytes.sort() 622 self.assertEqual(message.repeated_bytes[0], b'a') 623 self.assertEqual(message.repeated_bytes[1], b'b') 624 self.assertEqual(message.repeated_bytes[2], b'c') 625 self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c'])) 626 627 def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): 628 """Check some different types with custom comparator.""" 629 message = message_module.TestAllTypes() 630 631 message.repeated_int32.append(-3) 632 message.repeated_int32.append(-2) 633 message.repeated_int32.append(-1) 634 message.repeated_int32.sort(key=abs) 635 self.assertEqual(message.repeated_int32[0], -1) 636 self.assertEqual(message.repeated_int32[1], -2) 637 self.assertEqual(message.repeated_int32[2], -3) 638 639 message.repeated_string.append('aaa') 640 message.repeated_string.append('bb') 641 message.repeated_string.append('c') 642 message.repeated_string.sort(key=len) 643 self.assertEqual(message.repeated_string[0], 'c') 644 self.assertEqual(message.repeated_string[1], 'bb') 645 self.assertEqual(message.repeated_string[2], 'aaa') 646 647 def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module): 648 """Check passing a custom comparator to sort a repeated composite field.""" 649 message = message_module.TestAllTypes() 650 651 message.repeated_nested_message.add().bb = 1 652 message.repeated_nested_message.add().bb = 3 653 message.repeated_nested_message.add().bb = 2 654 message.repeated_nested_message.add().bb = 6 655 message.repeated_nested_message.add().bb = 5 656 message.repeated_nested_message.add().bb = 4 657 message.repeated_nested_message.sort(key=operator.attrgetter('bb')) 658 self.assertEqual(message.repeated_nested_message[0].bb, 1) 659 self.assertEqual(message.repeated_nested_message[1].bb, 2) 660 self.assertEqual(message.repeated_nested_message[2].bb, 3) 661 self.assertEqual(message.repeated_nested_message[3].bb, 4) 662 self.assertEqual(message.repeated_nested_message[4].bb, 5) 663 self.assertEqual(message.repeated_nested_message[5].bb, 6) 664 self.assertEqual(str(message.repeated_nested_message), 665 '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]') 666 667 def testSortingRepeatedCompositeFieldsStable(self, message_module): 668 """Check passing a custom comparator to sort a repeated composite field.""" 669 message = message_module.TestAllTypes() 670 671 message.repeated_nested_message.add().bb = 21 672 message.repeated_nested_message.add().bb = 20 673 message.repeated_nested_message.add().bb = 13 674 message.repeated_nested_message.add().bb = 33 675 message.repeated_nested_message.add().bb = 11 676 message.repeated_nested_message.add().bb = 24 677 message.repeated_nested_message.add().bb = 10 678 message.repeated_nested_message.sort(key=lambda z: z.bb // 10) 679 self.assertEqual( 680 [13, 11, 10, 21, 20, 24, 33], 681 [n.bb for n in message.repeated_nested_message]) 682 683 # Make sure that for the C++ implementation, the underlying fields 684 # are actually reordered. 685 pb = message.SerializeToString() 686 message.Clear() 687 message.MergeFromString(pb) 688 self.assertEqual( 689 [13, 11, 10, 21, 20, 24, 33], 690 [n.bb for n in message.repeated_nested_message]) 691 692 def testRepeatedCompositeFieldSortArguments(self, message_module): 693 """Check sorting a repeated composite field using list.sort() arguments.""" 694 message = message_module.TestAllTypes() 695 696 get_bb = operator.attrgetter('bb') 697 cmp_bb = lambda a, b: cmp(a.bb, b.bb) 698 message.repeated_nested_message.add().bb = 1 699 message.repeated_nested_message.add().bb = 3 700 message.repeated_nested_message.add().bb = 2 701 message.repeated_nested_message.add().bb = 6 702 message.repeated_nested_message.add().bb = 5 703 message.repeated_nested_message.add().bb = 4 704 message.repeated_nested_message.sort(key=get_bb) 705 self.assertEqual([k.bb for k in message.repeated_nested_message], 706 [1, 2, 3, 4, 5, 6]) 707 message.repeated_nested_message.sort(key=get_bb, reverse=True) 708 self.assertEqual([k.bb for k in message.repeated_nested_message], 709 [6, 5, 4, 3, 2, 1]) 710 if sys.version_info >= (3,): return # No cmp sorting in PY3. 711 message.repeated_nested_message.sort(sort_function=cmp_bb) 712 self.assertEqual([k.bb for k in message.repeated_nested_message], 713 [1, 2, 3, 4, 5, 6]) 714 message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) 715 self.assertEqual([k.bb for k in message.repeated_nested_message], 716 [6, 5, 4, 3, 2, 1]) 717 718 def testRepeatedScalarFieldSortArguments(self, message_module): 719 """Check sorting a scalar field using list.sort() arguments.""" 720 message = message_module.TestAllTypes() 721 722 message.repeated_int32.append(-3) 723 message.repeated_int32.append(-2) 724 message.repeated_int32.append(-1) 725 message.repeated_int32.sort(key=abs) 726 self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) 727 message.repeated_int32.sort(key=abs, reverse=True) 728 self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) 729 if sys.version_info < (3,): # No cmp sorting in PY3. 730 abs_cmp = lambda a, b: cmp(abs(a), abs(b)) 731 message.repeated_int32.sort(sort_function=abs_cmp) 732 self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) 733 message.repeated_int32.sort(cmp=abs_cmp, reverse=True) 734 self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) 735 736 message.repeated_string.append('aaa') 737 message.repeated_string.append('bb') 738 message.repeated_string.append('c') 739 message.repeated_string.sort(key=len) 740 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) 741 message.repeated_string.sort(key=len, reverse=True) 742 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) 743 if sys.version_info < (3,): # No cmp sorting in PY3. 744 len_cmp = lambda a, b: cmp(len(a), len(b)) 745 message.repeated_string.sort(sort_function=len_cmp) 746 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) 747 message.repeated_string.sort(cmp=len_cmp, reverse=True) 748 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) 749 750 def testRepeatedFieldsComparable(self, message_module): 751 m1 = message_module.TestAllTypes() 752 m2 = message_module.TestAllTypes() 753 m1.repeated_int32.append(0) 754 m1.repeated_int32.append(1) 755 m1.repeated_int32.append(2) 756 m2.repeated_int32.append(0) 757 m2.repeated_int32.append(1) 758 m2.repeated_int32.append(2) 759 m1.repeated_nested_message.add().bb = 1 760 m1.repeated_nested_message.add().bb = 2 761 m1.repeated_nested_message.add().bb = 3 762 m2.repeated_nested_message.add().bb = 1 763 m2.repeated_nested_message.add().bb = 2 764 m2.repeated_nested_message.add().bb = 3 765 766 if sys.version_info >= (3,): return # No cmp() in PY3. 767 768 # These comparisons should not raise errors. 769 _ = m1 < m2 770 _ = m1.repeated_nested_message < m2.repeated_nested_message 771 772 # Make sure cmp always works. If it wasn't defined, these would be 773 # id() comparisons and would all fail. 774 self.assertEqual(cmp(m1, m2), 0) 775 self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0) 776 self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0) 777 self.assertEqual(cmp(m1.repeated_nested_message, 778 m2.repeated_nested_message), 0) 779 with self.assertRaises(TypeError): 780 # Can't compare repeated composite containers to lists. 781 cmp(m1.repeated_nested_message, m2.repeated_nested_message[:]) 782 783 # TODO(anuraag): Implement extensiondict comparison in C++ and then add test 784 785 def testRepeatedFieldsAreSequences(self, message_module): 786 m = message_module.TestAllTypes() 787 self.assertIsInstance(m.repeated_int32, collections_abc.MutableSequence) 788 self.assertIsInstance(m.repeated_nested_message, 789 collections_abc.MutableSequence) 790 791 def testRepeatedFieldsNotHashable(self, message_module): 792 m = message_module.TestAllTypes() 793 with self.assertRaises(TypeError): 794 hash(m.repeated_int32) 795 with self.assertRaises(TypeError): 796 hash(m.repeated_nested_message) 797 798 def testRepeatedFieldInsideNestedMessage(self, message_module): 799 m = message_module.NestedTestAllTypes() 800 m.payload.repeated_int32.extend([]) 801 self.assertTrue(m.HasField('payload')) 802 803 def testMergeFrom(self, message_module): 804 m1 = message_module.TestAllTypes() 805 m2 = message_module.TestAllTypes() 806 # Cpp extension will lazily create a sub message which is immutable. 807 nested = m1.optional_nested_message 808 self.assertEqual(0, nested.bb) 809 m2.optional_nested_message.bb = 1 810 # Make sure cmessage pointing to a mutable message after merge instead of 811 # the lazily created message. 812 m1.MergeFrom(m2) 813 self.assertEqual(1, nested.bb) 814 815 # Test more nested sub message. 816 msg1 = message_module.NestedTestAllTypes() 817 msg2 = message_module.NestedTestAllTypes() 818 nested = msg1.child.payload.optional_nested_message 819 self.assertEqual(0, nested.bb) 820 msg2.child.payload.optional_nested_message.bb = 1 821 msg1.MergeFrom(msg2) 822 self.assertEqual(1, nested.bb) 823 824 # Test repeated field. 825 self.assertEqual(msg1.payload.repeated_nested_message, 826 msg1.payload.repeated_nested_message) 827 nested = msg2.payload.repeated_nested_message.add() 828 nested.bb = 1 829 msg1.MergeFrom(msg2) 830 self.assertEqual(1, len(msg1.payload.repeated_nested_message)) 831 self.assertEqual(1, nested.bb) 832 833 def testMergeFromString(self, message_module): 834 m1 = message_module.TestAllTypes() 835 m2 = message_module.TestAllTypes() 836 # Cpp extension will lazily create a sub message which is immutable. 837 self.assertEqual(0, m1.optional_nested_message.bb) 838 m2.optional_nested_message.bb = 1 839 # Make sure cmessage pointing to a mutable message after merge instead of 840 # the lazily created message. 841 m1.MergeFromString(m2.SerializeToString()) 842 self.assertEqual(1, m1.optional_nested_message.bb) 843 844 @unittest.skipIf(six.PY2, 'memoryview objects are not supported on py2') 845 def testMergeFromStringUsingMemoryViewWorksInPy3(self, message_module): 846 m2 = message_module.TestAllTypes() 847 m2.optional_string = 'scalar string' 848 m2.repeated_string.append('repeated string') 849 m2.optional_bytes = b'scalar bytes' 850 m2.repeated_bytes.append(b'repeated bytes') 851 852 serialized = m2.SerializeToString() 853 memview = memoryview(serialized) 854 m1 = message_module.TestAllTypes.FromString(memview) 855 856 self.assertEqual(m1.optional_bytes, b'scalar bytes') 857 self.assertEqual(m1.repeated_bytes, [b'repeated bytes']) 858 self.assertEqual(m1.optional_string, 'scalar string') 859 self.assertEqual(m1.repeated_string, ['repeated string']) 860 # Make sure that the memoryview was correctly converted to bytes, and 861 # that a sub-sliced memoryview is not being used. 862 self.assertIsInstance(m1.optional_bytes, bytes) 863 self.assertIsInstance(m1.repeated_bytes[0], bytes) 864 self.assertIsInstance(m1.optional_string, six.text_type) 865 self.assertIsInstance(m1.repeated_string[0], six.text_type) 866 867 @unittest.skipIf(six.PY3, 'memoryview is supported by py3') 868 def testMergeFromStringUsingMemoryViewIsPy2Error(self, message_module): 869 memview = memoryview(b'') 870 with self.assertRaises(TypeError): 871 message_module.TestAllTypes.FromString(memview) 872 873 def testMergeFromEmpty(self, message_module): 874 m1 = message_module.TestAllTypes() 875 # Cpp extension will lazily create a sub message which is immutable. 876 self.assertEqual(0, m1.optional_nested_message.bb) 877 self.assertFalse(m1.HasField('optional_nested_message')) 878 # Make sure the sub message is still immutable after merge from empty. 879 m1.MergeFromString(b'') # field state should not change 880 self.assertFalse(m1.HasField('optional_nested_message')) 881 882 def ensureNestedMessageExists(self, msg, attribute): 883 """Make sure that a nested message object exists. 884 885 As soon as a nested message attribute is accessed, it will be present in the 886 _fields dict, without being marked as actually being set. 887 """ 888 getattr(msg, attribute) 889 self.assertFalse(msg.HasField(attribute)) 890 891 def testOneofGetCaseNonexistingField(self, message_module): 892 m = message_module.TestAllTypes() 893 self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') 894 self.assertRaises(Exception, m.WhichOneof, 0) 895 896 def testOneofDefaultValues(self, message_module): 897 m = message_module.TestAllTypes() 898 self.assertIs(None, m.WhichOneof('oneof_field')) 899 self.assertFalse(m.HasField('oneof_field')) 900 self.assertFalse(m.HasField('oneof_uint32')) 901 902 # Oneof is set even when setting it to a default value. 903 m.oneof_uint32 = 0 904 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 905 self.assertTrue(m.HasField('oneof_field')) 906 self.assertTrue(m.HasField('oneof_uint32')) 907 self.assertFalse(m.HasField('oneof_string')) 908 909 m.oneof_string = "" 910 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 911 self.assertTrue(m.HasField('oneof_string')) 912 self.assertFalse(m.HasField('oneof_uint32')) 913 914 def testOneofSemantics(self, message_module): 915 m = message_module.TestAllTypes() 916 self.assertIs(None, m.WhichOneof('oneof_field')) 917 918 m.oneof_uint32 = 11 919 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 920 self.assertTrue(m.HasField('oneof_uint32')) 921 922 m.oneof_string = u'foo' 923 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 924 self.assertFalse(m.HasField('oneof_uint32')) 925 self.assertTrue(m.HasField('oneof_string')) 926 927 # Read nested message accessor without accessing submessage. 928 m.oneof_nested_message 929 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 930 self.assertTrue(m.HasField('oneof_string')) 931 self.assertFalse(m.HasField('oneof_nested_message')) 932 933 # Read accessor of nested message without accessing submessage. 934 m.oneof_nested_message.bb 935 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 936 self.assertTrue(m.HasField('oneof_string')) 937 self.assertFalse(m.HasField('oneof_nested_message')) 938 939 m.oneof_nested_message.bb = 11 940 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 941 self.assertFalse(m.HasField('oneof_string')) 942 self.assertTrue(m.HasField('oneof_nested_message')) 943 944 m.oneof_bytes = b'bb' 945 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 946 self.assertFalse(m.HasField('oneof_nested_message')) 947 self.assertTrue(m.HasField('oneof_bytes')) 948 949 def testOneofCompositeFieldReadAccess(self, message_module): 950 m = message_module.TestAllTypes() 951 m.oneof_uint32 = 11 952 953 self.ensureNestedMessageExists(m, 'oneof_nested_message') 954 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 955 self.assertEqual(11, m.oneof_uint32) 956 957 def testOneofWhichOneof(self, message_module): 958 m = message_module.TestAllTypes() 959 self.assertIs(None, m.WhichOneof('oneof_field')) 960 if message_module is unittest_pb2: 961 self.assertFalse(m.HasField('oneof_field')) 962 963 m.oneof_uint32 = 11 964 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 965 if message_module is unittest_pb2: 966 self.assertTrue(m.HasField('oneof_field')) 967 968 m.oneof_bytes = b'bb' 969 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 970 971 m.ClearField('oneof_bytes') 972 self.assertIs(None, m.WhichOneof('oneof_field')) 973 if message_module is unittest_pb2: 974 self.assertFalse(m.HasField('oneof_field')) 975 976 def testOneofClearField(self, message_module): 977 m = message_module.TestAllTypes() 978 m.oneof_uint32 = 11 979 m.ClearField('oneof_field') 980 if message_module is unittest_pb2: 981 self.assertFalse(m.HasField('oneof_field')) 982 self.assertFalse(m.HasField('oneof_uint32')) 983 self.assertIs(None, m.WhichOneof('oneof_field')) 984 985 def testOneofClearSetField(self, message_module): 986 m = message_module.TestAllTypes() 987 m.oneof_uint32 = 11 988 m.ClearField('oneof_uint32') 989 if message_module is unittest_pb2: 990 self.assertFalse(m.HasField('oneof_field')) 991 self.assertFalse(m.HasField('oneof_uint32')) 992 self.assertIs(None, m.WhichOneof('oneof_field')) 993 994 def testOneofClearUnsetField(self, message_module): 995 m = message_module.TestAllTypes() 996 m.oneof_uint32 = 11 997 self.ensureNestedMessageExists(m, 'oneof_nested_message') 998 m.ClearField('oneof_nested_message') 999 self.assertEqual(11, m.oneof_uint32) 1000 if message_module is unittest_pb2: 1001 self.assertTrue(m.HasField('oneof_field')) 1002 self.assertTrue(m.HasField('oneof_uint32')) 1003 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 1004 1005 def testOneofDeserialize(self, message_module): 1006 m = message_module.TestAllTypes() 1007 m.oneof_uint32 = 11 1008 m2 = message_module.TestAllTypes() 1009 m2.ParseFromString(m.SerializeToString()) 1010 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 1011 1012 def testOneofCopyFrom(self, message_module): 1013 m = message_module.TestAllTypes() 1014 m.oneof_uint32 = 11 1015 m2 = message_module.TestAllTypes() 1016 m2.CopyFrom(m) 1017 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 1018 1019 def testOneofNestedMergeFrom(self, message_module): 1020 m = message_module.NestedTestAllTypes() 1021 m.payload.oneof_uint32 = 11 1022 m2 = message_module.NestedTestAllTypes() 1023 m2.payload.oneof_bytes = b'bb' 1024 m2.child.payload.oneof_bytes = b'bb' 1025 m2.MergeFrom(m) 1026 self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field')) 1027 self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field')) 1028 1029 def testOneofMessageMergeFrom(self, message_module): 1030 m = message_module.NestedTestAllTypes() 1031 m.payload.oneof_nested_message.bb = 11 1032 m.child.payload.oneof_nested_message.bb = 12 1033 m2 = message_module.NestedTestAllTypes() 1034 m2.payload.oneof_uint32 = 13 1035 m2.MergeFrom(m) 1036 self.assertEqual('oneof_nested_message', 1037 m2.payload.WhichOneof('oneof_field')) 1038 self.assertEqual('oneof_nested_message', 1039 m2.child.payload.WhichOneof('oneof_field')) 1040 1041 def testOneofNestedMessageInit(self, message_module): 1042 m = message_module.TestAllTypes( 1043 oneof_nested_message=message_module.TestAllTypes.NestedMessage()) 1044 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 1045 1046 def testOneofClear(self, message_module): 1047 m = message_module.TestAllTypes() 1048 m.oneof_uint32 = 11 1049 m.Clear() 1050 self.assertIsNone(m.WhichOneof('oneof_field')) 1051 m.oneof_bytes = b'bb' 1052 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 1053 1054 def testAssignByteStringToUnicodeField(self, message_module): 1055 """Assigning a byte string to a string field should result 1056 in the value being converted to a Unicode string.""" 1057 m = message_module.TestAllTypes() 1058 m.optional_string = str('') 1059 self.assertIsInstance(m.optional_string, six.text_type) 1060 1061 def testLongValuedSlice(self, message_module): 1062 """It should be possible to use long-valued indices in slices. 1063 1064 This didn't used to work in the v2 C++ implementation. 1065 """ 1066 m = message_module.TestAllTypes() 1067 1068 # Repeated scalar 1069 m.repeated_int32.append(1) 1070 sl = m.repeated_int32[long(0):long(len(m.repeated_int32))] 1071 self.assertEqual(len(m.repeated_int32), len(sl)) 1072 1073 # Repeated composite 1074 m.repeated_nested_message.add().bb = 3 1075 sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))] 1076 self.assertEqual(len(m.repeated_nested_message), len(sl)) 1077 1078 def testExtendShouldNotSwallowExceptions(self, message_module): 1079 """This didn't use to work in the v2 C++ implementation.""" 1080 m = message_module.TestAllTypes() 1081 with self.assertRaises(NameError) as _: 1082 m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable 1083 with self.assertRaises(NameError) as _: 1084 m.repeated_nested_enum.extend( 1085 a for i in range(10)) # pylint: disable=undefined-variable 1086 1087 FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()] 1088 1089 def testExtendInt32WithNothing(self, message_module): 1090 """Test no-ops extending repeated int32 fields.""" 1091 m = message_module.TestAllTypes() 1092 self.assertSequenceEqual([], m.repeated_int32) 1093 1094 # TODO(ptucker): Deprecate this behavior. b/18413862 1095 for falsy_value in MessageTest.FALSY_VALUES: 1096 m.repeated_int32.extend(falsy_value) 1097 self.assertSequenceEqual([], m.repeated_int32) 1098 1099 m.repeated_int32.extend([]) 1100 self.assertSequenceEqual([], m.repeated_int32) 1101 1102 def testExtendFloatWithNothing(self, message_module): 1103 """Test no-ops extending repeated float fields.""" 1104 m = message_module.TestAllTypes() 1105 self.assertSequenceEqual([], m.repeated_float) 1106 1107 # TODO(ptucker): Deprecate this behavior. b/18413862 1108 for falsy_value in MessageTest.FALSY_VALUES: 1109 m.repeated_float.extend(falsy_value) 1110 self.assertSequenceEqual([], m.repeated_float) 1111 1112 m.repeated_float.extend([]) 1113 self.assertSequenceEqual([], m.repeated_float) 1114 1115 def testExtendStringWithNothing(self, message_module): 1116 """Test no-ops extending repeated string fields.""" 1117 m = message_module.TestAllTypes() 1118 self.assertSequenceEqual([], m.repeated_string) 1119 1120 # TODO(ptucker): Deprecate this behavior. b/18413862 1121 for falsy_value in MessageTest.FALSY_VALUES: 1122 m.repeated_string.extend(falsy_value) 1123 self.assertSequenceEqual([], m.repeated_string) 1124 1125 m.repeated_string.extend([]) 1126 self.assertSequenceEqual([], m.repeated_string) 1127 1128 def testExtendInt32WithPythonList(self, message_module): 1129 """Test extending repeated int32 fields with python lists.""" 1130 m = message_module.TestAllTypes() 1131 self.assertSequenceEqual([], m.repeated_int32) 1132 m.repeated_int32.extend([0]) 1133 self.assertSequenceEqual([0], m.repeated_int32) 1134 m.repeated_int32.extend([1, 2]) 1135 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 1136 m.repeated_int32.extend([3, 4]) 1137 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 1138 1139 def testExtendFloatWithPythonList(self, message_module): 1140 """Test extending repeated float fields with python lists.""" 1141 m = message_module.TestAllTypes() 1142 self.assertSequenceEqual([], m.repeated_float) 1143 m.repeated_float.extend([0.0]) 1144 self.assertSequenceEqual([0.0], m.repeated_float) 1145 m.repeated_float.extend([1.0, 2.0]) 1146 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 1147 m.repeated_float.extend([3.0, 4.0]) 1148 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 1149 1150 def testExtendStringWithPythonList(self, message_module): 1151 """Test extending repeated string fields with python lists.""" 1152 m = message_module.TestAllTypes() 1153 self.assertSequenceEqual([], m.repeated_string) 1154 m.repeated_string.extend(['']) 1155 self.assertSequenceEqual([''], m.repeated_string) 1156 m.repeated_string.extend(['11', '22']) 1157 self.assertSequenceEqual(['', '11', '22'], m.repeated_string) 1158 m.repeated_string.extend(['33', '44']) 1159 self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string) 1160 1161 def testExtendStringWithString(self, message_module): 1162 """Test extending repeated string fields with characters from a string.""" 1163 m = message_module.TestAllTypes() 1164 self.assertSequenceEqual([], m.repeated_string) 1165 m.repeated_string.extend('abc') 1166 self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string) 1167 1168 class TestIterable(object): 1169 """This iterable object mimics the behavior of numpy.array. 1170 1171 __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1. 1172 1173 """ 1174 1175 def __init__(self, values=None): 1176 self._list = values or [] 1177 1178 def __nonzero__(self): 1179 size = len(self._list) 1180 if size == 0: 1181 return False 1182 if size == 1: 1183 return bool(self._list[0]) 1184 raise ValueError('Truth value is ambiguous.') 1185 1186 def __len__(self): 1187 return len(self._list) 1188 1189 def __iter__(self): 1190 return self._list.__iter__() 1191 1192 def testExtendInt32WithIterable(self, message_module): 1193 """Test extending repeated int32 fields with iterable.""" 1194 m = message_module.TestAllTypes() 1195 self.assertSequenceEqual([], m.repeated_int32) 1196 m.repeated_int32.extend(MessageTest.TestIterable([])) 1197 self.assertSequenceEqual([], m.repeated_int32) 1198 m.repeated_int32.extend(MessageTest.TestIterable([0])) 1199 self.assertSequenceEqual([0], m.repeated_int32) 1200 m.repeated_int32.extend(MessageTest.TestIterable([1, 2])) 1201 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 1202 m.repeated_int32.extend(MessageTest.TestIterable([3, 4])) 1203 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 1204 1205 def testExtendFloatWithIterable(self, message_module): 1206 """Test extending repeated float fields with iterable.""" 1207 m = message_module.TestAllTypes() 1208 self.assertSequenceEqual([], m.repeated_float) 1209 m.repeated_float.extend(MessageTest.TestIterable([])) 1210 self.assertSequenceEqual([], m.repeated_float) 1211 m.repeated_float.extend(MessageTest.TestIterable([0.0])) 1212 self.assertSequenceEqual([0.0], m.repeated_float) 1213 m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0])) 1214 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 1215 m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0])) 1216 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 1217 1218 def testExtendStringWithIterable(self, message_module): 1219 """Test extending repeated string fields with iterable.""" 1220 m = message_module.TestAllTypes() 1221 self.assertSequenceEqual([], m.repeated_string) 1222 m.repeated_string.extend(MessageTest.TestIterable([])) 1223 self.assertSequenceEqual([], m.repeated_string) 1224 m.repeated_string.extend(MessageTest.TestIterable([''])) 1225 self.assertSequenceEqual([''], m.repeated_string) 1226 m.repeated_string.extend(MessageTest.TestIterable(['1', '2'])) 1227 self.assertSequenceEqual(['', '1', '2'], m.repeated_string) 1228 m.repeated_string.extend(MessageTest.TestIterable(['3', '4'])) 1229 self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string) 1230 1231 def testPickleRepeatedScalarContainer(self, message_module): 1232 # TODO(tibell): The pure-Python implementation support pickling of 1233 # scalar containers in *some* cases. For now the cpp2 version 1234 # throws an exception to avoid a segfault. Investigate if we 1235 # want to support pickling of these fields. 1236 # 1237 # For more information see: https://b2.corp.google.com/u/0/issues/18677897 1238 if (api_implementation.Type() != 'cpp' or 1239 api_implementation.Version() == 2): 1240 return 1241 m = message_module.TestAllTypes() 1242 with self.assertRaises(pickle.PickleError) as _: 1243 pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) 1244 1245 def testSortEmptyRepeatedCompositeContainer(self, message_module): 1246 """Exercise a scenario that has led to segfaults in the past. 1247 """ 1248 m = message_module.TestAllTypes() 1249 m.repeated_nested_message.sort() 1250 1251 def testHasFieldOnRepeatedField(self, message_module): 1252 """Using HasField on a repeated field should raise an exception. 1253 """ 1254 m = message_module.TestAllTypes() 1255 with self.assertRaises(ValueError) as _: 1256 m.HasField('repeated_int32') 1257 1258 def testRepeatedScalarFieldPop(self, message_module): 1259 m = message_module.TestAllTypes() 1260 with self.assertRaises(IndexError) as _: 1261 m.repeated_int32.pop() 1262 m.repeated_int32.extend(range(5)) 1263 self.assertEqual(4, m.repeated_int32.pop()) 1264 self.assertEqual(0, m.repeated_int32.pop(0)) 1265 self.assertEqual(2, m.repeated_int32.pop(1)) 1266 self.assertEqual([1, 3], m.repeated_int32) 1267 1268 def testRepeatedCompositeFieldPop(self, message_module): 1269 m = message_module.TestAllTypes() 1270 with self.assertRaises(IndexError) as _: 1271 m.repeated_nested_message.pop() 1272 with self.assertRaises(TypeError) as _: 1273 m.repeated_nested_message.pop('0') 1274 for i in range(5): 1275 n = m.repeated_nested_message.add() 1276 n.bb = i 1277 self.assertEqual(4, m.repeated_nested_message.pop().bb) 1278 self.assertEqual(0, m.repeated_nested_message.pop(0).bb) 1279 self.assertEqual(2, m.repeated_nested_message.pop(1).bb) 1280 self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message]) 1281 1282 def testRepeatedCompareWithSelf(self, message_module): 1283 m = message_module.TestAllTypes() 1284 for i in range(5): 1285 m.repeated_int32.insert(i, i) 1286 n = m.repeated_nested_message.add() 1287 n.bb = i 1288 self.assertSequenceEqual(m.repeated_int32, m.repeated_int32) 1289 self.assertEqual(m.repeated_nested_message, m.repeated_nested_message) 1290 1291 def testReleasedNestedMessages(self, message_module): 1292 """A case that lead to a segfault when a message detached from its parent 1293 container has itself a child container. 1294 """ 1295 m = message_module.NestedTestAllTypes() 1296 m = m.repeated_child.add() 1297 m = m.child 1298 m = m.repeated_child.add() 1299 self.assertEqual(m.payload.optional_int32, 0) 1300 1301 def testSetRepeatedComposite(self, message_module): 1302 m = message_module.TestAllTypes() 1303 with self.assertRaises(AttributeError): 1304 m.repeated_int32 = [] 1305 m.repeated_int32.append(1) 1306 with self.assertRaises(AttributeError): 1307 m.repeated_int32 = [] 1308 1309 def testReturningType(self, message_module): 1310 m = message_module.TestAllTypes() 1311 self.assertEqual(float, type(m.optional_float)) 1312 self.assertEqual(float, type(m.optional_double)) 1313 self.assertEqual(bool, type(m.optional_bool)) 1314 m.optional_float = 1 1315 m.optional_double = 1 1316 m.optional_bool = 1 1317 m.repeated_float.append(1) 1318 m.repeated_double.append(1) 1319 m.repeated_bool.append(1) 1320 m.ParseFromString(m.SerializeToString()) 1321 self.assertEqual(float, type(m.optional_float)) 1322 self.assertEqual(float, type(m.optional_double)) 1323 self.assertEqual('1.0', str(m.optional_double)) 1324 self.assertEqual(bool, type(m.optional_bool)) 1325 self.assertEqual(float, type(m.repeated_float[0])) 1326 self.assertEqual(float, type(m.repeated_double[0])) 1327 self.assertEqual(bool, type(m.repeated_bool[0])) 1328 self.assertEqual(True, m.repeated_bool[0]) 1329 1330 1331# Class to test proto2-only features (required, extensions, etc.) 1332@testing_refleaks.TestCase 1333class Proto2Test(unittest.TestCase): 1334 1335 def testFieldPresence(self): 1336 message = unittest_pb2.TestAllTypes() 1337 1338 self.assertFalse(message.HasField("optional_int32")) 1339 self.assertFalse(message.HasField("optional_bool")) 1340 self.assertFalse(message.HasField("optional_nested_message")) 1341 1342 with self.assertRaises(ValueError): 1343 message.HasField("field_doesnt_exist") 1344 1345 with self.assertRaises(ValueError): 1346 message.HasField("repeated_int32") 1347 with self.assertRaises(ValueError): 1348 message.HasField("repeated_nested_message") 1349 1350 self.assertEqual(0, message.optional_int32) 1351 self.assertEqual(False, message.optional_bool) 1352 self.assertEqual(0, message.optional_nested_message.bb) 1353 1354 # Fields are set even when setting the values to default values. 1355 message.optional_int32 = 0 1356 message.optional_bool = False 1357 message.optional_nested_message.bb = 0 1358 self.assertTrue(message.HasField("optional_int32")) 1359 self.assertTrue(message.HasField("optional_bool")) 1360 self.assertTrue(message.HasField("optional_nested_message")) 1361 1362 # Set the fields to non-default values. 1363 message.optional_int32 = 5 1364 message.optional_bool = True 1365 message.optional_nested_message.bb = 15 1366 1367 self.assertTrue(message.HasField(u"optional_int32")) 1368 self.assertTrue(message.HasField("optional_bool")) 1369 self.assertTrue(message.HasField("optional_nested_message")) 1370 1371 # Clearing the fields unsets them and resets their value to default. 1372 message.ClearField("optional_int32") 1373 message.ClearField(u"optional_bool") 1374 message.ClearField("optional_nested_message") 1375 1376 self.assertFalse(message.HasField("optional_int32")) 1377 self.assertFalse(message.HasField("optional_bool")) 1378 self.assertFalse(message.HasField("optional_nested_message")) 1379 self.assertEqual(0, message.optional_int32) 1380 self.assertEqual(False, message.optional_bool) 1381 self.assertEqual(0, message.optional_nested_message.bb) 1382 1383 def testAssignInvalidEnum(self): 1384 """Assigning an invalid enum number is not allowed in proto2.""" 1385 m = unittest_pb2.TestAllTypes() 1386 1387 # Proto2 can not assign unknown enum. 1388 with self.assertRaises(ValueError) as _: 1389 m.optional_nested_enum = 1234567 1390 self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567) 1391 # Assignment is a different code path than append for the C++ impl. 1392 m.repeated_nested_enum.append(2) 1393 m.repeated_nested_enum[0] = 2 1394 with self.assertRaises(ValueError): 1395 m.repeated_nested_enum[0] = 123456 1396 1397 # Unknown enum value can be parsed but is ignored. 1398 m2 = unittest_proto3_arena_pb2.TestAllTypes() 1399 m2.optional_nested_enum = 1234567 1400 m2.repeated_nested_enum.append(7654321) 1401 serialized = m2.SerializeToString() 1402 1403 m3 = unittest_pb2.TestAllTypes() 1404 m3.ParseFromString(serialized) 1405 self.assertFalse(m3.HasField('optional_nested_enum')) 1406 # 1 is the default value for optional_nested_enum. 1407 self.assertEqual(1, m3.optional_nested_enum) 1408 self.assertEqual(0, len(m3.repeated_nested_enum)) 1409 m2.Clear() 1410 m2.ParseFromString(m3.SerializeToString()) 1411 self.assertEqual(1234567, m2.optional_nested_enum) 1412 self.assertEqual(7654321, m2.repeated_nested_enum[0]) 1413 1414 def testUnknownEnumMap(self): 1415 m = map_proto2_unittest_pb2.TestEnumMap() 1416 m.known_map_field[123] = 0 1417 with self.assertRaises(ValueError): 1418 m.unknown_map_field[1] = 123 1419 1420 def testExtensionsErrors(self): 1421 msg = unittest_pb2.TestAllTypes() 1422 self.assertRaises(AttributeError, getattr, msg, 'Extensions') 1423 1424 def testMergeFromExtensions(self): 1425 msg1 = more_extensions_pb2.TopLevelMessage() 1426 msg2 = more_extensions_pb2.TopLevelMessage() 1427 # Cpp extension will lazily create a sub message which is immutable. 1428 self.assertEqual(0, msg1.submessage.Extensions[ 1429 more_extensions_pb2.optional_int_extension]) 1430 self.assertFalse(msg1.HasField('submessage')) 1431 msg2.submessage.Extensions[ 1432 more_extensions_pb2.optional_int_extension] = 123 1433 # Make sure cmessage and extensions pointing to a mutable message 1434 # after merge instead of the lazily created message. 1435 msg1.MergeFrom(msg2) 1436 self.assertEqual(123, msg1.submessage.Extensions[ 1437 more_extensions_pb2.optional_int_extension]) 1438 1439 def testGoldenExtensions(self): 1440 golden_data = test_util.GoldenFileData('golden_message') 1441 golden_message = unittest_pb2.TestAllExtensions() 1442 golden_message.ParseFromString(golden_data) 1443 all_set = unittest_pb2.TestAllExtensions() 1444 test_util.SetAllExtensions(all_set) 1445 self.assertEqual(all_set, golden_message) 1446 self.assertEqual(golden_data, golden_message.SerializeToString()) 1447 golden_copy = copy.deepcopy(golden_message) 1448 self.assertEqual(golden_data, golden_copy.SerializeToString()) 1449 1450 def testGoldenPackedExtensions(self): 1451 golden_data = test_util.GoldenFileData('golden_packed_fields_message') 1452 golden_message = unittest_pb2.TestPackedExtensions() 1453 golden_message.ParseFromString(golden_data) 1454 all_set = unittest_pb2.TestPackedExtensions() 1455 test_util.SetAllPackedExtensions(all_set) 1456 self.assertEqual(all_set, golden_message) 1457 self.assertEqual(golden_data, all_set.SerializeToString()) 1458 golden_copy = copy.deepcopy(golden_message) 1459 self.assertEqual(golden_data, golden_copy.SerializeToString()) 1460 1461 def testPickleIncompleteProto(self): 1462 golden_message = unittest_pb2.TestRequired(a=1) 1463 pickled_message = pickle.dumps(golden_message) 1464 1465 unpickled_message = pickle.loads(pickled_message) 1466 self.assertEqual(unpickled_message, golden_message) 1467 self.assertEqual(unpickled_message.a, 1) 1468 # This is still an incomplete proto - so serializing should fail 1469 self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) 1470 1471 1472 # TODO(haberman): this isn't really a proto2-specific test except that this 1473 # message has a required field in it. Should probably be factored out so 1474 # that we can test the other parts with proto3. 1475 def testParsingMerge(self): 1476 """Check the merge behavior when a required or optional field appears 1477 multiple times in the input.""" 1478 messages = [ 1479 unittest_pb2.TestAllTypes(), 1480 unittest_pb2.TestAllTypes(), 1481 unittest_pb2.TestAllTypes() ] 1482 messages[0].optional_int32 = 1 1483 messages[1].optional_int64 = 2 1484 messages[2].optional_int32 = 3 1485 messages[2].optional_string = 'hello' 1486 1487 merged_message = unittest_pb2.TestAllTypes() 1488 merged_message.optional_int32 = 3 1489 merged_message.optional_int64 = 2 1490 merged_message.optional_string = 'hello' 1491 1492 generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() 1493 generator.field1.extend(messages) 1494 generator.field2.extend(messages) 1495 generator.field3.extend(messages) 1496 generator.ext1.extend(messages) 1497 generator.ext2.extend(messages) 1498 generator.group1.add().field1.MergeFrom(messages[0]) 1499 generator.group1.add().field1.MergeFrom(messages[1]) 1500 generator.group1.add().field1.MergeFrom(messages[2]) 1501 generator.group2.add().field1.MergeFrom(messages[0]) 1502 generator.group2.add().field1.MergeFrom(messages[1]) 1503 generator.group2.add().field1.MergeFrom(messages[2]) 1504 1505 data = generator.SerializeToString() 1506 parsing_merge = unittest_pb2.TestParsingMerge() 1507 parsing_merge.ParseFromString(data) 1508 1509 # Required and optional fields should be merged. 1510 self.assertEqual(parsing_merge.required_all_types, merged_message) 1511 self.assertEqual(parsing_merge.optional_all_types, merged_message) 1512 self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, 1513 merged_message) 1514 self.assertEqual(parsing_merge.Extensions[ 1515 unittest_pb2.TestParsingMerge.optional_ext], 1516 merged_message) 1517 1518 # Repeated fields should not be merged. 1519 self.assertEqual(len(parsing_merge.repeated_all_types), 3) 1520 self.assertEqual(len(parsing_merge.repeatedgroup), 3) 1521 self.assertEqual(len(parsing_merge.Extensions[ 1522 unittest_pb2.TestParsingMerge.repeated_ext]), 3) 1523 1524 def testPythonicInit(self): 1525 message = unittest_pb2.TestAllTypes( 1526 optional_int32=100, 1527 optional_fixed32=200, 1528 optional_float=300.5, 1529 optional_bytes=b'x', 1530 optionalgroup={'a': 400}, 1531 optional_nested_message={'bb': 500}, 1532 optional_foreign_message={}, 1533 optional_nested_enum='BAZ', 1534 repeatedgroup=[{'a': 600}, 1535 {'a': 700}], 1536 repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], 1537 default_int32=800, 1538 oneof_string='y') 1539 self.assertIsInstance(message, unittest_pb2.TestAllTypes) 1540 self.assertEqual(100, message.optional_int32) 1541 self.assertEqual(200, message.optional_fixed32) 1542 self.assertEqual(300.5, message.optional_float) 1543 self.assertEqual(b'x', message.optional_bytes) 1544 self.assertEqual(400, message.optionalgroup.a) 1545 self.assertIsInstance(message.optional_nested_message, 1546 unittest_pb2.TestAllTypes.NestedMessage) 1547 self.assertEqual(500, message.optional_nested_message.bb) 1548 self.assertTrue(message.HasField('optional_foreign_message')) 1549 self.assertEqual(message.optional_foreign_message, 1550 unittest_pb2.ForeignMessage()) 1551 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1552 message.optional_nested_enum) 1553 self.assertEqual(2, len(message.repeatedgroup)) 1554 self.assertEqual(600, message.repeatedgroup[0].a) 1555 self.assertEqual(700, message.repeatedgroup[1].a) 1556 self.assertEqual(2, len(message.repeated_nested_enum)) 1557 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 1558 message.repeated_nested_enum[0]) 1559 self.assertEqual(unittest_pb2.TestAllTypes.BAR, 1560 message.repeated_nested_enum[1]) 1561 self.assertEqual(800, message.default_int32) 1562 self.assertEqual('y', message.oneof_string) 1563 self.assertFalse(message.HasField('optional_int64')) 1564 self.assertEqual(0, len(message.repeated_float)) 1565 self.assertEqual(42, message.default_int64) 1566 1567 message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ') 1568 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1569 message.optional_nested_enum) 1570 1571 with self.assertRaises(ValueError): 1572 unittest_pb2.TestAllTypes( 1573 optional_nested_message={'INVALID_NESTED_FIELD': 17}) 1574 1575 with self.assertRaises(TypeError): 1576 unittest_pb2.TestAllTypes( 1577 optional_nested_message={'bb': 'INVALID_VALUE_TYPE'}) 1578 1579 with self.assertRaises(ValueError): 1580 unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL') 1581 1582 with self.assertRaises(ValueError): 1583 unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') 1584 1585 def testPythonicInitWithDict(self): 1586 # Both string/unicode field name keys should work. 1587 kwargs = { 1588 'optional_int32': 100, 1589 u'optional_fixed32': 200, 1590 } 1591 msg = unittest_pb2.TestAllTypes(**kwargs) 1592 self.assertEqual(100, msg.optional_int32) 1593 self.assertEqual(200, msg.optional_fixed32) 1594 1595 1596 def test_documentation(self): 1597 # Also used by the interactive help() function. 1598 doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message') 1599 self.assertIn('class TestAllTypes', doc) 1600 self.assertIn('SerializePartialToString', doc) 1601 self.assertIn('repeated_float', doc) 1602 base = unittest_pb2.TestAllTypes.__bases__[0] 1603 self.assertRaises(AttributeError, getattr, base, '_extensions_by_name') 1604 1605 1606# Class to test proto3-only features/behavior (updated field presence & enums) 1607@testing_refleaks.TestCase 1608class Proto3Test(unittest.TestCase): 1609 1610 # Utility method for comparing equality with a map. 1611 def assertMapIterEquals(self, map_iter, dict_value): 1612 # Avoid mutating caller's copy. 1613 dict_value = dict(dict_value) 1614 1615 for k, v in map_iter: 1616 self.assertEqual(v, dict_value[k]) 1617 del dict_value[k] 1618 1619 self.assertEqual({}, dict_value) 1620 1621 def testFieldPresence(self): 1622 message = unittest_proto3_arena_pb2.TestAllTypes() 1623 1624 # We can't test presence of non-repeated, non-submessage fields. 1625 with self.assertRaises(ValueError): 1626 message.HasField('optional_int32') 1627 with self.assertRaises(ValueError): 1628 message.HasField('optional_float') 1629 with self.assertRaises(ValueError): 1630 message.HasField('optional_string') 1631 with self.assertRaises(ValueError): 1632 message.HasField('optional_bool') 1633 1634 # But we can still test presence of submessage fields. 1635 self.assertFalse(message.HasField('optional_nested_message')) 1636 1637 # As with proto2, we can't test presence of fields that don't exist, or 1638 # repeated fields. 1639 with self.assertRaises(ValueError): 1640 message.HasField('field_doesnt_exist') 1641 1642 with self.assertRaises(ValueError): 1643 message.HasField('repeated_int32') 1644 with self.assertRaises(ValueError): 1645 message.HasField('repeated_nested_message') 1646 1647 # Fields should default to their type-specific default. 1648 self.assertEqual(0, message.optional_int32) 1649 self.assertEqual(0, message.optional_float) 1650 self.assertEqual('', message.optional_string) 1651 self.assertEqual(False, message.optional_bool) 1652 self.assertEqual(0, message.optional_nested_message.bb) 1653 1654 # Setting a submessage should still return proper presence information. 1655 message.optional_nested_message.bb = 0 1656 self.assertTrue(message.HasField('optional_nested_message')) 1657 1658 # Set the fields to non-default values. 1659 message.optional_int32 = 5 1660 message.optional_float = 1.1 1661 message.optional_string = 'abc' 1662 message.optional_bool = True 1663 message.optional_nested_message.bb = 15 1664 1665 # Clearing the fields unsets them and resets their value to default. 1666 message.ClearField('optional_int32') 1667 message.ClearField('optional_float') 1668 message.ClearField('optional_string') 1669 message.ClearField('optional_bool') 1670 message.ClearField('optional_nested_message') 1671 1672 self.assertEqual(0, message.optional_int32) 1673 self.assertEqual(0, message.optional_float) 1674 self.assertEqual('', message.optional_string) 1675 self.assertEqual(False, message.optional_bool) 1676 self.assertEqual(0, message.optional_nested_message.bb) 1677 1678 def testProto3ParserDropDefaultScalar(self): 1679 message_proto2 = unittest_pb2.TestAllTypes() 1680 message_proto2.optional_int32 = 0 1681 message_proto2.optional_string = '' 1682 message_proto2.optional_bytes = b'' 1683 self.assertEqual(len(message_proto2.ListFields()), 3) 1684 1685 message_proto3 = unittest_proto3_arena_pb2.TestAllTypes() 1686 message_proto3.ParseFromString(message_proto2.SerializeToString()) 1687 self.assertEqual(len(message_proto3.ListFields()), 0) 1688 1689 def testProto3Optional(self): 1690 msg = test_proto3_optional_pb2.TestProto3Optional() 1691 self.assertFalse(msg.HasField('optional_int32')) 1692 self.assertFalse(msg.HasField('optional_float')) 1693 self.assertFalse(msg.HasField('optional_string')) 1694 self.assertFalse(msg.HasField('optional_nested_message')) 1695 self.assertFalse(msg.optional_nested_message.HasField('bb')) 1696 1697 # Set fields. 1698 msg.optional_int32 = 1 1699 msg.optional_float = 1.0 1700 msg.optional_string = '123' 1701 msg.optional_nested_message.bb = 1 1702 self.assertTrue(msg.HasField('optional_int32')) 1703 self.assertTrue(msg.HasField('optional_float')) 1704 self.assertTrue(msg.HasField('optional_string')) 1705 self.assertTrue(msg.HasField('optional_nested_message')) 1706 self.assertTrue(msg.optional_nested_message.HasField('bb')) 1707 # Set to default value does not clear the fields 1708 msg.optional_int32 = 0 1709 msg.optional_float = 0.0 1710 msg.optional_string = '' 1711 msg.optional_nested_message.bb = 0 1712 self.assertTrue(msg.HasField('optional_int32')) 1713 self.assertTrue(msg.HasField('optional_float')) 1714 self.assertTrue(msg.HasField('optional_string')) 1715 self.assertTrue(msg.HasField('optional_nested_message')) 1716 self.assertTrue(msg.optional_nested_message.HasField('bb')) 1717 1718 # Test serialize 1719 msg2 = test_proto3_optional_pb2.TestProto3Optional() 1720 msg2.ParseFromString(msg.SerializeToString()) 1721 self.assertTrue(msg2.HasField('optional_int32')) 1722 self.assertTrue(msg2.HasField('optional_float')) 1723 self.assertTrue(msg2.HasField('optional_string')) 1724 self.assertTrue(msg2.HasField('optional_nested_message')) 1725 self.assertTrue(msg2.optional_nested_message.HasField('bb')) 1726 1727 self.assertEqual(msg.WhichOneof('_optional_int32'), 'optional_int32') 1728 1729 # Clear these fields. 1730 msg.ClearField('optional_int32') 1731 msg.ClearField('optional_float') 1732 msg.ClearField('optional_string') 1733 msg.ClearField('optional_nested_message') 1734 self.assertFalse(msg.HasField('optional_int32')) 1735 self.assertFalse(msg.HasField('optional_float')) 1736 self.assertFalse(msg.HasField('optional_string')) 1737 self.assertFalse(msg.HasField('optional_nested_message')) 1738 self.assertFalse(msg.optional_nested_message.HasField('bb')) 1739 1740 self.assertEqual(msg.WhichOneof('_optional_int32'), None) 1741 1742 def testAssignUnknownEnum(self): 1743 """Assigning an unknown enum value is allowed and preserves the value.""" 1744 m = unittest_proto3_arena_pb2.TestAllTypes() 1745 1746 # Proto3 can assign unknown enums. 1747 m.optional_nested_enum = 1234567 1748 self.assertEqual(1234567, m.optional_nested_enum) 1749 m.repeated_nested_enum.append(22334455) 1750 self.assertEqual(22334455, m.repeated_nested_enum[0]) 1751 # Assignment is a different code path than append for the C++ impl. 1752 m.repeated_nested_enum[0] = 7654321 1753 self.assertEqual(7654321, m.repeated_nested_enum[0]) 1754 serialized = m.SerializeToString() 1755 1756 m2 = unittest_proto3_arena_pb2.TestAllTypes() 1757 m2.ParseFromString(serialized) 1758 self.assertEqual(1234567, m2.optional_nested_enum) 1759 self.assertEqual(7654321, m2.repeated_nested_enum[0]) 1760 1761 # Map isn't really a proto3-only feature. But there is no proto2 equivalent 1762 # of google/protobuf/map_unittest.proto right now, so it's not easy to 1763 # test both with the same test like we do for the other proto2/proto3 tests. 1764 # (google/protobuf/map_proto2_unittest.proto is very different in the set 1765 # of messages and fields it contains). 1766 def testScalarMapDefaults(self): 1767 msg = map_unittest_pb2.TestMap() 1768 1769 # Scalars start out unset. 1770 self.assertFalse(-123 in msg.map_int32_int32) 1771 self.assertFalse(-2**33 in msg.map_int64_int64) 1772 self.assertFalse(123 in msg.map_uint32_uint32) 1773 self.assertFalse(2**33 in msg.map_uint64_uint64) 1774 self.assertFalse(123 in msg.map_int32_double) 1775 self.assertFalse(False in msg.map_bool_bool) 1776 self.assertFalse('abc' in msg.map_string_string) 1777 self.assertFalse(111 in msg.map_int32_bytes) 1778 self.assertFalse(888 in msg.map_int32_enum) 1779 1780 # Accessing an unset key returns the default. 1781 self.assertEqual(0, msg.map_int32_int32[-123]) 1782 self.assertEqual(0, msg.map_int64_int64[-2**33]) 1783 self.assertEqual(0, msg.map_uint32_uint32[123]) 1784 self.assertEqual(0, msg.map_uint64_uint64[2**33]) 1785 self.assertEqual(0.0, msg.map_int32_double[123]) 1786 self.assertTrue(isinstance(msg.map_int32_double[123], float)) 1787 self.assertEqual(False, msg.map_bool_bool[False]) 1788 self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) 1789 self.assertEqual('', msg.map_string_string['abc']) 1790 self.assertEqual(b'', msg.map_int32_bytes[111]) 1791 self.assertEqual(0, msg.map_int32_enum[888]) 1792 1793 # It also sets the value in the map 1794 self.assertTrue(-123 in msg.map_int32_int32) 1795 self.assertTrue(-2**33 in msg.map_int64_int64) 1796 self.assertTrue(123 in msg.map_uint32_uint32) 1797 self.assertTrue(2**33 in msg.map_uint64_uint64) 1798 self.assertTrue(123 in msg.map_int32_double) 1799 self.assertTrue(False in msg.map_bool_bool) 1800 self.assertTrue('abc' in msg.map_string_string) 1801 self.assertTrue(111 in msg.map_int32_bytes) 1802 self.assertTrue(888 in msg.map_int32_enum) 1803 1804 self.assertIsInstance(msg.map_string_string['abc'], six.text_type) 1805 1806 # Accessing an unset key still throws TypeError if the type of the key 1807 # is incorrect. 1808 with self.assertRaises(TypeError): 1809 msg.map_string_string[123] 1810 1811 with self.assertRaises(TypeError): 1812 123 in msg.map_string_string 1813 1814 def testMapGet(self): 1815 # Need to test that get() properly returns the default, even though the dict 1816 # has defaultdict-like semantics. 1817 msg = map_unittest_pb2.TestMap() 1818 1819 self.assertIsNone(msg.map_int32_int32.get(5)) 1820 self.assertEqual(10, msg.map_int32_int32.get(5, 10)) 1821 self.assertEqual(10, msg.map_int32_int32.get(key=5, default=10)) 1822 self.assertIsNone(msg.map_int32_int32.get(5)) 1823 1824 msg.map_int32_int32[5] = 15 1825 self.assertEqual(15, msg.map_int32_int32.get(5)) 1826 self.assertEqual(15, msg.map_int32_int32.get(5)) 1827 with self.assertRaises(TypeError): 1828 msg.map_int32_int32.get('') 1829 1830 self.assertIsNone(msg.map_int32_foreign_message.get(5)) 1831 self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10)) 1832 self.assertEqual(10, msg.map_int32_foreign_message.get(key=5, default=10)) 1833 1834 submsg = msg.map_int32_foreign_message[5] 1835 self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) 1836 with self.assertRaises(TypeError): 1837 msg.map_int32_foreign_message.get('') 1838 1839 def testScalarMap(self): 1840 msg = map_unittest_pb2.TestMap() 1841 1842 self.assertEqual(0, len(msg.map_int32_int32)) 1843 self.assertFalse(5 in msg.map_int32_int32) 1844 1845 msg.map_int32_int32[-123] = -456 1846 msg.map_int64_int64[-2**33] = -2**34 1847 msg.map_uint32_uint32[123] = 456 1848 msg.map_uint64_uint64[2**33] = 2**34 1849 msg.map_int32_float[2] = 1.2 1850 msg.map_int32_double[1] = 3.3 1851 msg.map_string_string['abc'] = '123' 1852 msg.map_bool_bool[True] = True 1853 msg.map_int32_enum[888] = 2 1854 # Unknown numeric enum is supported in proto3. 1855 msg.map_int32_enum[123] = 456 1856 1857 self.assertEqual([], msg.FindInitializationErrors()) 1858 1859 self.assertEqual(1, len(msg.map_string_string)) 1860 1861 # Bad key. 1862 with self.assertRaises(TypeError): 1863 msg.map_string_string[123] = '123' 1864 1865 # Verify that trying to assign a bad key doesn't actually add a member to 1866 # the map. 1867 self.assertEqual(1, len(msg.map_string_string)) 1868 1869 # Bad value. 1870 with self.assertRaises(TypeError): 1871 msg.map_string_string['123'] = 123 1872 1873 serialized = msg.SerializeToString() 1874 msg2 = map_unittest_pb2.TestMap() 1875 msg2.ParseFromString(serialized) 1876 1877 # Bad key. 1878 with self.assertRaises(TypeError): 1879 msg2.map_string_string[123] = '123' 1880 1881 # Bad value. 1882 with self.assertRaises(TypeError): 1883 msg2.map_string_string['123'] = 123 1884 1885 self.assertEqual(-456, msg2.map_int32_int32[-123]) 1886 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 1887 self.assertEqual(456, msg2.map_uint32_uint32[123]) 1888 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 1889 self.assertAlmostEqual(1.2, msg.map_int32_float[2]) 1890 self.assertEqual(3.3, msg.map_int32_double[1]) 1891 self.assertEqual('123', msg2.map_string_string['abc']) 1892 self.assertEqual(True, msg2.map_bool_bool[True]) 1893 self.assertEqual(2, msg2.map_int32_enum[888]) 1894 self.assertEqual(456, msg2.map_int32_enum[123]) 1895 self.assertEqual('{-123: -456}', 1896 str(msg2.map_int32_int32)) 1897 1898 def testMapEntryAlwaysSerialized(self): 1899 msg = map_unittest_pb2.TestMap() 1900 msg.map_int32_int32[0] = 0 1901 msg.map_string_string[''] = '' 1902 self.assertEqual(msg.ByteSize(), 12) 1903 self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00', 1904 msg.SerializeToString()) 1905 1906 def testStringUnicodeConversionInMap(self): 1907 msg = map_unittest_pb2.TestMap() 1908 1909 unicode_obj = u'\u1234' 1910 bytes_obj = unicode_obj.encode('utf8') 1911 1912 msg.map_string_string[bytes_obj] = bytes_obj 1913 1914 (key, value) = list(msg.map_string_string.items())[0] 1915 1916 self.assertEqual(key, unicode_obj) 1917 self.assertEqual(value, unicode_obj) 1918 1919 self.assertIsInstance(key, six.text_type) 1920 self.assertIsInstance(value, six.text_type) 1921 1922 def testMessageMap(self): 1923 msg = map_unittest_pb2.TestMap() 1924 1925 self.assertEqual(0, len(msg.map_int32_foreign_message)) 1926 self.assertFalse(5 in msg.map_int32_foreign_message) 1927 1928 msg.map_int32_foreign_message[123] 1929 # get_or_create() is an alias for getitem. 1930 msg.map_int32_foreign_message.get_or_create(-456) 1931 1932 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1933 self.assertIn(123, msg.map_int32_foreign_message) 1934 self.assertIn(-456, msg.map_int32_foreign_message) 1935 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1936 1937 # Bad key. 1938 with self.assertRaises(TypeError): 1939 msg.map_int32_foreign_message['123'] 1940 1941 # Can't assign directly to submessage. 1942 with self.assertRaises(ValueError): 1943 msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] 1944 1945 # Verify that trying to assign a bad key doesn't actually add a member to 1946 # the map. 1947 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1948 1949 serialized = msg.SerializeToString() 1950 msg2 = map_unittest_pb2.TestMap() 1951 msg2.ParseFromString(serialized) 1952 1953 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 1954 self.assertIn(123, msg2.map_int32_foreign_message) 1955 self.assertIn(-456, msg2.map_int32_foreign_message) 1956 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 1957 msg2.map_int32_foreign_message[123].c = 1 1958 # TODO(jieluo): Fix text format for message map. 1959 self.assertIn(str(msg2.map_int32_foreign_message), 1960 ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }')) 1961 1962 def testNestedMessageMapItemDelete(self): 1963 msg = map_unittest_pb2.TestMap() 1964 msg.map_int32_all_types[1].optional_nested_message.bb = 1 1965 del msg.map_int32_all_types[1] 1966 msg.map_int32_all_types[2].optional_nested_message.bb = 2 1967 self.assertEqual(1, len(msg.map_int32_all_types)) 1968 msg.map_int32_all_types[1].optional_nested_message.bb = 1 1969 self.assertEqual(2, len(msg.map_int32_all_types)) 1970 1971 serialized = msg.SerializeToString() 1972 msg2 = map_unittest_pb2.TestMap() 1973 msg2.ParseFromString(serialized) 1974 keys = [1, 2] 1975 # The loop triggers PyErr_Occurred() in c extension. 1976 for key in keys: 1977 del msg2.map_int32_all_types[key] 1978 1979 def testMapByteSize(self): 1980 msg = map_unittest_pb2.TestMap() 1981 msg.map_int32_int32[1] = 1 1982 size = msg.ByteSize() 1983 msg.map_int32_int32[1] = 128 1984 self.assertEqual(msg.ByteSize(), size + 1) 1985 1986 msg.map_int32_foreign_message[19].c = 1 1987 size = msg.ByteSize() 1988 msg.map_int32_foreign_message[19].c = 128 1989 self.assertEqual(msg.ByteSize(), size + 1) 1990 1991 def testMergeFrom(self): 1992 msg = map_unittest_pb2.TestMap() 1993 msg.map_int32_int32[12] = 34 1994 msg.map_int32_int32[56] = 78 1995 msg.map_int64_int64[22] = 33 1996 msg.map_int32_foreign_message[111].c = 5 1997 msg.map_int32_foreign_message[222].c = 10 1998 1999 msg2 = map_unittest_pb2.TestMap() 2000 msg2.map_int32_int32[12] = 55 2001 msg2.map_int64_int64[88] = 99 2002 msg2.map_int32_foreign_message[222].c = 15 2003 msg2.map_int32_foreign_message[222].d = 20 2004 old_map_value = msg2.map_int32_foreign_message[222] 2005 2006 msg2.MergeFrom(msg) 2007 # Compare with expected message instead of call 2008 # msg2.map_int32_foreign_message[222] to make sure MergeFrom does not 2009 # sync with repeated field and there is no duplicated keys. 2010 expected_msg = map_unittest_pb2.TestMap() 2011 expected_msg.CopyFrom(msg) 2012 expected_msg.map_int64_int64[88] = 99 2013 self.assertEqual(msg2, expected_msg) 2014 2015 self.assertEqual(34, msg2.map_int32_int32[12]) 2016 self.assertEqual(78, msg2.map_int32_int32[56]) 2017 self.assertEqual(33, msg2.map_int64_int64[22]) 2018 self.assertEqual(99, msg2.map_int64_int64[88]) 2019 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 2020 self.assertEqual(10, msg2.map_int32_foreign_message[222].c) 2021 self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d')) 2022 if api_implementation.Type() != 'cpp': 2023 # During the call to MergeFrom(), the C++ implementation will have 2024 # deallocated the underlying message, but this is very difficult to detect 2025 # properly. The line below is likely to cause a segmentation fault. 2026 # With the Python implementation, old_map_value is just 'detached' from 2027 # the main message. Using it will not crash of course, but since it still 2028 # have a reference to the parent message I'm sure we can find interesting 2029 # ways to cause inconsistencies. 2030 self.assertEqual(15, old_map_value.c) 2031 2032 # Verify that there is only one entry per key, even though the MergeFrom 2033 # may have internally created multiple entries for a single key in the 2034 # list representation. 2035 as_dict = {} 2036 for key in msg2.map_int32_foreign_message: 2037 self.assertFalse(key in as_dict) 2038 as_dict[key] = msg2.map_int32_foreign_message[key].c 2039 2040 self.assertEqual({111: 5, 222: 10}, as_dict) 2041 2042 # Special case: test that delete of item really removes the item, even if 2043 # there might have physically been duplicate keys due to the previous merge. 2044 # This is only a special case for the C++ implementation which stores the 2045 # map as an array. 2046 del msg2.map_int32_int32[12] 2047 self.assertFalse(12 in msg2.map_int32_int32) 2048 2049 del msg2.map_int32_foreign_message[222] 2050 self.assertFalse(222 in msg2.map_int32_foreign_message) 2051 with self.assertRaises(TypeError): 2052 del msg2.map_int32_foreign_message[''] 2053 2054 def testMapMergeFrom(self): 2055 msg = map_unittest_pb2.TestMap() 2056 msg.map_int32_int32[12] = 34 2057 msg.map_int32_int32[56] = 78 2058 msg.map_int64_int64[22] = 33 2059 msg.map_int32_foreign_message[111].c = 5 2060 msg.map_int32_foreign_message[222].c = 10 2061 2062 msg2 = map_unittest_pb2.TestMap() 2063 msg2.map_int32_int32[12] = 55 2064 msg2.map_int64_int64[88] = 99 2065 msg2.map_int32_foreign_message[222].c = 15 2066 msg2.map_int32_foreign_message[222].d = 20 2067 2068 msg2.map_int32_int32.MergeFrom(msg.map_int32_int32) 2069 self.assertEqual(34, msg2.map_int32_int32[12]) 2070 self.assertEqual(78, msg2.map_int32_int32[56]) 2071 2072 msg2.map_int64_int64.MergeFrom(msg.map_int64_int64) 2073 self.assertEqual(33, msg2.map_int64_int64[22]) 2074 self.assertEqual(99, msg2.map_int64_int64[88]) 2075 2076 msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message) 2077 # Compare with expected message instead of call 2078 # msg.map_int32_foreign_message[222] to make sure MergeFrom does not 2079 # sync with repeated field and no duplicated keys. 2080 expected_msg = map_unittest_pb2.TestMap() 2081 expected_msg.CopyFrom(msg) 2082 expected_msg.map_int64_int64[88] = 99 2083 self.assertEqual(msg2, expected_msg) 2084 2085 # Test when cpp extension cache a map. 2086 m1 = map_unittest_pb2.TestMap() 2087 m2 = map_unittest_pb2.TestMap() 2088 self.assertEqual(m1.map_int32_foreign_message, 2089 m1.map_int32_foreign_message) 2090 m2.map_int32_foreign_message[123].c = 10 2091 m1.MergeFrom(m2) 2092 self.assertEqual(10, m2.map_int32_foreign_message[123].c) 2093 2094 # Test merge maps within different message types. 2095 m1 = map_unittest_pb2.TestMap() 2096 m2 = map_unittest_pb2.TestMessageMap() 2097 m2.map_int32_message[123].optional_int32 = 10 2098 m1.map_int32_all_types.MergeFrom(m2.map_int32_message) 2099 self.assertEqual(10, m1.map_int32_all_types[123].optional_int32) 2100 2101 # Test overwrite message value map 2102 msg = map_unittest_pb2.TestMap() 2103 msg.map_int32_foreign_message[222].c = 123 2104 msg2 = map_unittest_pb2.TestMap() 2105 msg2.map_int32_foreign_message[222].d = 20 2106 msg.MergeFromString(msg2.SerializeToString()) 2107 self.assertEqual(msg.map_int32_foreign_message[222].d, 20) 2108 self.assertNotEqual(msg.map_int32_foreign_message[222].c, 123) 2109 2110 # Merge a dict to map field is not accepted 2111 with self.assertRaises(AttributeError): 2112 m1.map_int32_all_types.MergeFrom( 2113 {1: unittest_proto3_arena_pb2.TestAllTypes()}) 2114 2115 def testMergeFromBadType(self): 2116 msg = map_unittest_pb2.TestMap() 2117 with self.assertRaisesRegexp( 2118 TypeError, 2119 r'Parameter to MergeFrom\(\) must be instance of same class: expected ' 2120 r'.*TestMap got int\.'): 2121 msg.MergeFrom(1) 2122 2123 def testCopyFromBadType(self): 2124 msg = map_unittest_pb2.TestMap() 2125 with self.assertRaisesRegexp( 2126 TypeError, 2127 r'Parameter to [A-Za-z]*From\(\) must be instance of same class: ' 2128 r'expected .*TestMap got int\.'): 2129 msg.CopyFrom(1) 2130 2131 def testIntegerMapWithLongs(self): 2132 msg = map_unittest_pb2.TestMap() 2133 msg.map_int32_int32[long(-123)] = long(-456) 2134 msg.map_int64_int64[long(-2**33)] = long(-2**34) 2135 msg.map_uint32_uint32[long(123)] = long(456) 2136 msg.map_uint64_uint64[long(2**33)] = long(2**34) 2137 2138 serialized = msg.SerializeToString() 2139 msg2 = map_unittest_pb2.TestMap() 2140 msg2.ParseFromString(serialized) 2141 2142 self.assertEqual(-456, msg2.map_int32_int32[-123]) 2143 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 2144 self.assertEqual(456, msg2.map_uint32_uint32[123]) 2145 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 2146 2147 def testMapAssignmentCausesPresence(self): 2148 msg = map_unittest_pb2.TestMapSubmessage() 2149 msg.test_map.map_int32_int32[123] = 456 2150 2151 serialized = msg.SerializeToString() 2152 msg2 = map_unittest_pb2.TestMapSubmessage() 2153 msg2.ParseFromString(serialized) 2154 2155 self.assertEqual(msg, msg2) 2156 2157 # Now test that various mutations of the map properly invalidate the 2158 # cached size of the submessage. 2159 msg.test_map.map_int32_int32[888] = 999 2160 serialized = msg.SerializeToString() 2161 msg2.ParseFromString(serialized) 2162 self.assertEqual(msg, msg2) 2163 2164 msg.test_map.map_int32_int32.clear() 2165 serialized = msg.SerializeToString() 2166 msg2.ParseFromString(serialized) 2167 self.assertEqual(msg, msg2) 2168 2169 def testMapAssignmentCausesPresenceForSubmessages(self): 2170 msg = map_unittest_pb2.TestMapSubmessage() 2171 msg.test_map.map_int32_foreign_message[123].c = 5 2172 2173 serialized = msg.SerializeToString() 2174 msg2 = map_unittest_pb2.TestMapSubmessage() 2175 msg2.ParseFromString(serialized) 2176 2177 self.assertEqual(msg, msg2) 2178 2179 # Now test that various mutations of the map properly invalidate the 2180 # cached size of the submessage. 2181 msg.test_map.map_int32_foreign_message[888].c = 7 2182 serialized = msg.SerializeToString() 2183 msg2.ParseFromString(serialized) 2184 self.assertEqual(msg, msg2) 2185 2186 msg.test_map.map_int32_foreign_message[888].MergeFrom( 2187 msg.test_map.map_int32_foreign_message[123]) 2188 serialized = msg.SerializeToString() 2189 msg2.ParseFromString(serialized) 2190 self.assertEqual(msg, msg2) 2191 2192 msg.test_map.map_int32_foreign_message.clear() 2193 serialized = msg.SerializeToString() 2194 msg2.ParseFromString(serialized) 2195 self.assertEqual(msg, msg2) 2196 2197 def testModifyMapWhileIterating(self): 2198 msg = map_unittest_pb2.TestMap() 2199 2200 string_string_iter = iter(msg.map_string_string) 2201 int32_foreign_iter = iter(msg.map_int32_foreign_message) 2202 2203 msg.map_string_string['abc'] = '123' 2204 msg.map_int32_foreign_message[5].c = 5 2205 2206 with self.assertRaises(RuntimeError): 2207 for key in string_string_iter: 2208 pass 2209 2210 with self.assertRaises(RuntimeError): 2211 for key in int32_foreign_iter: 2212 pass 2213 2214 def testSubmessageMap(self): 2215 msg = map_unittest_pb2.TestMap() 2216 2217 submsg = msg.map_int32_foreign_message[111] 2218 self.assertIs(submsg, msg.map_int32_foreign_message[111]) 2219 self.assertIsInstance(submsg, unittest_pb2.ForeignMessage) 2220 2221 submsg.c = 5 2222 2223 serialized = msg.SerializeToString() 2224 msg2 = map_unittest_pb2.TestMap() 2225 msg2.ParseFromString(serialized) 2226 2227 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 2228 2229 # Doesn't allow direct submessage assignment. 2230 with self.assertRaises(ValueError): 2231 msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage() 2232 2233 def testMapIteration(self): 2234 msg = map_unittest_pb2.TestMap() 2235 2236 for k, v in msg.map_int32_int32.items(): 2237 # Should not be reached. 2238 self.assertTrue(False) 2239 2240 msg.map_int32_int32[2] = 4 2241 msg.map_int32_int32[3] = 6 2242 msg.map_int32_int32[4] = 8 2243 self.assertEqual(3, len(msg.map_int32_int32)) 2244 2245 matching_dict = {2: 4, 3: 6, 4: 8} 2246 self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) 2247 2248 def testPython2Map(self): 2249 if sys.version_info < (3,): 2250 msg = map_unittest_pb2.TestMap() 2251 msg.map_int32_int32[2] = 4 2252 msg.map_int32_int32[3] = 6 2253 msg.map_int32_int32[4] = 8 2254 msg.map_int32_int32[5] = 10 2255 map_int32 = msg.map_int32_int32 2256 self.assertEqual(4, len(map_int32)) 2257 msg2 = map_unittest_pb2.TestMap() 2258 msg2.ParseFromString(msg.SerializeToString()) 2259 2260 def CheckItems(seq, iterator): 2261 self.assertEqual(next(iterator), seq[0]) 2262 self.assertEqual(list(iterator), seq[1:]) 2263 2264 CheckItems(map_int32.items(), map_int32.iteritems()) 2265 CheckItems(map_int32.keys(), map_int32.iterkeys()) 2266 CheckItems(map_int32.values(), map_int32.itervalues()) 2267 2268 self.assertEqual(6, map_int32.get(3)) 2269 self.assertEqual(None, map_int32.get(999)) 2270 self.assertEqual(6, map_int32.pop(3)) 2271 self.assertEqual(0, map_int32.pop(3)) 2272 self.assertEqual(3, len(map_int32)) 2273 key, value = map_int32.popitem() 2274 self.assertEqual(2 * key, value) 2275 self.assertEqual(2, len(map_int32)) 2276 map_int32.clear() 2277 self.assertEqual(0, len(map_int32)) 2278 2279 with self.assertRaises(KeyError): 2280 map_int32.popitem() 2281 2282 self.assertEqual(0, map_int32.setdefault(2)) 2283 self.assertEqual(1, len(map_int32)) 2284 2285 map_int32.update(msg2.map_int32_int32) 2286 self.assertEqual(4, len(map_int32)) 2287 2288 with self.assertRaises(TypeError): 2289 map_int32.update(msg2.map_int32_int32, 2290 msg2.map_int32_int32) 2291 with self.assertRaises(TypeError): 2292 map_int32.update(0) 2293 with self.assertRaises(TypeError): 2294 map_int32.update(value=12) 2295 2296 def testMapItems(self): 2297 # Map items used to have strange behaviors when use c extension. Because 2298 # [] may reorder the map and invalidate any exsting iterators. 2299 # TODO(jieluo): Check if [] reordering the map is a bug or intended 2300 # behavior. 2301 msg = map_unittest_pb2.TestMap() 2302 msg.map_string_string['local_init_op'] = '' 2303 msg.map_string_string['trainable_variables'] = '' 2304 msg.map_string_string['variables'] = '' 2305 msg.map_string_string['init_op'] = '' 2306 msg.map_string_string['summaries'] = '' 2307 items1 = msg.map_string_string.items() 2308 items2 = msg.map_string_string.items() 2309 self.assertEqual(items1, items2) 2310 2311 def testMapDeterministicSerialization(self): 2312 golden_data = (b'r\x0c\n\x07init_op\x12\x01d' 2313 b'r\n\n\x05item1\x12\x01e' 2314 b'r\n\n\x05item2\x12\x01f' 2315 b'r\n\n\x05item3\x12\x01g' 2316 b'r\x0b\n\x05item4\x12\x02QQ' 2317 b'r\x12\n\rlocal_init_op\x12\x01a' 2318 b'r\x0e\n\tsummaries\x12\x01e' 2319 b'r\x18\n\x13trainable_variables\x12\x01b' 2320 b'r\x0e\n\tvariables\x12\x01c') 2321 msg = map_unittest_pb2.TestMap() 2322 msg.map_string_string['local_init_op'] = 'a' 2323 msg.map_string_string['trainable_variables'] = 'b' 2324 msg.map_string_string['variables'] = 'c' 2325 msg.map_string_string['init_op'] = 'd' 2326 msg.map_string_string['summaries'] = 'e' 2327 msg.map_string_string['item1'] = 'e' 2328 msg.map_string_string['item2'] = 'f' 2329 msg.map_string_string['item3'] = 'g' 2330 msg.map_string_string['item4'] = 'QQ' 2331 2332 # If deterministic serialization is not working correctly, this will be 2333 # "flaky" depending on the exact python dict hash seed. 2334 # 2335 # Fortunately, there are enough items in this map that it is extremely 2336 # unlikely to ever hit the "right" in-order combination, so the test 2337 # itself should fail reliably. 2338 self.assertEqual(golden_data, msg.SerializeToString(deterministic=True)) 2339 2340 def testMapIterationClearMessage(self): 2341 # Iterator needs to work even if message and map are deleted. 2342 msg = map_unittest_pb2.TestMap() 2343 2344 msg.map_int32_int32[2] = 4 2345 msg.map_int32_int32[3] = 6 2346 msg.map_int32_int32[4] = 8 2347 2348 it = msg.map_int32_int32.items() 2349 del msg 2350 2351 matching_dict = {2: 4, 3: 6, 4: 8} 2352 self.assertMapIterEquals(it, matching_dict) 2353 2354 def testMapConstruction(self): 2355 msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4}) 2356 self.assertEqual(2, msg.map_int32_int32[1]) 2357 self.assertEqual(4, msg.map_int32_int32[3]) 2358 2359 msg = map_unittest_pb2.TestMap( 2360 map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)}) 2361 self.assertEqual(5, msg.map_int32_foreign_message[3].c) 2362 2363 def testMapScalarFieldConstruction(self): 2364 msg1 = map_unittest_pb2.TestMap() 2365 msg1.map_int32_int32[1] = 42 2366 msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32) 2367 self.assertEqual(42, msg2.map_int32_int32[1]) 2368 2369 def testMapMessageFieldConstruction(self): 2370 msg1 = map_unittest_pb2.TestMap() 2371 msg1.map_string_foreign_message['test'].c = 42 2372 msg2 = map_unittest_pb2.TestMap( 2373 map_string_foreign_message=msg1.map_string_foreign_message) 2374 self.assertEqual(42, msg2.map_string_foreign_message['test'].c) 2375 2376 def testMapFieldRaisesCorrectError(self): 2377 # Should raise a TypeError when given a non-iterable. 2378 with self.assertRaises(TypeError): 2379 map_unittest_pb2.TestMap(map_string_foreign_message=1) 2380 2381 def testMapValidAfterFieldCleared(self): 2382 # Map needs to work even if field is cleared. 2383 # For the C++ implementation this tests the correctness of 2384 # MapContainer::Release() 2385 msg = map_unittest_pb2.TestMap() 2386 int32_map = msg.map_int32_int32 2387 2388 int32_map[2] = 4 2389 int32_map[3] = 6 2390 int32_map[4] = 8 2391 2392 msg.ClearField('map_int32_int32') 2393 self.assertEqual(b'', msg.SerializeToString()) 2394 matching_dict = {2: 4, 3: 6, 4: 8} 2395 self.assertMapIterEquals(int32_map.items(), matching_dict) 2396 2397 def testMessageMapValidAfterFieldCleared(self): 2398 # Map needs to work even if field is cleared. 2399 # For the C++ implementation this tests the correctness of 2400 # MapContainer::Release() 2401 msg = map_unittest_pb2.TestMap() 2402 int32_foreign_message = msg.map_int32_foreign_message 2403 2404 int32_foreign_message[2].c = 5 2405 2406 msg.ClearField('map_int32_foreign_message') 2407 self.assertEqual(b'', msg.SerializeToString()) 2408 self.assertTrue(2 in int32_foreign_message.keys()) 2409 2410 def testMessageMapItemValidAfterTopMessageCleared(self): 2411 # Message map item needs to work even if it is cleared. 2412 # For the C++ implementation this tests the correctness of 2413 # MapContainer::Release() 2414 msg = map_unittest_pb2.TestMap() 2415 msg.map_int32_all_types[2].optional_string = 'bar' 2416 2417 if api_implementation.Type() == 'cpp': 2418 # Need to keep the map reference because of b/27942626. 2419 # TODO(jieluo): Remove it. 2420 unused_map = msg.map_int32_all_types # pylint: disable=unused-variable 2421 msg_value = msg.map_int32_all_types[2] 2422 msg.Clear() 2423 2424 # Reset to trigger sync between repeated field and map in c++. 2425 msg.map_int32_all_types[3].optional_string = 'foo' 2426 self.assertEqual(msg_value.optional_string, 'bar') 2427 2428 def testMapIterInvalidatedByClearField(self): 2429 # Map iterator is invalidated when field is cleared. 2430 # But this case does need to not crash the interpreter. 2431 # For the C++ implementation this tests the correctness of 2432 # ScalarMapContainer::Release() 2433 msg = map_unittest_pb2.TestMap() 2434 2435 it = iter(msg.map_int32_int32) 2436 2437 msg.ClearField('map_int32_int32') 2438 with self.assertRaises(RuntimeError): 2439 for _ in it: 2440 pass 2441 2442 it = iter(msg.map_int32_foreign_message) 2443 msg.ClearField('map_int32_foreign_message') 2444 with self.assertRaises(RuntimeError): 2445 for _ in it: 2446 pass 2447 2448 def testMapDelete(self): 2449 msg = map_unittest_pb2.TestMap() 2450 2451 self.assertEqual(0, len(msg.map_int32_int32)) 2452 2453 msg.map_int32_int32[4] = 6 2454 self.assertEqual(1, len(msg.map_int32_int32)) 2455 2456 with self.assertRaises(KeyError): 2457 del msg.map_int32_int32[88] 2458 2459 del msg.map_int32_int32[4] 2460 self.assertEqual(0, len(msg.map_int32_int32)) 2461 2462 with self.assertRaises(KeyError): 2463 del msg.map_int32_all_types[32] 2464 2465 def testMapsAreMapping(self): 2466 msg = map_unittest_pb2.TestMap() 2467 self.assertIsInstance(msg.map_int32_int32, collections_abc.Mapping) 2468 self.assertIsInstance(msg.map_int32_int32, collections_abc.MutableMapping) 2469 self.assertIsInstance(msg.map_int32_foreign_message, collections_abc.Mapping) 2470 self.assertIsInstance(msg.map_int32_foreign_message, 2471 collections_abc.MutableMapping) 2472 2473 def testMapsCompare(self): 2474 msg = map_unittest_pb2.TestMap() 2475 msg.map_int32_int32[-123] = -456 2476 self.assertEqual(msg.map_int32_int32, msg.map_int32_int32) 2477 self.assertEqual(msg.map_int32_foreign_message, 2478 msg.map_int32_foreign_message) 2479 self.assertNotEqual(msg.map_int32_int32, 0) 2480 2481 def testMapFindInitializationErrorsSmokeTest(self): 2482 msg = map_unittest_pb2.TestMap() 2483 msg.map_string_string['abc'] = '123' 2484 msg.map_int32_int32[35] = 64 2485 msg.map_string_foreign_message['foo'].c = 5 2486 self.assertEqual(0, len(msg.FindInitializationErrors())) 2487 2488 @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2') 2489 def testStrictUtf8Check(self): 2490 # Test u'\ud801' is rejected at parser in both python2 and python3. 2491 serialized = (b'r\x03\xed\xa0\x81') 2492 msg = unittest_proto3_arena_pb2.TestAllTypes() 2493 with self.assertRaises(Exception) as context: 2494 msg.MergeFromString(serialized) 2495 if api_implementation.Type() == 'python': 2496 self.assertIn('optional_string', str(context.exception)) 2497 else: 2498 self.assertIn('Error parsing message', str(context.exception)) 2499 2500 # Test optional_string=u'' is accepted. 2501 serialized = unittest_proto3_arena_pb2.TestAllTypes( 2502 optional_string=u'').SerializeToString() 2503 msg2 = unittest_proto3_arena_pb2.TestAllTypes() 2504 msg2.MergeFromString(serialized) 2505 self.assertEqual(msg2.optional_string, u'') 2506 2507 msg = unittest_proto3_arena_pb2.TestAllTypes( 2508 optional_string=u'\ud001') 2509 self.assertEqual(msg.optional_string, u'\ud001') 2510 2511 @unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2') 2512 def testSurrogatesInPython3(self): 2513 # Surrogates like U+D83D is an invalid unicode character, it is 2514 # supported by Python2 only because in some builds, unicode strings 2515 # use 2-bytes code units. Since Python 3.3, we don't have this problem. 2516 # 2517 # Surrogates are utf16 code units, in a unicode string they are invalid 2518 # characters even when they appear in pairs like u'\ud801\udc01'. Protobuf 2519 # Python3 reject such cases at setters and parsers. Python2 accpect it 2520 # to keep same features with the language itself. 'Unpaired pairs' 2521 # like u'\ud801' are rejected at parsers when strict utf8 check is enabled 2522 # in proto3 to keep same behavior with c extension. 2523 2524 # Surrogates are rejected at setters in Python3. 2525 with self.assertRaises(ValueError): 2526 unittest_proto3_arena_pb2.TestAllTypes( 2527 optional_string=u'\ud801\udc01') 2528 with self.assertRaises(ValueError): 2529 unittest_proto3_arena_pb2.TestAllTypes( 2530 optional_string=b'\xed\xa0\x81') 2531 with self.assertRaises(ValueError): 2532 unittest_proto3_arena_pb2.TestAllTypes( 2533 optional_string=u'\ud801') 2534 with self.assertRaises(ValueError): 2535 unittest_proto3_arena_pb2.TestAllTypes( 2536 optional_string=u'\ud801\ud801') 2537 2538 @unittest.skipIf(six.PY3 or sys.maxunicode == UCS2_MAXUNICODE, 2539 'Surrogates are rejected at setters in Python3') 2540 def testSurrogatesInPython2(self): 2541 # Test optional_string=u'\ud801\udc01'. 2542 # surrogate pair is acceptable in python2. 2543 msg = unittest_proto3_arena_pb2.TestAllTypes( 2544 optional_string=u'\ud801\udc01') 2545 # TODO(jieluo): Change pure python to have same behavior with c extension. 2546 # Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are 2547 # equal, some are not equal. 2548 if api_implementation.Type() == 'python': 2549 self.assertEqual(msg.optional_string, u'\ud801\udc01') 2550 else: 2551 self.assertEqual(msg.optional_string, u'\U00010401') 2552 serialized = msg.SerializeToString() 2553 msg2 = unittest_proto3_arena_pb2.TestAllTypes() 2554 msg2.MergeFromString(serialized) 2555 self.assertEqual(msg2.optional_string, u'\U00010401') 2556 2557 # Python2 does not reject surrogates at setters. 2558 msg = unittest_proto3_arena_pb2.TestAllTypes( 2559 optional_string=b'\xed\xa0\x81') 2560 unittest_proto3_arena_pb2.TestAllTypes( 2561 optional_string=u'\ud801') 2562 unittest_proto3_arena_pb2.TestAllTypes( 2563 optional_string=u'\ud801\ud801') 2564 2565 2566@testing_refleaks.TestCase 2567class ValidTypeNamesTest(unittest.TestCase): 2568 2569 def assertImportFromName(self, msg, base_name): 2570 # Parse <type 'module.class_name'> to extra 'some.name' as a string. 2571 tp_name = str(type(msg)).split("'")[1] 2572 valid_names = ('Repeated%sContainer' % base_name, 2573 'Repeated%sFieldContainer' % base_name) 2574 self.assertTrue(any(tp_name.endswith(v) for v in valid_names), 2575 '%r does end with any of %r' % (tp_name, valid_names)) 2576 2577 parts = tp_name.split('.') 2578 class_name = parts[-1] 2579 module_name = '.'.join(parts[:-1]) 2580 __import__(module_name, fromlist=[class_name]) 2581 2582 def testTypeNamesCanBeImported(self): 2583 # If import doesn't work, pickling won't work either. 2584 pb = unittest_pb2.TestAllTypes() 2585 self.assertImportFromName(pb.repeated_int32, 'Scalar') 2586 self.assertImportFromName(pb.repeated_nested_message, 'Composite') 2587 2588@testing_refleaks.TestCase 2589class PackedFieldTest(unittest.TestCase): 2590 2591 def setMessage(self, message): 2592 message.repeated_int32.append(1) 2593 message.repeated_int64.append(1) 2594 message.repeated_uint32.append(1) 2595 message.repeated_uint64.append(1) 2596 message.repeated_sint32.append(1) 2597 message.repeated_sint64.append(1) 2598 message.repeated_fixed32.append(1) 2599 message.repeated_fixed64.append(1) 2600 message.repeated_sfixed32.append(1) 2601 message.repeated_sfixed64.append(1) 2602 message.repeated_float.append(1.0) 2603 message.repeated_double.append(1.0) 2604 message.repeated_bool.append(True) 2605 message.repeated_nested_enum.append(1) 2606 2607 def testPackedFields(self): 2608 message = packed_field_test_pb2.TestPackedTypes() 2609 self.setMessage(message) 2610 golden_data = (b'\x0A\x01\x01' 2611 b'\x12\x01\x01' 2612 b'\x1A\x01\x01' 2613 b'\x22\x01\x01' 2614 b'\x2A\x01\x02' 2615 b'\x32\x01\x02' 2616 b'\x3A\x04\x01\x00\x00\x00' 2617 b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00' 2618 b'\x4A\x04\x01\x00\x00\x00' 2619 b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00' 2620 b'\x5A\x04\x00\x00\x80\x3f' 2621 b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f' 2622 b'\x6A\x01\x01' 2623 b'\x72\x01\x01') 2624 self.assertEqual(golden_data, message.SerializeToString()) 2625 2626 def testUnpackedFields(self): 2627 message = packed_field_test_pb2.TestUnpackedTypes() 2628 self.setMessage(message) 2629 golden_data = (b'\x08\x01' 2630 b'\x10\x01' 2631 b'\x18\x01' 2632 b'\x20\x01' 2633 b'\x28\x02' 2634 b'\x30\x02' 2635 b'\x3D\x01\x00\x00\x00' 2636 b'\x41\x01\x00\x00\x00\x00\x00\x00\x00' 2637 b'\x4D\x01\x00\x00\x00' 2638 b'\x51\x01\x00\x00\x00\x00\x00\x00\x00' 2639 b'\x5D\x00\x00\x80\x3f' 2640 b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f' 2641 b'\x68\x01' 2642 b'\x70\x01') 2643 self.assertEqual(golden_data, message.SerializeToString()) 2644 2645 2646@unittest.skipIf(api_implementation.Type() != 'cpp' or 2647 sys.version_info < (2, 7), 2648 'explicit tests of the C++ implementation for PY27 and above') 2649@testing_refleaks.TestCase 2650class OversizeProtosTest(unittest.TestCase): 2651 2652 @classmethod 2653 def setUpClass(cls): 2654 # At the moment, reference cycles between DescriptorPool and Message classes 2655 # are not detected and these objects are never freed. 2656 # To avoid errors with ReferenceLeakChecker, we create the class only once. 2657 file_desc = """ 2658 name: "f/f.msg2" 2659 package: "f" 2660 message_type { 2661 name: "msg1" 2662 field { 2663 name: "payload" 2664 number: 1 2665 label: LABEL_OPTIONAL 2666 type: TYPE_STRING 2667 } 2668 } 2669 message_type { 2670 name: "msg2" 2671 field { 2672 name: "field" 2673 number: 1 2674 label: LABEL_OPTIONAL 2675 type: TYPE_MESSAGE 2676 type_name: "msg1" 2677 } 2678 } 2679 """ 2680 pool = descriptor_pool.DescriptorPool() 2681 desc = descriptor_pb2.FileDescriptorProto() 2682 text_format.Parse(file_desc, desc) 2683 pool.Add(desc) 2684 cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype( 2685 pool.FindMessageTypeByName('f.msg2')) 2686 2687 def setUp(self): 2688 self.p = self.proto_cls() 2689 self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) 2690 self.p_serialized = self.p.SerializeToString() 2691 2692 def testAssertOversizeProto(self): 2693 from google.protobuf.pyext._message import SetAllowOversizeProtos 2694 SetAllowOversizeProtos(False) 2695 q = self.proto_cls() 2696 try: 2697 q.ParseFromString(self.p_serialized) 2698 except message.DecodeError as e: 2699 self.assertEqual(str(e), 'Error parsing message') 2700 2701 def testSucceedOversizeProto(self): 2702 from google.protobuf.pyext._message import SetAllowOversizeProtos 2703 SetAllowOversizeProtos(True) 2704 q = self.proto_cls() 2705 q.ParseFromString(self.p_serialized) 2706 self.assertEqual(self.p.field.payload, q.field.payload) 2707 2708if __name__ == '__main__': 2709 unittest.main() 2710