1#! /usr/bin/env python 2# 3# Protocol Buffers - Google's data interchange format 4# Copyright 2008 Google Inc. All rights reserved. 5# https://developers.google.com/protocol-buffers/ 6# 7# Redistribution and use in source and binary forms, with or without 8# modification, are permitted provided that the following conditions are 9# met: 10# 11# * Redistributions of source code must retain the above copyright 12# notice, this list of conditions and the following disclaimer. 13# * Redistributions in binary form must reproduce the above 14# copyright notice, this list of conditions and the following disclaimer 15# in the documentation and/or other materials provided with the 16# distribution. 17# * Neither the name of Google Inc. nor the names of its 18# contributors may be used to endorse or promote products derived from 19# this software without specific prior written permission. 20# 21# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 33"""Tests python protocol buffers against the golden message. 34 35Note that the golden messages exercise every known field type, thus this 36test ends up exercising and verifying nearly all of the parsing and 37serialization code in the whole library. 38 39TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of 40sense to call this a test of the "message" module, which only declares an 41abstract interface. 42""" 43 44__author__ = 'gps@google.com (Gregory P. Smith)' 45 46 47import collections 48import copy 49import math 50import operator 51import pickle 52import six 53import sys 54 55try: 56 import unittest2 as unittest #PY26 57except ImportError: 58 import unittest 59 60from google.protobuf import map_unittest_pb2 61from google.protobuf import unittest_pb2 62from google.protobuf import unittest_proto3_arena_pb2 63from google.protobuf import descriptor_pb2 64from google.protobuf import descriptor_pool 65from google.protobuf import message_factory 66from google.protobuf import text_format 67from google.protobuf.internal import api_implementation 68from google.protobuf.internal import packed_field_test_pb2 69from google.protobuf.internal import test_util 70from google.protobuf import message 71from google.protobuf.internal import _parameterized 72 73if six.PY3: 74 long = int 75 76 77# Python pre-2.6 does not have isinf() or isnan() functions, so we have 78# to provide our own. 79def isnan(val): 80 # NaN is never equal to itself. 81 return val != val 82def isinf(val): 83 # Infinity times zero equals NaN. 84 return not isnan(val) and isnan(val * 0) 85def IsPosInf(val): 86 return isinf(val) and (val > 0) 87def IsNegInf(val): 88 return isinf(val) and (val < 0) 89 90 91@_parameterized.Parameters( 92 (unittest_pb2), 93 (unittest_proto3_arena_pb2)) 94class MessageTest(unittest.TestCase): 95 96 def testBadUtf8String(self, message_module): 97 if api_implementation.Type() != 'python': 98 self.skipTest("Skipping testBadUtf8String, currently only the python " 99 "api implementation raises UnicodeDecodeError when a " 100 "string field contains bad utf-8.") 101 bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') 102 with self.assertRaises(UnicodeDecodeError) as context: 103 message_module.TestAllTypes.FromString(bad_utf8_data) 104 self.assertIn('TestAllTypes.optional_string', str(context.exception)) 105 106 def testGoldenMessage(self, message_module): 107 # Proto3 doesn't have the "default_foo" members or foreign enums, 108 # and doesn't preserve unknown fields, so for proto3 we use a golden 109 # message that doesn't have these fields set. 110 if message_module is unittest_pb2: 111 golden_data = test_util.GoldenFileData( 112 'golden_message_oneof_implemented') 113 else: 114 golden_data = test_util.GoldenFileData('golden_message_proto3') 115 116 golden_message = message_module.TestAllTypes() 117 golden_message.ParseFromString(golden_data) 118 if message_module is unittest_pb2: 119 test_util.ExpectAllFieldsSet(self, golden_message) 120 self.assertEqual(golden_data, golden_message.SerializeToString()) 121 golden_copy = copy.deepcopy(golden_message) 122 self.assertEqual(golden_data, golden_copy.SerializeToString()) 123 124 def testGoldenPackedMessage(self, message_module): 125 golden_data = test_util.GoldenFileData('golden_packed_fields_message') 126 golden_message = message_module.TestPackedTypes() 127 golden_message.ParseFromString(golden_data) 128 all_set = message_module.TestPackedTypes() 129 test_util.SetAllPackedFields(all_set) 130 self.assertEqual(all_set, golden_message) 131 self.assertEqual(golden_data, all_set.SerializeToString()) 132 golden_copy = copy.deepcopy(golden_message) 133 self.assertEqual(golden_data, golden_copy.SerializeToString()) 134 135 def testPickleSupport(self, message_module): 136 golden_data = test_util.GoldenFileData('golden_message') 137 golden_message = message_module.TestAllTypes() 138 golden_message.ParseFromString(golden_data) 139 pickled_message = pickle.dumps(golden_message) 140 141 unpickled_message = pickle.loads(pickled_message) 142 self.assertEqual(unpickled_message, golden_message) 143 144 def testPositiveInfinity(self, message_module): 145 if message_module is unittest_pb2: 146 golden_data = (b'\x5D\x00\x00\x80\x7F' 147 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 148 b'\xCD\x02\x00\x00\x80\x7F' 149 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') 150 else: 151 golden_data = (b'\x5D\x00\x00\x80\x7F' 152 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 153 b'\xCA\x02\x04\x00\x00\x80\x7F' 154 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 155 156 golden_message = message_module.TestAllTypes() 157 golden_message.ParseFromString(golden_data) 158 self.assertTrue(IsPosInf(golden_message.optional_float)) 159 self.assertTrue(IsPosInf(golden_message.optional_double)) 160 self.assertTrue(IsPosInf(golden_message.repeated_float[0])) 161 self.assertTrue(IsPosInf(golden_message.repeated_double[0])) 162 self.assertEqual(golden_data, golden_message.SerializeToString()) 163 164 def testNegativeInfinity(self, message_module): 165 if message_module is unittest_pb2: 166 golden_data = (b'\x5D\x00\x00\x80\xFF' 167 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 168 b'\xCD\x02\x00\x00\x80\xFF' 169 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') 170 else: 171 golden_data = (b'\x5D\x00\x00\x80\xFF' 172 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 173 b'\xCA\x02\x04\x00\x00\x80\xFF' 174 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 175 176 golden_message = message_module.TestAllTypes() 177 golden_message.ParseFromString(golden_data) 178 self.assertTrue(IsNegInf(golden_message.optional_float)) 179 self.assertTrue(IsNegInf(golden_message.optional_double)) 180 self.assertTrue(IsNegInf(golden_message.repeated_float[0])) 181 self.assertTrue(IsNegInf(golden_message.repeated_double[0])) 182 self.assertEqual(golden_data, golden_message.SerializeToString()) 183 184 def testNotANumber(self, message_module): 185 golden_data = (b'\x5D\x00\x00\xC0\x7F' 186 b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' 187 b'\xCD\x02\x00\x00\xC0\x7F' 188 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') 189 golden_message = message_module.TestAllTypes() 190 golden_message.ParseFromString(golden_data) 191 self.assertTrue(isnan(golden_message.optional_float)) 192 self.assertTrue(isnan(golden_message.optional_double)) 193 self.assertTrue(isnan(golden_message.repeated_float[0])) 194 self.assertTrue(isnan(golden_message.repeated_double[0])) 195 196 # The protocol buffer may serialize to any one of multiple different 197 # representations of a NaN. Rather than verify a specific representation, 198 # verify the serialized string can be converted into a correctly 199 # behaving protocol buffer. 200 serialized = golden_message.SerializeToString() 201 message = message_module.TestAllTypes() 202 message.ParseFromString(serialized) 203 self.assertTrue(isnan(message.optional_float)) 204 self.assertTrue(isnan(message.optional_double)) 205 self.assertTrue(isnan(message.repeated_float[0])) 206 self.assertTrue(isnan(message.repeated_double[0])) 207 208 def testPositiveInfinityPacked(self, message_module): 209 golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' 210 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 211 golden_message = message_module.TestPackedTypes() 212 golden_message.ParseFromString(golden_data) 213 self.assertTrue(IsPosInf(golden_message.packed_float[0])) 214 self.assertTrue(IsPosInf(golden_message.packed_double[0])) 215 self.assertEqual(golden_data, golden_message.SerializeToString()) 216 217 def testNegativeInfinityPacked(self, message_module): 218 golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' 219 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 220 golden_message = message_module.TestPackedTypes() 221 golden_message.ParseFromString(golden_data) 222 self.assertTrue(IsNegInf(golden_message.packed_float[0])) 223 self.assertTrue(IsNegInf(golden_message.packed_double[0])) 224 self.assertEqual(golden_data, golden_message.SerializeToString()) 225 226 def testNotANumberPacked(self, message_module): 227 golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' 228 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') 229 golden_message = message_module.TestPackedTypes() 230 golden_message.ParseFromString(golden_data) 231 self.assertTrue(isnan(golden_message.packed_float[0])) 232 self.assertTrue(isnan(golden_message.packed_double[0])) 233 234 serialized = golden_message.SerializeToString() 235 message = message_module.TestPackedTypes() 236 message.ParseFromString(serialized) 237 self.assertTrue(isnan(message.packed_float[0])) 238 self.assertTrue(isnan(message.packed_double[0])) 239 240 def testExtremeFloatValues(self, message_module): 241 message = message_module.TestAllTypes() 242 243 # Most positive exponent, no significand bits set. 244 kMostPosExponentNoSigBits = math.pow(2, 127) 245 message.optional_float = kMostPosExponentNoSigBits 246 message.ParseFromString(message.SerializeToString()) 247 self.assertTrue(message.optional_float == kMostPosExponentNoSigBits) 248 249 # Most positive exponent, one significand bit set. 250 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127) 251 message.optional_float = kMostPosExponentOneSigBit 252 message.ParseFromString(message.SerializeToString()) 253 self.assertTrue(message.optional_float == kMostPosExponentOneSigBit) 254 255 # Repeat last two cases with values of same magnitude, but negative. 256 message.optional_float = -kMostPosExponentNoSigBits 257 message.ParseFromString(message.SerializeToString()) 258 self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits) 259 260 message.optional_float = -kMostPosExponentOneSigBit 261 message.ParseFromString(message.SerializeToString()) 262 self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit) 263 264 # Most negative exponent, no significand bits set. 265 kMostNegExponentNoSigBits = math.pow(2, -127) 266 message.optional_float = kMostNegExponentNoSigBits 267 message.ParseFromString(message.SerializeToString()) 268 self.assertTrue(message.optional_float == kMostNegExponentNoSigBits) 269 270 # Most negative exponent, one significand bit set. 271 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127) 272 message.optional_float = kMostNegExponentOneSigBit 273 message.ParseFromString(message.SerializeToString()) 274 self.assertTrue(message.optional_float == kMostNegExponentOneSigBit) 275 276 # Repeat last two cases with values of the same magnitude, but negative. 277 message.optional_float = -kMostNegExponentNoSigBits 278 message.ParseFromString(message.SerializeToString()) 279 self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits) 280 281 message.optional_float = -kMostNegExponentOneSigBit 282 message.ParseFromString(message.SerializeToString()) 283 self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) 284 285 def testExtremeDoubleValues(self, message_module): 286 message = message_module.TestAllTypes() 287 288 # Most positive exponent, no significand bits set. 289 kMostPosExponentNoSigBits = math.pow(2, 1023) 290 message.optional_double = kMostPosExponentNoSigBits 291 message.ParseFromString(message.SerializeToString()) 292 self.assertTrue(message.optional_double == kMostPosExponentNoSigBits) 293 294 # Most positive exponent, one significand bit set. 295 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023) 296 message.optional_double = kMostPosExponentOneSigBit 297 message.ParseFromString(message.SerializeToString()) 298 self.assertTrue(message.optional_double == kMostPosExponentOneSigBit) 299 300 # Repeat last two cases with values of same magnitude, but negative. 301 message.optional_double = -kMostPosExponentNoSigBits 302 message.ParseFromString(message.SerializeToString()) 303 self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits) 304 305 message.optional_double = -kMostPosExponentOneSigBit 306 message.ParseFromString(message.SerializeToString()) 307 self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit) 308 309 # Most negative exponent, no significand bits set. 310 kMostNegExponentNoSigBits = math.pow(2, -1023) 311 message.optional_double = kMostNegExponentNoSigBits 312 message.ParseFromString(message.SerializeToString()) 313 self.assertTrue(message.optional_double == kMostNegExponentNoSigBits) 314 315 # Most negative exponent, one significand bit set. 316 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023) 317 message.optional_double = kMostNegExponentOneSigBit 318 message.ParseFromString(message.SerializeToString()) 319 self.assertTrue(message.optional_double == kMostNegExponentOneSigBit) 320 321 # Repeat last two cases with values of the same magnitude, but negative. 322 message.optional_double = -kMostNegExponentNoSigBits 323 message.ParseFromString(message.SerializeToString()) 324 self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits) 325 326 message.optional_double = -kMostNegExponentOneSigBit 327 message.ParseFromString(message.SerializeToString()) 328 self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) 329 330 def testFloatPrinting(self, message_module): 331 message = message_module.TestAllTypes() 332 message.optional_float = 2.0 333 self.assertEqual(str(message), 'optional_float: 2.0\n') 334 335 def testHighPrecisionFloatPrinting(self, message_module): 336 message = message_module.TestAllTypes() 337 message.optional_double = 0.12345678912345678 338 if sys.version_info >= (3,): 339 self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') 340 else: 341 self.assertEqual(str(message), 'optional_double: 0.123456789123\n') 342 343 def testUnknownFieldPrinting(self, message_module): 344 populated = message_module.TestAllTypes() 345 test_util.SetAllNonLazyFields(populated) 346 empty = message_module.TestEmptyMessage() 347 empty.ParseFromString(populated.SerializeToString()) 348 self.assertEqual(str(empty), '') 349 350 def testRepeatedNestedFieldIteration(self, message_module): 351 msg = message_module.TestAllTypes() 352 msg.repeated_nested_message.add(bb=1) 353 msg.repeated_nested_message.add(bb=2) 354 msg.repeated_nested_message.add(bb=3) 355 msg.repeated_nested_message.add(bb=4) 356 357 self.assertEqual([1, 2, 3, 4], 358 [m.bb for m in msg.repeated_nested_message]) 359 self.assertEqual([4, 3, 2, 1], 360 [m.bb for m in reversed(msg.repeated_nested_message)]) 361 self.assertEqual([4, 3, 2, 1], 362 [m.bb for m in msg.repeated_nested_message[::-1]]) 363 364 def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module): 365 """Check some different types with the default comparator.""" 366 message = message_module.TestAllTypes() 367 368 # TODO(mattp): would testing more scalar types strengthen test? 369 message.repeated_int32.append(1) 370 message.repeated_int32.append(3) 371 message.repeated_int32.append(2) 372 message.repeated_int32.sort() 373 self.assertEqual(message.repeated_int32[0], 1) 374 self.assertEqual(message.repeated_int32[1], 2) 375 self.assertEqual(message.repeated_int32[2], 3) 376 377 message.repeated_float.append(1.1) 378 message.repeated_float.append(1.3) 379 message.repeated_float.append(1.2) 380 message.repeated_float.sort() 381 self.assertAlmostEqual(message.repeated_float[0], 1.1) 382 self.assertAlmostEqual(message.repeated_float[1], 1.2) 383 self.assertAlmostEqual(message.repeated_float[2], 1.3) 384 385 message.repeated_string.append('a') 386 message.repeated_string.append('c') 387 message.repeated_string.append('b') 388 message.repeated_string.sort() 389 self.assertEqual(message.repeated_string[0], 'a') 390 self.assertEqual(message.repeated_string[1], 'b') 391 self.assertEqual(message.repeated_string[2], 'c') 392 393 message.repeated_bytes.append(b'a') 394 message.repeated_bytes.append(b'c') 395 message.repeated_bytes.append(b'b') 396 message.repeated_bytes.sort() 397 self.assertEqual(message.repeated_bytes[0], b'a') 398 self.assertEqual(message.repeated_bytes[1], b'b') 399 self.assertEqual(message.repeated_bytes[2], b'c') 400 401 def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): 402 """Check some different types with custom comparator.""" 403 message = message_module.TestAllTypes() 404 405 message.repeated_int32.append(-3) 406 message.repeated_int32.append(-2) 407 message.repeated_int32.append(-1) 408 message.repeated_int32.sort(key=abs) 409 self.assertEqual(message.repeated_int32[0], -1) 410 self.assertEqual(message.repeated_int32[1], -2) 411 self.assertEqual(message.repeated_int32[2], -3) 412 413 message.repeated_string.append('aaa') 414 message.repeated_string.append('bb') 415 message.repeated_string.append('c') 416 message.repeated_string.sort(key=len) 417 self.assertEqual(message.repeated_string[0], 'c') 418 self.assertEqual(message.repeated_string[1], 'bb') 419 self.assertEqual(message.repeated_string[2], 'aaa') 420 421 def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module): 422 """Check passing a custom comparator to sort a repeated composite field.""" 423 message = message_module.TestAllTypes() 424 425 message.repeated_nested_message.add().bb = 1 426 message.repeated_nested_message.add().bb = 3 427 message.repeated_nested_message.add().bb = 2 428 message.repeated_nested_message.add().bb = 6 429 message.repeated_nested_message.add().bb = 5 430 message.repeated_nested_message.add().bb = 4 431 message.repeated_nested_message.sort(key=operator.attrgetter('bb')) 432 self.assertEqual(message.repeated_nested_message[0].bb, 1) 433 self.assertEqual(message.repeated_nested_message[1].bb, 2) 434 self.assertEqual(message.repeated_nested_message[2].bb, 3) 435 self.assertEqual(message.repeated_nested_message[3].bb, 4) 436 self.assertEqual(message.repeated_nested_message[4].bb, 5) 437 self.assertEqual(message.repeated_nested_message[5].bb, 6) 438 439 def testSortingRepeatedCompositeFieldsStable(self, message_module): 440 """Check passing a custom comparator to sort a repeated composite field.""" 441 message = message_module.TestAllTypes() 442 443 message.repeated_nested_message.add().bb = 21 444 message.repeated_nested_message.add().bb = 20 445 message.repeated_nested_message.add().bb = 13 446 message.repeated_nested_message.add().bb = 33 447 message.repeated_nested_message.add().bb = 11 448 message.repeated_nested_message.add().bb = 24 449 message.repeated_nested_message.add().bb = 10 450 message.repeated_nested_message.sort(key=lambda z: z.bb // 10) 451 self.assertEqual( 452 [13, 11, 10, 21, 20, 24, 33], 453 [n.bb for n in message.repeated_nested_message]) 454 455 # Make sure that for the C++ implementation, the underlying fields 456 # are actually reordered. 457 pb = message.SerializeToString() 458 message.Clear() 459 message.MergeFromString(pb) 460 self.assertEqual( 461 [13, 11, 10, 21, 20, 24, 33], 462 [n.bb for n in message.repeated_nested_message]) 463 464 def testRepeatedCompositeFieldSortArguments(self, message_module): 465 """Check sorting a repeated composite field using list.sort() arguments.""" 466 message = message_module.TestAllTypes() 467 468 get_bb = operator.attrgetter('bb') 469 cmp_bb = lambda a, b: cmp(a.bb, b.bb) 470 message.repeated_nested_message.add().bb = 1 471 message.repeated_nested_message.add().bb = 3 472 message.repeated_nested_message.add().bb = 2 473 message.repeated_nested_message.add().bb = 6 474 message.repeated_nested_message.add().bb = 5 475 message.repeated_nested_message.add().bb = 4 476 message.repeated_nested_message.sort(key=get_bb) 477 self.assertEqual([k.bb for k in message.repeated_nested_message], 478 [1, 2, 3, 4, 5, 6]) 479 message.repeated_nested_message.sort(key=get_bb, reverse=True) 480 self.assertEqual([k.bb for k in message.repeated_nested_message], 481 [6, 5, 4, 3, 2, 1]) 482 if sys.version_info >= (3,): return # No cmp sorting in PY3. 483 message.repeated_nested_message.sort(sort_function=cmp_bb) 484 self.assertEqual([k.bb for k in message.repeated_nested_message], 485 [1, 2, 3, 4, 5, 6]) 486 message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) 487 self.assertEqual([k.bb for k in message.repeated_nested_message], 488 [6, 5, 4, 3, 2, 1]) 489 490 def testRepeatedScalarFieldSortArguments(self, message_module): 491 """Check sorting a scalar field using list.sort() arguments.""" 492 message = message_module.TestAllTypes() 493 494 message.repeated_int32.append(-3) 495 message.repeated_int32.append(-2) 496 message.repeated_int32.append(-1) 497 message.repeated_int32.sort(key=abs) 498 self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) 499 message.repeated_int32.sort(key=abs, reverse=True) 500 self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) 501 if sys.version_info < (3,): # No cmp sorting in PY3. 502 abs_cmp = lambda a, b: cmp(abs(a), abs(b)) 503 message.repeated_int32.sort(sort_function=abs_cmp) 504 self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) 505 message.repeated_int32.sort(cmp=abs_cmp, reverse=True) 506 self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) 507 508 message.repeated_string.append('aaa') 509 message.repeated_string.append('bb') 510 message.repeated_string.append('c') 511 message.repeated_string.sort(key=len) 512 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) 513 message.repeated_string.sort(key=len, reverse=True) 514 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) 515 if sys.version_info < (3,): # No cmp sorting in PY3. 516 len_cmp = lambda a, b: cmp(len(a), len(b)) 517 message.repeated_string.sort(sort_function=len_cmp) 518 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) 519 message.repeated_string.sort(cmp=len_cmp, reverse=True) 520 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) 521 522 def testRepeatedFieldsComparable(self, message_module): 523 m1 = message_module.TestAllTypes() 524 m2 = message_module.TestAllTypes() 525 m1.repeated_int32.append(0) 526 m1.repeated_int32.append(1) 527 m1.repeated_int32.append(2) 528 m2.repeated_int32.append(0) 529 m2.repeated_int32.append(1) 530 m2.repeated_int32.append(2) 531 m1.repeated_nested_message.add().bb = 1 532 m1.repeated_nested_message.add().bb = 2 533 m1.repeated_nested_message.add().bb = 3 534 m2.repeated_nested_message.add().bb = 1 535 m2.repeated_nested_message.add().bb = 2 536 m2.repeated_nested_message.add().bb = 3 537 538 if sys.version_info >= (3,): return # No cmp() in PY3. 539 540 # These comparisons should not raise errors. 541 _ = m1 < m2 542 _ = m1.repeated_nested_message < m2.repeated_nested_message 543 544 # Make sure cmp always works. If it wasn't defined, these would be 545 # id() comparisons and would all fail. 546 self.assertEqual(cmp(m1, m2), 0) 547 self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0) 548 self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0) 549 self.assertEqual(cmp(m1.repeated_nested_message, 550 m2.repeated_nested_message), 0) 551 with self.assertRaises(TypeError): 552 # Can't compare repeated composite containers to lists. 553 cmp(m1.repeated_nested_message, m2.repeated_nested_message[:]) 554 555 # TODO(anuraag): Implement extensiondict comparison in C++ and then add test 556 557 def testRepeatedFieldsAreSequences(self, message_module): 558 m = message_module.TestAllTypes() 559 self.assertIsInstance(m.repeated_int32, collections.MutableSequence) 560 self.assertIsInstance(m.repeated_nested_message, 561 collections.MutableSequence) 562 563 def ensureNestedMessageExists(self, msg, attribute): 564 """Make sure that a nested message object exists. 565 566 As soon as a nested message attribute is accessed, it will be present in the 567 _fields dict, without being marked as actually being set. 568 """ 569 getattr(msg, attribute) 570 self.assertFalse(msg.HasField(attribute)) 571 572 def testOneofGetCaseNonexistingField(self, message_module): 573 m = message_module.TestAllTypes() 574 self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') 575 576 def testOneofDefaultValues(self, message_module): 577 m = message_module.TestAllTypes() 578 self.assertIs(None, m.WhichOneof('oneof_field')) 579 self.assertFalse(m.HasField('oneof_uint32')) 580 581 # Oneof is set even when setting it to a default value. 582 m.oneof_uint32 = 0 583 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 584 self.assertTrue(m.HasField('oneof_uint32')) 585 self.assertFalse(m.HasField('oneof_string')) 586 587 m.oneof_string = "" 588 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 589 self.assertTrue(m.HasField('oneof_string')) 590 self.assertFalse(m.HasField('oneof_uint32')) 591 592 def testOneofSemantics(self, message_module): 593 m = message_module.TestAllTypes() 594 self.assertIs(None, m.WhichOneof('oneof_field')) 595 596 m.oneof_uint32 = 11 597 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 598 self.assertTrue(m.HasField('oneof_uint32')) 599 600 m.oneof_string = u'foo' 601 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 602 self.assertFalse(m.HasField('oneof_uint32')) 603 self.assertTrue(m.HasField('oneof_string')) 604 605 # Read nested message accessor without accessing submessage. 606 m.oneof_nested_message 607 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 608 self.assertTrue(m.HasField('oneof_string')) 609 self.assertFalse(m.HasField('oneof_nested_message')) 610 611 # Read accessor of nested message without accessing submessage. 612 m.oneof_nested_message.bb 613 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 614 self.assertTrue(m.HasField('oneof_string')) 615 self.assertFalse(m.HasField('oneof_nested_message')) 616 617 m.oneof_nested_message.bb = 11 618 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 619 self.assertFalse(m.HasField('oneof_string')) 620 self.assertTrue(m.HasField('oneof_nested_message')) 621 622 m.oneof_bytes = b'bb' 623 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 624 self.assertFalse(m.HasField('oneof_nested_message')) 625 self.assertTrue(m.HasField('oneof_bytes')) 626 627 def testOneofCompositeFieldReadAccess(self, message_module): 628 m = message_module.TestAllTypes() 629 m.oneof_uint32 = 11 630 631 self.ensureNestedMessageExists(m, 'oneof_nested_message') 632 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 633 self.assertEqual(11, m.oneof_uint32) 634 635 def testOneofWhichOneof(self, message_module): 636 m = message_module.TestAllTypes() 637 self.assertIs(None, m.WhichOneof('oneof_field')) 638 if message_module is unittest_pb2: 639 self.assertFalse(m.HasField('oneof_field')) 640 641 m.oneof_uint32 = 11 642 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 643 if message_module is unittest_pb2: 644 self.assertTrue(m.HasField('oneof_field')) 645 646 m.oneof_bytes = b'bb' 647 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 648 649 m.ClearField('oneof_bytes') 650 self.assertIs(None, m.WhichOneof('oneof_field')) 651 if message_module is unittest_pb2: 652 self.assertFalse(m.HasField('oneof_field')) 653 654 def testOneofClearField(self, message_module): 655 m = message_module.TestAllTypes() 656 m.oneof_uint32 = 11 657 m.ClearField('oneof_field') 658 if message_module is unittest_pb2: 659 self.assertFalse(m.HasField('oneof_field')) 660 self.assertFalse(m.HasField('oneof_uint32')) 661 self.assertIs(None, m.WhichOneof('oneof_field')) 662 663 def testOneofClearSetField(self, message_module): 664 m = message_module.TestAllTypes() 665 m.oneof_uint32 = 11 666 m.ClearField('oneof_uint32') 667 if message_module is unittest_pb2: 668 self.assertFalse(m.HasField('oneof_field')) 669 self.assertFalse(m.HasField('oneof_uint32')) 670 self.assertIs(None, m.WhichOneof('oneof_field')) 671 672 def testOneofClearUnsetField(self, message_module): 673 m = message_module.TestAllTypes() 674 m.oneof_uint32 = 11 675 self.ensureNestedMessageExists(m, 'oneof_nested_message') 676 m.ClearField('oneof_nested_message') 677 self.assertEqual(11, m.oneof_uint32) 678 if message_module is unittest_pb2: 679 self.assertTrue(m.HasField('oneof_field')) 680 self.assertTrue(m.HasField('oneof_uint32')) 681 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 682 683 def testOneofDeserialize(self, message_module): 684 m = message_module.TestAllTypes() 685 m.oneof_uint32 = 11 686 m2 = message_module.TestAllTypes() 687 m2.ParseFromString(m.SerializeToString()) 688 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 689 690 def testOneofCopyFrom(self, message_module): 691 m = message_module.TestAllTypes() 692 m.oneof_uint32 = 11 693 m2 = message_module.TestAllTypes() 694 m2.CopyFrom(m) 695 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 696 697 def testOneofNestedMergeFrom(self, message_module): 698 m = message_module.NestedTestAllTypes() 699 m.payload.oneof_uint32 = 11 700 m2 = message_module.NestedTestAllTypes() 701 m2.payload.oneof_bytes = b'bb' 702 m2.child.payload.oneof_bytes = b'bb' 703 m2.MergeFrom(m) 704 self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field')) 705 self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field')) 706 707 def testOneofMessageMergeFrom(self, message_module): 708 m = message_module.NestedTestAllTypes() 709 m.payload.oneof_nested_message.bb = 11 710 m.child.payload.oneof_nested_message.bb = 12 711 m2 = message_module.NestedTestAllTypes() 712 m2.payload.oneof_uint32 = 13 713 m2.MergeFrom(m) 714 self.assertEqual('oneof_nested_message', 715 m2.payload.WhichOneof('oneof_field')) 716 self.assertEqual('oneof_nested_message', 717 m2.child.payload.WhichOneof('oneof_field')) 718 719 def testOneofNestedMessageInit(self, message_module): 720 m = message_module.TestAllTypes( 721 oneof_nested_message=message_module.TestAllTypes.NestedMessage()) 722 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 723 724 def testOneofClear(self, message_module): 725 m = message_module.TestAllTypes() 726 m.oneof_uint32 = 11 727 m.Clear() 728 self.assertIsNone(m.WhichOneof('oneof_field')) 729 m.oneof_bytes = b'bb' 730 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 731 732 def testAssignByteStringToUnicodeField(self, message_module): 733 """Assigning a byte string to a string field should result 734 in the value being converted to a Unicode string.""" 735 m = message_module.TestAllTypes() 736 m.optional_string = str('') 737 self.assertIsInstance(m.optional_string, six.text_type) 738 739 def testLongValuedSlice(self, message_module): 740 """It should be possible to use long-valued indicies in slices 741 742 This didn't used to work in the v2 C++ implementation. 743 """ 744 m = message_module.TestAllTypes() 745 746 # Repeated scalar 747 m.repeated_int32.append(1) 748 sl = m.repeated_int32[long(0):long(len(m.repeated_int32))] 749 self.assertEqual(len(m.repeated_int32), len(sl)) 750 751 # Repeated composite 752 m.repeated_nested_message.add().bb = 3 753 sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))] 754 self.assertEqual(len(m.repeated_nested_message), len(sl)) 755 756 def testExtendShouldNotSwallowExceptions(self, message_module): 757 """This didn't use to work in the v2 C++ implementation.""" 758 m = message_module.TestAllTypes() 759 with self.assertRaises(NameError) as _: 760 m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable 761 with self.assertRaises(NameError) as _: 762 m.repeated_nested_enum.extend( 763 a for i in range(10)) # pylint: disable=undefined-variable 764 765 FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()] 766 767 def testExtendInt32WithNothing(self, message_module): 768 """Test no-ops extending repeated int32 fields.""" 769 m = message_module.TestAllTypes() 770 self.assertSequenceEqual([], m.repeated_int32) 771 772 # TODO(ptucker): Deprecate this behavior. b/18413862 773 for falsy_value in MessageTest.FALSY_VALUES: 774 m.repeated_int32.extend(falsy_value) 775 self.assertSequenceEqual([], m.repeated_int32) 776 777 m.repeated_int32.extend([]) 778 self.assertSequenceEqual([], m.repeated_int32) 779 780 def testExtendFloatWithNothing(self, message_module): 781 """Test no-ops extending repeated float fields.""" 782 m = message_module.TestAllTypes() 783 self.assertSequenceEqual([], m.repeated_float) 784 785 # TODO(ptucker): Deprecate this behavior. b/18413862 786 for falsy_value in MessageTest.FALSY_VALUES: 787 m.repeated_float.extend(falsy_value) 788 self.assertSequenceEqual([], m.repeated_float) 789 790 m.repeated_float.extend([]) 791 self.assertSequenceEqual([], m.repeated_float) 792 793 def testExtendStringWithNothing(self, message_module): 794 """Test no-ops extending repeated string fields.""" 795 m = message_module.TestAllTypes() 796 self.assertSequenceEqual([], m.repeated_string) 797 798 # TODO(ptucker): Deprecate this behavior. b/18413862 799 for falsy_value in MessageTest.FALSY_VALUES: 800 m.repeated_string.extend(falsy_value) 801 self.assertSequenceEqual([], m.repeated_string) 802 803 m.repeated_string.extend([]) 804 self.assertSequenceEqual([], m.repeated_string) 805 806 def testExtendInt32WithPythonList(self, message_module): 807 """Test extending repeated int32 fields with python lists.""" 808 m = message_module.TestAllTypes() 809 self.assertSequenceEqual([], m.repeated_int32) 810 m.repeated_int32.extend([0]) 811 self.assertSequenceEqual([0], m.repeated_int32) 812 m.repeated_int32.extend([1, 2]) 813 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 814 m.repeated_int32.extend([3, 4]) 815 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 816 817 def testExtendFloatWithPythonList(self, message_module): 818 """Test extending repeated float fields with python lists.""" 819 m = message_module.TestAllTypes() 820 self.assertSequenceEqual([], m.repeated_float) 821 m.repeated_float.extend([0.0]) 822 self.assertSequenceEqual([0.0], m.repeated_float) 823 m.repeated_float.extend([1.0, 2.0]) 824 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 825 m.repeated_float.extend([3.0, 4.0]) 826 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 827 828 def testExtendStringWithPythonList(self, message_module): 829 """Test extending repeated string fields with python lists.""" 830 m = message_module.TestAllTypes() 831 self.assertSequenceEqual([], m.repeated_string) 832 m.repeated_string.extend(['']) 833 self.assertSequenceEqual([''], m.repeated_string) 834 m.repeated_string.extend(['11', '22']) 835 self.assertSequenceEqual(['', '11', '22'], m.repeated_string) 836 m.repeated_string.extend(['33', '44']) 837 self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string) 838 839 def testExtendStringWithString(self, message_module): 840 """Test extending repeated string fields with characters from a string.""" 841 m = message_module.TestAllTypes() 842 self.assertSequenceEqual([], m.repeated_string) 843 m.repeated_string.extend('abc') 844 self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string) 845 846 class TestIterable(object): 847 """This iterable object mimics the behavior of numpy.array. 848 849 __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1. 850 851 """ 852 853 def __init__(self, values=None): 854 self._list = values or [] 855 856 def __nonzero__(self): 857 size = len(self._list) 858 if size == 0: 859 return False 860 if size == 1: 861 return bool(self._list[0]) 862 raise ValueError('Truth value is ambiguous.') 863 864 def __len__(self): 865 return len(self._list) 866 867 def __iter__(self): 868 return self._list.__iter__() 869 870 def testExtendInt32WithIterable(self, message_module): 871 """Test extending repeated int32 fields with iterable.""" 872 m = message_module.TestAllTypes() 873 self.assertSequenceEqual([], m.repeated_int32) 874 m.repeated_int32.extend(MessageTest.TestIterable([])) 875 self.assertSequenceEqual([], m.repeated_int32) 876 m.repeated_int32.extend(MessageTest.TestIterable([0])) 877 self.assertSequenceEqual([0], m.repeated_int32) 878 m.repeated_int32.extend(MessageTest.TestIterable([1, 2])) 879 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 880 m.repeated_int32.extend(MessageTest.TestIterable([3, 4])) 881 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 882 883 def testExtendFloatWithIterable(self, message_module): 884 """Test extending repeated float fields with iterable.""" 885 m = message_module.TestAllTypes() 886 self.assertSequenceEqual([], m.repeated_float) 887 m.repeated_float.extend(MessageTest.TestIterable([])) 888 self.assertSequenceEqual([], m.repeated_float) 889 m.repeated_float.extend(MessageTest.TestIterable([0.0])) 890 self.assertSequenceEqual([0.0], m.repeated_float) 891 m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0])) 892 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 893 m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0])) 894 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 895 896 def testExtendStringWithIterable(self, message_module): 897 """Test extending repeated string fields with iterable.""" 898 m = message_module.TestAllTypes() 899 self.assertSequenceEqual([], m.repeated_string) 900 m.repeated_string.extend(MessageTest.TestIterable([])) 901 self.assertSequenceEqual([], m.repeated_string) 902 m.repeated_string.extend(MessageTest.TestIterable([''])) 903 self.assertSequenceEqual([''], m.repeated_string) 904 m.repeated_string.extend(MessageTest.TestIterable(['1', '2'])) 905 self.assertSequenceEqual(['', '1', '2'], m.repeated_string) 906 m.repeated_string.extend(MessageTest.TestIterable(['3', '4'])) 907 self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string) 908 909 def testPickleRepeatedScalarContainer(self, message_module): 910 # TODO(tibell): The pure-Python implementation support pickling of 911 # scalar containers in *some* cases. For now the cpp2 version 912 # throws an exception to avoid a segfault. Investigate if we 913 # want to support pickling of these fields. 914 # 915 # For more information see: https://b2.corp.google.com/u/0/issues/18677897 916 if (api_implementation.Type() != 'cpp' or 917 api_implementation.Version() == 2): 918 return 919 m = message_module.TestAllTypes() 920 with self.assertRaises(pickle.PickleError) as _: 921 pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) 922 923 def testSortEmptyRepeatedCompositeContainer(self, message_module): 924 """Exercise a scenario that has led to segfaults in the past. 925 """ 926 m = message_module.TestAllTypes() 927 m.repeated_nested_message.sort() 928 929 def testHasFieldOnRepeatedField(self, message_module): 930 """Using HasField on a repeated field should raise an exception. 931 """ 932 m = message_module.TestAllTypes() 933 with self.assertRaises(ValueError) as _: 934 m.HasField('repeated_int32') 935 936 def testRepeatedScalarFieldPop(self, message_module): 937 m = message_module.TestAllTypes() 938 with self.assertRaises(IndexError) as _: 939 m.repeated_int32.pop() 940 m.repeated_int32.extend(range(5)) 941 self.assertEqual(4, m.repeated_int32.pop()) 942 self.assertEqual(0, m.repeated_int32.pop(0)) 943 self.assertEqual(2, m.repeated_int32.pop(1)) 944 self.assertEqual([1, 3], m.repeated_int32) 945 946 def testRepeatedCompositeFieldPop(self, message_module): 947 m = message_module.TestAllTypes() 948 with self.assertRaises(IndexError) as _: 949 m.repeated_nested_message.pop() 950 for i in range(5): 951 n = m.repeated_nested_message.add() 952 n.bb = i 953 self.assertEqual(4, m.repeated_nested_message.pop().bb) 954 self.assertEqual(0, m.repeated_nested_message.pop(0).bb) 955 self.assertEqual(2, m.repeated_nested_message.pop(1).bb) 956 self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message]) 957 958 959# Class to test proto2-only features (required, extensions, etc.) 960class Proto2Test(unittest.TestCase): 961 962 def testFieldPresence(self): 963 message = unittest_pb2.TestAllTypes() 964 965 self.assertFalse(message.HasField("optional_int32")) 966 self.assertFalse(message.HasField("optional_bool")) 967 self.assertFalse(message.HasField("optional_nested_message")) 968 969 with self.assertRaises(ValueError): 970 message.HasField("field_doesnt_exist") 971 972 with self.assertRaises(ValueError): 973 message.HasField("repeated_int32") 974 with self.assertRaises(ValueError): 975 message.HasField("repeated_nested_message") 976 977 self.assertEqual(0, message.optional_int32) 978 self.assertEqual(False, message.optional_bool) 979 self.assertEqual(0, message.optional_nested_message.bb) 980 981 # Fields are set even when setting the values to default values. 982 message.optional_int32 = 0 983 message.optional_bool = False 984 message.optional_nested_message.bb = 0 985 self.assertTrue(message.HasField("optional_int32")) 986 self.assertTrue(message.HasField("optional_bool")) 987 self.assertTrue(message.HasField("optional_nested_message")) 988 989 # Set the fields to non-default values. 990 message.optional_int32 = 5 991 message.optional_bool = True 992 message.optional_nested_message.bb = 15 993 994 self.assertTrue(message.HasField("optional_int32")) 995 self.assertTrue(message.HasField("optional_bool")) 996 self.assertTrue(message.HasField("optional_nested_message")) 997 998 # Clearing the fields unsets them and resets their value to default. 999 message.ClearField("optional_int32") 1000 message.ClearField("optional_bool") 1001 message.ClearField("optional_nested_message") 1002 1003 self.assertFalse(message.HasField("optional_int32")) 1004 self.assertFalse(message.HasField("optional_bool")) 1005 self.assertFalse(message.HasField("optional_nested_message")) 1006 self.assertEqual(0, message.optional_int32) 1007 self.assertEqual(False, message.optional_bool) 1008 self.assertEqual(0, message.optional_nested_message.bb) 1009 1010 # TODO(tibell): The C++ implementations actually allows assignment 1011 # of unknown enum values to *scalar* fields (but not repeated 1012 # fields). Once checked enum fields becomes the default in the 1013 # Python implementation, the C++ implementation should follow suit. 1014 def testAssignInvalidEnum(self): 1015 """It should not be possible to assign an invalid enum number to an 1016 enum field.""" 1017 m = unittest_pb2.TestAllTypes() 1018 1019 with self.assertRaises(ValueError) as _: 1020 m.optional_nested_enum = 1234567 1021 self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567) 1022 1023 def testGoldenExtensions(self): 1024 golden_data = test_util.GoldenFileData('golden_message') 1025 golden_message = unittest_pb2.TestAllExtensions() 1026 golden_message.ParseFromString(golden_data) 1027 all_set = unittest_pb2.TestAllExtensions() 1028 test_util.SetAllExtensions(all_set) 1029 self.assertEqual(all_set, golden_message) 1030 self.assertEqual(golden_data, golden_message.SerializeToString()) 1031 golden_copy = copy.deepcopy(golden_message) 1032 self.assertEqual(golden_data, golden_copy.SerializeToString()) 1033 1034 def testGoldenPackedExtensions(self): 1035 golden_data = test_util.GoldenFileData('golden_packed_fields_message') 1036 golden_message = unittest_pb2.TestPackedExtensions() 1037 golden_message.ParseFromString(golden_data) 1038 all_set = unittest_pb2.TestPackedExtensions() 1039 test_util.SetAllPackedExtensions(all_set) 1040 self.assertEqual(all_set, golden_message) 1041 self.assertEqual(golden_data, all_set.SerializeToString()) 1042 golden_copy = copy.deepcopy(golden_message) 1043 self.assertEqual(golden_data, golden_copy.SerializeToString()) 1044 1045 def testPickleIncompleteProto(self): 1046 golden_message = unittest_pb2.TestRequired(a=1) 1047 pickled_message = pickle.dumps(golden_message) 1048 1049 unpickled_message = pickle.loads(pickled_message) 1050 self.assertEqual(unpickled_message, golden_message) 1051 self.assertEqual(unpickled_message.a, 1) 1052 # This is still an incomplete proto - so serializing should fail 1053 self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) 1054 1055 1056 # TODO(haberman): this isn't really a proto2-specific test except that this 1057 # message has a required field in it. Should probably be factored out so 1058 # that we can test the other parts with proto3. 1059 def testParsingMerge(self): 1060 """Check the merge behavior when a required or optional field appears 1061 multiple times in the input.""" 1062 messages = [ 1063 unittest_pb2.TestAllTypes(), 1064 unittest_pb2.TestAllTypes(), 1065 unittest_pb2.TestAllTypes() ] 1066 messages[0].optional_int32 = 1 1067 messages[1].optional_int64 = 2 1068 messages[2].optional_int32 = 3 1069 messages[2].optional_string = 'hello' 1070 1071 merged_message = unittest_pb2.TestAllTypes() 1072 merged_message.optional_int32 = 3 1073 merged_message.optional_int64 = 2 1074 merged_message.optional_string = 'hello' 1075 1076 generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() 1077 generator.field1.extend(messages) 1078 generator.field2.extend(messages) 1079 generator.field3.extend(messages) 1080 generator.ext1.extend(messages) 1081 generator.ext2.extend(messages) 1082 generator.group1.add().field1.MergeFrom(messages[0]) 1083 generator.group1.add().field1.MergeFrom(messages[1]) 1084 generator.group1.add().field1.MergeFrom(messages[2]) 1085 generator.group2.add().field1.MergeFrom(messages[0]) 1086 generator.group2.add().field1.MergeFrom(messages[1]) 1087 generator.group2.add().field1.MergeFrom(messages[2]) 1088 1089 data = generator.SerializeToString() 1090 parsing_merge = unittest_pb2.TestParsingMerge() 1091 parsing_merge.ParseFromString(data) 1092 1093 # Required and optional fields should be merged. 1094 self.assertEqual(parsing_merge.required_all_types, merged_message) 1095 self.assertEqual(parsing_merge.optional_all_types, merged_message) 1096 self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, 1097 merged_message) 1098 self.assertEqual(parsing_merge.Extensions[ 1099 unittest_pb2.TestParsingMerge.optional_ext], 1100 merged_message) 1101 1102 # Repeated fields should not be merged. 1103 self.assertEqual(len(parsing_merge.repeated_all_types), 3) 1104 self.assertEqual(len(parsing_merge.repeatedgroup), 3) 1105 self.assertEqual(len(parsing_merge.Extensions[ 1106 unittest_pb2.TestParsingMerge.repeated_ext]), 3) 1107 1108 def testPythonicInit(self): 1109 message = unittest_pb2.TestAllTypes( 1110 optional_int32=100, 1111 optional_fixed32=200, 1112 optional_float=300.5, 1113 optional_bytes=b'x', 1114 optionalgroup={'a': 400}, 1115 optional_nested_message={'bb': 500}, 1116 optional_nested_enum='BAZ', 1117 repeatedgroup=[{'a': 600}, 1118 {'a': 700}], 1119 repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], 1120 default_int32=800, 1121 oneof_string='y') 1122 self.assertIsInstance(message, unittest_pb2.TestAllTypes) 1123 self.assertEqual(100, message.optional_int32) 1124 self.assertEqual(200, message.optional_fixed32) 1125 self.assertEqual(300.5, message.optional_float) 1126 self.assertEqual(b'x', message.optional_bytes) 1127 self.assertEqual(400, message.optionalgroup.a) 1128 self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage) 1129 self.assertEqual(500, message.optional_nested_message.bb) 1130 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1131 message.optional_nested_enum) 1132 self.assertEqual(2, len(message.repeatedgroup)) 1133 self.assertEqual(600, message.repeatedgroup[0].a) 1134 self.assertEqual(700, message.repeatedgroup[1].a) 1135 self.assertEqual(2, len(message.repeated_nested_enum)) 1136 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 1137 message.repeated_nested_enum[0]) 1138 self.assertEqual(unittest_pb2.TestAllTypes.BAR, 1139 message.repeated_nested_enum[1]) 1140 self.assertEqual(800, message.default_int32) 1141 self.assertEqual('y', message.oneof_string) 1142 self.assertFalse(message.HasField('optional_int64')) 1143 self.assertEqual(0, len(message.repeated_float)) 1144 self.assertEqual(42, message.default_int64) 1145 1146 message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ') 1147 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1148 message.optional_nested_enum) 1149 1150 with self.assertRaises(ValueError): 1151 unittest_pb2.TestAllTypes( 1152 optional_nested_message={'INVALID_NESTED_FIELD': 17}) 1153 1154 with self.assertRaises(TypeError): 1155 unittest_pb2.TestAllTypes( 1156 optional_nested_message={'bb': 'INVALID_VALUE_TYPE'}) 1157 1158 with self.assertRaises(ValueError): 1159 unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL') 1160 1161 with self.assertRaises(ValueError): 1162 unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') 1163 1164 1165 1166# Class to test proto3-only features/behavior (updated field presence & enums) 1167class Proto3Test(unittest.TestCase): 1168 1169 # Utility method for comparing equality with a map. 1170 def assertMapIterEquals(self, map_iter, dict_value): 1171 # Avoid mutating caller's copy. 1172 dict_value = dict(dict_value) 1173 1174 for k, v in map_iter: 1175 self.assertEqual(v, dict_value[k]) 1176 del dict_value[k] 1177 1178 self.assertEqual({}, dict_value) 1179 1180 def testFieldPresence(self): 1181 message = unittest_proto3_arena_pb2.TestAllTypes() 1182 1183 # We can't test presence of non-repeated, non-submessage fields. 1184 with self.assertRaises(ValueError): 1185 message.HasField('optional_int32') 1186 with self.assertRaises(ValueError): 1187 message.HasField('optional_float') 1188 with self.assertRaises(ValueError): 1189 message.HasField('optional_string') 1190 with self.assertRaises(ValueError): 1191 message.HasField('optional_bool') 1192 1193 # But we can still test presence of submessage fields. 1194 self.assertFalse(message.HasField('optional_nested_message')) 1195 1196 # As with proto2, we can't test presence of fields that don't exist, or 1197 # repeated fields. 1198 with self.assertRaises(ValueError): 1199 message.HasField('field_doesnt_exist') 1200 1201 with self.assertRaises(ValueError): 1202 message.HasField('repeated_int32') 1203 with self.assertRaises(ValueError): 1204 message.HasField('repeated_nested_message') 1205 1206 # Fields should default to their type-specific default. 1207 self.assertEqual(0, message.optional_int32) 1208 self.assertEqual(0, message.optional_float) 1209 self.assertEqual('', message.optional_string) 1210 self.assertEqual(False, message.optional_bool) 1211 self.assertEqual(0, message.optional_nested_message.bb) 1212 1213 # Setting a submessage should still return proper presence information. 1214 message.optional_nested_message.bb = 0 1215 self.assertTrue(message.HasField('optional_nested_message')) 1216 1217 # Set the fields to non-default values. 1218 message.optional_int32 = 5 1219 message.optional_float = 1.1 1220 message.optional_string = 'abc' 1221 message.optional_bool = True 1222 message.optional_nested_message.bb = 15 1223 1224 # Clearing the fields unsets them and resets their value to default. 1225 message.ClearField('optional_int32') 1226 message.ClearField('optional_float') 1227 message.ClearField('optional_string') 1228 message.ClearField('optional_bool') 1229 message.ClearField('optional_nested_message') 1230 1231 self.assertEqual(0, message.optional_int32) 1232 self.assertEqual(0, message.optional_float) 1233 self.assertEqual('', message.optional_string) 1234 self.assertEqual(False, message.optional_bool) 1235 self.assertEqual(0, message.optional_nested_message.bb) 1236 1237 def testAssignUnknownEnum(self): 1238 """Assigning an unknown enum value is allowed and preserves the value.""" 1239 m = unittest_proto3_arena_pb2.TestAllTypes() 1240 1241 m.optional_nested_enum = 1234567 1242 self.assertEqual(1234567, m.optional_nested_enum) 1243 m.repeated_nested_enum.append(22334455) 1244 self.assertEqual(22334455, m.repeated_nested_enum[0]) 1245 # Assignment is a different code path than append for the C++ impl. 1246 m.repeated_nested_enum[0] = 7654321 1247 self.assertEqual(7654321, m.repeated_nested_enum[0]) 1248 serialized = m.SerializeToString() 1249 1250 m2 = unittest_proto3_arena_pb2.TestAllTypes() 1251 m2.ParseFromString(serialized) 1252 self.assertEqual(1234567, m2.optional_nested_enum) 1253 self.assertEqual(7654321, m2.repeated_nested_enum[0]) 1254 1255 # Map isn't really a proto3-only feature. But there is no proto2 equivalent 1256 # of google/protobuf/map_unittest.proto right now, so it's not easy to 1257 # test both with the same test like we do for the other proto2/proto3 tests. 1258 # (google/protobuf/map_protobuf_unittest.proto is very different in the set 1259 # of messages and fields it contains). 1260 def testScalarMapDefaults(self): 1261 msg = map_unittest_pb2.TestMap() 1262 1263 # Scalars start out unset. 1264 self.assertFalse(-123 in msg.map_int32_int32) 1265 self.assertFalse(-2**33 in msg.map_int64_int64) 1266 self.assertFalse(123 in msg.map_uint32_uint32) 1267 self.assertFalse(2**33 in msg.map_uint64_uint64) 1268 self.assertFalse(123 in msg.map_int32_double) 1269 self.assertFalse(False in msg.map_bool_bool) 1270 self.assertFalse('abc' in msg.map_string_string) 1271 self.assertFalse(111 in msg.map_int32_bytes) 1272 self.assertFalse(888 in msg.map_int32_enum) 1273 1274 # Accessing an unset key returns the default. 1275 self.assertEqual(0, msg.map_int32_int32[-123]) 1276 self.assertEqual(0, msg.map_int64_int64[-2**33]) 1277 self.assertEqual(0, msg.map_uint32_uint32[123]) 1278 self.assertEqual(0, msg.map_uint64_uint64[2**33]) 1279 self.assertEqual(0.0, msg.map_int32_double[123]) 1280 self.assertTrue(isinstance(msg.map_int32_double[123], float)) 1281 self.assertEqual(False, msg.map_bool_bool[False]) 1282 self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) 1283 self.assertEqual('', msg.map_string_string['abc']) 1284 self.assertEqual(b'', msg.map_int32_bytes[111]) 1285 self.assertEqual(0, msg.map_int32_enum[888]) 1286 1287 # It also sets the value in the map 1288 self.assertTrue(-123 in msg.map_int32_int32) 1289 self.assertTrue(-2**33 in msg.map_int64_int64) 1290 self.assertTrue(123 in msg.map_uint32_uint32) 1291 self.assertTrue(2**33 in msg.map_uint64_uint64) 1292 self.assertTrue(123 in msg.map_int32_double) 1293 self.assertTrue(False in msg.map_bool_bool) 1294 self.assertTrue('abc' in msg.map_string_string) 1295 self.assertTrue(111 in msg.map_int32_bytes) 1296 self.assertTrue(888 in msg.map_int32_enum) 1297 1298 self.assertIsInstance(msg.map_string_string['abc'], six.text_type) 1299 1300 # Accessing an unset key still throws TypeError if the type of the key 1301 # is incorrect. 1302 with self.assertRaises(TypeError): 1303 msg.map_string_string[123] 1304 1305 with self.assertRaises(TypeError): 1306 123 in msg.map_string_string 1307 1308 def testMapGet(self): 1309 # Need to test that get() properly returns the default, even though the dict 1310 # has defaultdict-like semantics. 1311 msg = map_unittest_pb2.TestMap() 1312 1313 self.assertIsNone(msg.map_int32_int32.get(5)) 1314 self.assertEqual(10, msg.map_int32_int32.get(5, 10)) 1315 self.assertIsNone(msg.map_int32_int32.get(5)) 1316 1317 msg.map_int32_int32[5] = 15 1318 self.assertEqual(15, msg.map_int32_int32.get(5)) 1319 1320 self.assertIsNone(msg.map_int32_foreign_message.get(5)) 1321 self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10)) 1322 1323 submsg = msg.map_int32_foreign_message[5] 1324 self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) 1325 1326 def testScalarMap(self): 1327 msg = map_unittest_pb2.TestMap() 1328 1329 self.assertEqual(0, len(msg.map_int32_int32)) 1330 self.assertFalse(5 in msg.map_int32_int32) 1331 1332 msg.map_int32_int32[-123] = -456 1333 msg.map_int64_int64[-2**33] = -2**34 1334 msg.map_uint32_uint32[123] = 456 1335 msg.map_uint64_uint64[2**33] = 2**34 1336 msg.map_string_string['abc'] = '123' 1337 msg.map_int32_enum[888] = 2 1338 1339 self.assertEqual([], msg.FindInitializationErrors()) 1340 1341 self.assertEqual(1, len(msg.map_string_string)) 1342 1343 # Bad key. 1344 with self.assertRaises(TypeError): 1345 msg.map_string_string[123] = '123' 1346 1347 # Verify that trying to assign a bad key doesn't actually add a member to 1348 # the map. 1349 self.assertEqual(1, len(msg.map_string_string)) 1350 1351 # Bad value. 1352 with self.assertRaises(TypeError): 1353 msg.map_string_string['123'] = 123 1354 1355 serialized = msg.SerializeToString() 1356 msg2 = map_unittest_pb2.TestMap() 1357 msg2.ParseFromString(serialized) 1358 1359 # Bad key. 1360 with self.assertRaises(TypeError): 1361 msg2.map_string_string[123] = '123' 1362 1363 # Bad value. 1364 with self.assertRaises(TypeError): 1365 msg2.map_string_string['123'] = 123 1366 1367 self.assertEqual(-456, msg2.map_int32_int32[-123]) 1368 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 1369 self.assertEqual(456, msg2.map_uint32_uint32[123]) 1370 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 1371 self.assertEqual('123', msg2.map_string_string['abc']) 1372 self.assertEqual(2, msg2.map_int32_enum[888]) 1373 1374 def testStringUnicodeConversionInMap(self): 1375 msg = map_unittest_pb2.TestMap() 1376 1377 unicode_obj = u'\u1234' 1378 bytes_obj = unicode_obj.encode('utf8') 1379 1380 msg.map_string_string[bytes_obj] = bytes_obj 1381 1382 (key, value) = list(msg.map_string_string.items())[0] 1383 1384 self.assertEqual(key, unicode_obj) 1385 self.assertEqual(value, unicode_obj) 1386 1387 self.assertIsInstance(key, six.text_type) 1388 self.assertIsInstance(value, six.text_type) 1389 1390 def testMessageMap(self): 1391 msg = map_unittest_pb2.TestMap() 1392 1393 self.assertEqual(0, len(msg.map_int32_foreign_message)) 1394 self.assertFalse(5 in msg.map_int32_foreign_message) 1395 1396 msg.map_int32_foreign_message[123] 1397 # get_or_create() is an alias for getitem. 1398 msg.map_int32_foreign_message.get_or_create(-456) 1399 1400 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1401 self.assertIn(123, msg.map_int32_foreign_message) 1402 self.assertIn(-456, msg.map_int32_foreign_message) 1403 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1404 1405 # Bad key. 1406 with self.assertRaises(TypeError): 1407 msg.map_int32_foreign_message['123'] 1408 1409 # Can't assign directly to submessage. 1410 with self.assertRaises(ValueError): 1411 msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] 1412 1413 # Verify that trying to assign a bad key doesn't actually add a member to 1414 # the map. 1415 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1416 1417 serialized = msg.SerializeToString() 1418 msg2 = map_unittest_pb2.TestMap() 1419 msg2.ParseFromString(serialized) 1420 1421 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 1422 self.assertIn(123, msg2.map_int32_foreign_message) 1423 self.assertIn(-456, msg2.map_int32_foreign_message) 1424 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 1425 1426 def testMergeFrom(self): 1427 msg = map_unittest_pb2.TestMap() 1428 msg.map_int32_int32[12] = 34 1429 msg.map_int32_int32[56] = 78 1430 msg.map_int64_int64[22] = 33 1431 msg.map_int32_foreign_message[111].c = 5 1432 msg.map_int32_foreign_message[222].c = 10 1433 1434 msg2 = map_unittest_pb2.TestMap() 1435 msg2.map_int32_int32[12] = 55 1436 msg2.map_int64_int64[88] = 99 1437 msg2.map_int32_foreign_message[222].c = 15 1438 1439 msg2.MergeFrom(msg) 1440 1441 self.assertEqual(34, msg2.map_int32_int32[12]) 1442 self.assertEqual(78, msg2.map_int32_int32[56]) 1443 self.assertEqual(33, msg2.map_int64_int64[22]) 1444 self.assertEqual(99, msg2.map_int64_int64[88]) 1445 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 1446 self.assertEqual(10, msg2.map_int32_foreign_message[222].c) 1447 1448 # Verify that there is only one entry per key, even though the MergeFrom 1449 # may have internally created multiple entries for a single key in the 1450 # list representation. 1451 as_dict = {} 1452 for key in msg2.map_int32_foreign_message: 1453 self.assertFalse(key in as_dict) 1454 as_dict[key] = msg2.map_int32_foreign_message[key].c 1455 1456 self.assertEqual({111: 5, 222: 10}, as_dict) 1457 1458 # Special case: test that delete of item really removes the item, even if 1459 # there might have physically been duplicate keys due to the previous merge. 1460 # This is only a special case for the C++ implementation which stores the 1461 # map as an array. 1462 del msg2.map_int32_int32[12] 1463 self.assertFalse(12 in msg2.map_int32_int32) 1464 1465 del msg2.map_int32_foreign_message[222] 1466 self.assertFalse(222 in msg2.map_int32_foreign_message) 1467 1468 def testMergeFromBadType(self): 1469 msg = map_unittest_pb2.TestMap() 1470 with self.assertRaisesRegexp( 1471 TypeError, 1472 r'Parameter to MergeFrom\(\) must be instance of same class: expected ' 1473 r'.*TestMap got int\.'): 1474 msg.MergeFrom(1) 1475 1476 def testCopyFromBadType(self): 1477 msg = map_unittest_pb2.TestMap() 1478 with self.assertRaisesRegexp( 1479 TypeError, 1480 r'Parameter to [A-Za-z]*From\(\) must be instance of same class: ' 1481 r'expected .*TestMap got int\.'): 1482 msg.CopyFrom(1) 1483 1484 def testIntegerMapWithLongs(self): 1485 msg = map_unittest_pb2.TestMap() 1486 msg.map_int32_int32[long(-123)] = long(-456) 1487 msg.map_int64_int64[long(-2**33)] = long(-2**34) 1488 msg.map_uint32_uint32[long(123)] = long(456) 1489 msg.map_uint64_uint64[long(2**33)] = long(2**34) 1490 1491 serialized = msg.SerializeToString() 1492 msg2 = map_unittest_pb2.TestMap() 1493 msg2.ParseFromString(serialized) 1494 1495 self.assertEqual(-456, msg2.map_int32_int32[-123]) 1496 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 1497 self.assertEqual(456, msg2.map_uint32_uint32[123]) 1498 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 1499 1500 def testMapAssignmentCausesPresence(self): 1501 msg = map_unittest_pb2.TestMapSubmessage() 1502 msg.test_map.map_int32_int32[123] = 456 1503 1504 serialized = msg.SerializeToString() 1505 msg2 = map_unittest_pb2.TestMapSubmessage() 1506 msg2.ParseFromString(serialized) 1507 1508 self.assertEqual(msg, msg2) 1509 1510 # Now test that various mutations of the map properly invalidate the 1511 # cached size of the submessage. 1512 msg.test_map.map_int32_int32[888] = 999 1513 serialized = msg.SerializeToString() 1514 msg2.ParseFromString(serialized) 1515 self.assertEqual(msg, msg2) 1516 1517 msg.test_map.map_int32_int32.clear() 1518 serialized = msg.SerializeToString() 1519 msg2.ParseFromString(serialized) 1520 self.assertEqual(msg, msg2) 1521 1522 def testMapAssignmentCausesPresenceForSubmessages(self): 1523 msg = map_unittest_pb2.TestMapSubmessage() 1524 msg.test_map.map_int32_foreign_message[123].c = 5 1525 1526 serialized = msg.SerializeToString() 1527 msg2 = map_unittest_pb2.TestMapSubmessage() 1528 msg2.ParseFromString(serialized) 1529 1530 self.assertEqual(msg, msg2) 1531 1532 # Now test that various mutations of the map properly invalidate the 1533 # cached size of the submessage. 1534 msg.test_map.map_int32_foreign_message[888].c = 7 1535 serialized = msg.SerializeToString() 1536 msg2.ParseFromString(serialized) 1537 self.assertEqual(msg, msg2) 1538 1539 msg.test_map.map_int32_foreign_message[888].MergeFrom( 1540 msg.test_map.map_int32_foreign_message[123]) 1541 serialized = msg.SerializeToString() 1542 msg2.ParseFromString(serialized) 1543 self.assertEqual(msg, msg2) 1544 1545 msg.test_map.map_int32_foreign_message.clear() 1546 serialized = msg.SerializeToString() 1547 msg2.ParseFromString(serialized) 1548 self.assertEqual(msg, msg2) 1549 1550 def testModifyMapWhileIterating(self): 1551 msg = map_unittest_pb2.TestMap() 1552 1553 string_string_iter = iter(msg.map_string_string) 1554 int32_foreign_iter = iter(msg.map_int32_foreign_message) 1555 1556 msg.map_string_string['abc'] = '123' 1557 msg.map_int32_foreign_message[5].c = 5 1558 1559 with self.assertRaises(RuntimeError): 1560 for key in string_string_iter: 1561 pass 1562 1563 with self.assertRaises(RuntimeError): 1564 for key in int32_foreign_iter: 1565 pass 1566 1567 def testSubmessageMap(self): 1568 msg = map_unittest_pb2.TestMap() 1569 1570 submsg = msg.map_int32_foreign_message[111] 1571 self.assertIs(submsg, msg.map_int32_foreign_message[111]) 1572 self.assertIsInstance(submsg, unittest_pb2.ForeignMessage) 1573 1574 submsg.c = 5 1575 1576 serialized = msg.SerializeToString() 1577 msg2 = map_unittest_pb2.TestMap() 1578 msg2.ParseFromString(serialized) 1579 1580 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 1581 1582 # Doesn't allow direct submessage assignment. 1583 with self.assertRaises(ValueError): 1584 msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage() 1585 1586 def testMapIteration(self): 1587 msg = map_unittest_pb2.TestMap() 1588 1589 for k, v in msg.map_int32_int32.items(): 1590 # Should not be reached. 1591 self.assertTrue(False) 1592 1593 msg.map_int32_int32[2] = 4 1594 msg.map_int32_int32[3] = 6 1595 msg.map_int32_int32[4] = 8 1596 self.assertEqual(3, len(msg.map_int32_int32)) 1597 1598 matching_dict = {2: 4, 3: 6, 4: 8} 1599 self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) 1600 1601 def testMapItems(self): 1602 # Map items used to have strange behaviors when use c extension. Because 1603 # [] may reorder the map and invalidate any exsting iterators. 1604 # TODO(jieluo): Check if [] reordering the map is a bug or intended 1605 # behavior. 1606 msg = map_unittest_pb2.TestMap() 1607 msg.map_string_string['local_init_op'] = '' 1608 msg.map_string_string['trainable_variables'] = '' 1609 msg.map_string_string['variables'] = '' 1610 msg.map_string_string['init_op'] = '' 1611 msg.map_string_string['summaries'] = '' 1612 items1 = msg.map_string_string.items() 1613 items2 = msg.map_string_string.items() 1614 self.assertEqual(items1, items2) 1615 1616 def testMapIterationClearMessage(self): 1617 # Iterator needs to work even if message and map are deleted. 1618 msg = map_unittest_pb2.TestMap() 1619 1620 msg.map_int32_int32[2] = 4 1621 msg.map_int32_int32[3] = 6 1622 msg.map_int32_int32[4] = 8 1623 1624 it = msg.map_int32_int32.items() 1625 del msg 1626 1627 matching_dict = {2: 4, 3: 6, 4: 8} 1628 self.assertMapIterEquals(it, matching_dict) 1629 1630 def testMapConstruction(self): 1631 msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4}) 1632 self.assertEqual(2, msg.map_int32_int32[1]) 1633 self.assertEqual(4, msg.map_int32_int32[3]) 1634 1635 msg = map_unittest_pb2.TestMap( 1636 map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)}) 1637 self.assertEqual(5, msg.map_int32_foreign_message[3].c) 1638 1639 def testMapValidAfterFieldCleared(self): 1640 # Map needs to work even if field is cleared. 1641 # For the C++ implementation this tests the correctness of 1642 # ScalarMapContainer::Release() 1643 msg = map_unittest_pb2.TestMap() 1644 int32_map = msg.map_int32_int32 1645 1646 int32_map[2] = 4 1647 int32_map[3] = 6 1648 int32_map[4] = 8 1649 1650 msg.ClearField('map_int32_int32') 1651 self.assertEqual(b'', msg.SerializeToString()) 1652 matching_dict = {2: 4, 3: 6, 4: 8} 1653 self.assertMapIterEquals(int32_map.items(), matching_dict) 1654 1655 def testMessageMapValidAfterFieldCleared(self): 1656 # Map needs to work even if field is cleared. 1657 # For the C++ implementation this tests the correctness of 1658 # ScalarMapContainer::Release() 1659 msg = map_unittest_pb2.TestMap() 1660 int32_foreign_message = msg.map_int32_foreign_message 1661 1662 int32_foreign_message[2].c = 5 1663 1664 msg.ClearField('map_int32_foreign_message') 1665 self.assertEqual(b'', msg.SerializeToString()) 1666 self.assertTrue(2 in int32_foreign_message.keys()) 1667 1668 def testMapIterInvalidatedByClearField(self): 1669 # Map iterator is invalidated when field is cleared. 1670 # But this case does need to not crash the interpreter. 1671 # For the C++ implementation this tests the correctness of 1672 # ScalarMapContainer::Release() 1673 msg = map_unittest_pb2.TestMap() 1674 1675 it = iter(msg.map_int32_int32) 1676 1677 msg.ClearField('map_int32_int32') 1678 with self.assertRaises(RuntimeError): 1679 for _ in it: 1680 pass 1681 1682 it = iter(msg.map_int32_foreign_message) 1683 msg.ClearField('map_int32_foreign_message') 1684 with self.assertRaises(RuntimeError): 1685 for _ in it: 1686 pass 1687 1688 def testMapDelete(self): 1689 msg = map_unittest_pb2.TestMap() 1690 1691 self.assertEqual(0, len(msg.map_int32_int32)) 1692 1693 msg.map_int32_int32[4] = 6 1694 self.assertEqual(1, len(msg.map_int32_int32)) 1695 1696 with self.assertRaises(KeyError): 1697 del msg.map_int32_int32[88] 1698 1699 del msg.map_int32_int32[4] 1700 self.assertEqual(0, len(msg.map_int32_int32)) 1701 1702 def testMapsAreMapping(self): 1703 msg = map_unittest_pb2.TestMap() 1704 self.assertIsInstance(msg.map_int32_int32, collections.Mapping) 1705 self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping) 1706 self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping) 1707 self.assertIsInstance(msg.map_int32_foreign_message, 1708 collections.MutableMapping) 1709 1710 def testMapFindInitializationErrorsSmokeTest(self): 1711 msg = map_unittest_pb2.TestMap() 1712 msg.map_string_string['abc'] = '123' 1713 msg.map_int32_int32[35] = 64 1714 msg.map_string_foreign_message['foo'].c = 5 1715 self.assertEqual(0, len(msg.FindInitializationErrors())) 1716 1717 1718 1719class ValidTypeNamesTest(unittest.TestCase): 1720 1721 def assertImportFromName(self, msg, base_name): 1722 # Parse <type 'module.class_name'> to extra 'some.name' as a string. 1723 tp_name = str(type(msg)).split("'")[1] 1724 valid_names = ('Repeated%sContainer' % base_name, 1725 'Repeated%sFieldContainer' % base_name) 1726 self.assertTrue(any(tp_name.endswith(v) for v in valid_names), 1727 '%r does end with any of %r' % (tp_name, valid_names)) 1728 1729 parts = tp_name.split('.') 1730 class_name = parts[-1] 1731 module_name = '.'.join(parts[:-1]) 1732 __import__(module_name, fromlist=[class_name]) 1733 1734 def testTypeNamesCanBeImported(self): 1735 # If import doesn't work, pickling won't work either. 1736 pb = unittest_pb2.TestAllTypes() 1737 self.assertImportFromName(pb.repeated_int32, 'Scalar') 1738 self.assertImportFromName(pb.repeated_nested_message, 'Composite') 1739 1740class PackedFieldTest(unittest.TestCase): 1741 1742 def setMessage(self, message): 1743 message.repeated_int32.append(1) 1744 message.repeated_int64.append(1) 1745 message.repeated_uint32.append(1) 1746 message.repeated_uint64.append(1) 1747 message.repeated_sint32.append(1) 1748 message.repeated_sint64.append(1) 1749 message.repeated_fixed32.append(1) 1750 message.repeated_fixed64.append(1) 1751 message.repeated_sfixed32.append(1) 1752 message.repeated_sfixed64.append(1) 1753 message.repeated_float.append(1.0) 1754 message.repeated_double.append(1.0) 1755 message.repeated_bool.append(True) 1756 message.repeated_nested_enum.append(1) 1757 1758 def testPackedFields(self): 1759 message = packed_field_test_pb2.TestPackedTypes() 1760 self.setMessage(message) 1761 golden_data = (b'\x0A\x01\x01' 1762 b'\x12\x01\x01' 1763 b'\x1A\x01\x01' 1764 b'\x22\x01\x01' 1765 b'\x2A\x01\x02' 1766 b'\x32\x01\x02' 1767 b'\x3A\x04\x01\x00\x00\x00' 1768 b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00' 1769 b'\x4A\x04\x01\x00\x00\x00' 1770 b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00' 1771 b'\x5A\x04\x00\x00\x80\x3f' 1772 b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f' 1773 b'\x6A\x01\x01' 1774 b'\x72\x01\x01') 1775 self.assertEqual(golden_data, message.SerializeToString()) 1776 1777 def testUnpackedFields(self): 1778 message = packed_field_test_pb2.TestUnpackedTypes() 1779 self.setMessage(message) 1780 golden_data = (b'\x08\x01' 1781 b'\x10\x01' 1782 b'\x18\x01' 1783 b'\x20\x01' 1784 b'\x28\x02' 1785 b'\x30\x02' 1786 b'\x3D\x01\x00\x00\x00' 1787 b'\x41\x01\x00\x00\x00\x00\x00\x00\x00' 1788 b'\x4D\x01\x00\x00\x00' 1789 b'\x51\x01\x00\x00\x00\x00\x00\x00\x00' 1790 b'\x5D\x00\x00\x80\x3f' 1791 b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f' 1792 b'\x68\x01' 1793 b'\x70\x01') 1794 self.assertEqual(golden_data, message.SerializeToString()) 1795 1796 1797@unittest.skipIf(api_implementation.Type() != 'cpp', 1798 'explicit tests of the C++ implementation') 1799class OversizeProtosTest(unittest.TestCase): 1800 1801 def setUp(self): 1802 self.file_desc = """ 1803 name: "f/f.msg2" 1804 package: "f" 1805 message_type { 1806 name: "msg1" 1807 field { 1808 name: "payload" 1809 number: 1 1810 label: LABEL_OPTIONAL 1811 type: TYPE_STRING 1812 } 1813 } 1814 message_type { 1815 name: "msg2" 1816 field { 1817 name: "field" 1818 number: 1 1819 label: LABEL_OPTIONAL 1820 type: TYPE_MESSAGE 1821 type_name: "msg1" 1822 } 1823 } 1824 """ 1825 pool = descriptor_pool.DescriptorPool() 1826 desc = descriptor_pb2.FileDescriptorProto() 1827 text_format.Parse(self.file_desc, desc) 1828 pool.Add(desc) 1829 self.proto_cls = message_factory.MessageFactory(pool).GetPrototype( 1830 pool.FindMessageTypeByName('f.msg2')) 1831 self.p = self.proto_cls() 1832 self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) 1833 self.p_serialized = self.p.SerializeToString() 1834 1835 def testAssertOversizeProto(self): 1836 from google.protobuf.pyext._message import SetAllowOversizeProtos 1837 SetAllowOversizeProtos(False) 1838 q = self.proto_cls() 1839 try: 1840 q.ParseFromString(self.p_serialized) 1841 except message.DecodeError as e: 1842 self.assertEqual(str(e), 'Error parsing message') 1843 1844 def testSucceedOversizeProto(self): 1845 from google.protobuf.pyext._message import SetAllowOversizeProtos 1846 SetAllowOversizeProtos(True) 1847 q = self.proto_cls() 1848 q.ParseFromString(self.p_serialized) 1849 self.assertEqual(self.p.field.payload, q.field.payload) 1850 1851if __name__ == '__main__': 1852 unittest.main() 1853