1#! /usr/bin/env python 2# 3# Protocol Buffers - Google's data interchange format 4# Copyright 2008 Google Inc. All rights reserved. 5# https://developers.google.com/protocol-buffers/ 6# 7# Redistribution and use in source and binary forms, with or without 8# modification, are permitted provided that the following conditions are 9# met: 10# 11# * Redistributions of source code must retain the above copyright 12# notice, this list of conditions and the following disclaimer. 13# * Redistributions in binary form must reproduce the above 14# copyright notice, this list of conditions and the following disclaimer 15# in the documentation and/or other materials provided with the 16# distribution. 17# * Neither the name of Google Inc. nor the names of its 18# contributors may be used to endorse or promote products derived from 19# this software without specific prior written permission. 20# 21# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 33"""Tests python protocol buffers against the golden message. 34 35Note that the golden messages exercise every known field type, thus this 36test ends up exercising and verifying nearly all of the parsing and 37serialization code in the whole library. 38 39TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of 40sense to call this a test of the "message" module, which only declares an 41abstract interface. 42""" 43 44__author__ = 'gps@google.com (Gregory P. Smith)' 45 46 47import collections 48import copy 49import math 50import operator 51import pickle 52import six 53import sys 54 55try: 56 import unittest2 as unittest #PY26 57except ImportError: 58 import unittest 59 60from google.protobuf import map_unittest_pb2 61from google.protobuf import unittest_pb2 62from google.protobuf import unittest_proto3_arena_pb2 63from google.protobuf import descriptor_pb2 64from google.protobuf import descriptor_pool 65from google.protobuf import message_factory 66from google.protobuf import text_format 67from google.protobuf.internal import api_implementation 68from google.protobuf.internal import packed_field_test_pb2 69from google.protobuf.internal import test_util 70from google.protobuf import message 71from google.protobuf.internal import _parameterized 72 73if six.PY3: 74 long = int 75 76 77# Python pre-2.6 does not have isinf() or isnan() functions, so we have 78# to provide our own. 79def isnan(val): 80 # NaN is never equal to itself. 81 return val != val 82def isinf(val): 83 # Infinity times zero equals NaN. 84 return not isnan(val) and isnan(val * 0) 85def IsPosInf(val): 86 return isinf(val) and (val > 0) 87def IsNegInf(val): 88 return isinf(val) and (val < 0) 89 90 91@_parameterized.Parameters( 92 (unittest_pb2), 93 (unittest_proto3_arena_pb2)) 94class MessageTest(unittest.TestCase): 95 96 def testBadUtf8String(self, message_module): 97 if api_implementation.Type() != 'python': 98 self.skipTest("Skipping testBadUtf8String, currently only the python " 99 "api implementation raises UnicodeDecodeError when a " 100 "string field contains bad utf-8.") 101 bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') 102 with self.assertRaises(UnicodeDecodeError) as context: 103 message_module.TestAllTypes.FromString(bad_utf8_data) 104 self.assertIn('TestAllTypes.optional_string', str(context.exception)) 105 106 def testGoldenMessage(self, message_module): 107 # Proto3 doesn't have the "default_foo" members or foreign enums, 108 # and doesn't preserve unknown fields, so for proto3 we use a golden 109 # message that doesn't have these fields set. 110 if message_module is unittest_pb2: 111 golden_data = test_util.GoldenFileData( 112 'golden_message_oneof_implemented') 113 else: 114 golden_data = test_util.GoldenFileData('golden_message_proto3') 115 116 golden_message = message_module.TestAllTypes() 117 golden_message.ParseFromString(golden_data) 118 if message_module is unittest_pb2: 119 test_util.ExpectAllFieldsSet(self, golden_message) 120 self.assertEqual(golden_data, golden_message.SerializeToString()) 121 golden_copy = copy.deepcopy(golden_message) 122 self.assertEqual(golden_data, golden_copy.SerializeToString()) 123 124 def testGoldenPackedMessage(self, message_module): 125 golden_data = test_util.GoldenFileData('golden_packed_fields_message') 126 golden_message = message_module.TestPackedTypes() 127 golden_message.ParseFromString(golden_data) 128 all_set = message_module.TestPackedTypes() 129 test_util.SetAllPackedFields(all_set) 130 self.assertEqual(all_set, golden_message) 131 self.assertEqual(golden_data, all_set.SerializeToString()) 132 golden_copy = copy.deepcopy(golden_message) 133 self.assertEqual(golden_data, golden_copy.SerializeToString()) 134 135 def testPickleSupport(self, message_module): 136 golden_data = test_util.GoldenFileData('golden_message') 137 golden_message = message_module.TestAllTypes() 138 golden_message.ParseFromString(golden_data) 139 pickled_message = pickle.dumps(golden_message) 140 141 unpickled_message = pickle.loads(pickled_message) 142 self.assertEqual(unpickled_message, golden_message) 143 144 def testPositiveInfinity(self, message_module): 145 if message_module is unittest_pb2: 146 golden_data = (b'\x5D\x00\x00\x80\x7F' 147 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 148 b'\xCD\x02\x00\x00\x80\x7F' 149 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') 150 else: 151 golden_data = (b'\x5D\x00\x00\x80\x7F' 152 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' 153 b'\xCA\x02\x04\x00\x00\x80\x7F' 154 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 155 156 golden_message = message_module.TestAllTypes() 157 golden_message.ParseFromString(golden_data) 158 self.assertTrue(IsPosInf(golden_message.optional_float)) 159 self.assertTrue(IsPosInf(golden_message.optional_double)) 160 self.assertTrue(IsPosInf(golden_message.repeated_float[0])) 161 self.assertTrue(IsPosInf(golden_message.repeated_double[0])) 162 self.assertEqual(golden_data, golden_message.SerializeToString()) 163 164 def testNegativeInfinity(self, message_module): 165 if message_module is unittest_pb2: 166 golden_data = (b'\x5D\x00\x00\x80\xFF' 167 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 168 b'\xCD\x02\x00\x00\x80\xFF' 169 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') 170 else: 171 golden_data = (b'\x5D\x00\x00\x80\xFF' 172 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' 173 b'\xCA\x02\x04\x00\x00\x80\xFF' 174 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 175 176 golden_message = message_module.TestAllTypes() 177 golden_message.ParseFromString(golden_data) 178 self.assertTrue(IsNegInf(golden_message.optional_float)) 179 self.assertTrue(IsNegInf(golden_message.optional_double)) 180 self.assertTrue(IsNegInf(golden_message.repeated_float[0])) 181 self.assertTrue(IsNegInf(golden_message.repeated_double[0])) 182 self.assertEqual(golden_data, golden_message.SerializeToString()) 183 184 def testNotANumber(self, message_module): 185 golden_data = (b'\x5D\x00\x00\xC0\x7F' 186 b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' 187 b'\xCD\x02\x00\x00\xC0\x7F' 188 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') 189 golden_message = message_module.TestAllTypes() 190 golden_message.ParseFromString(golden_data) 191 self.assertTrue(isnan(golden_message.optional_float)) 192 self.assertTrue(isnan(golden_message.optional_double)) 193 self.assertTrue(isnan(golden_message.repeated_float[0])) 194 self.assertTrue(isnan(golden_message.repeated_double[0])) 195 196 # The protocol buffer may serialize to any one of multiple different 197 # representations of a NaN. Rather than verify a specific representation, 198 # verify the serialized string can be converted into a correctly 199 # behaving protocol buffer. 200 serialized = golden_message.SerializeToString() 201 message = message_module.TestAllTypes() 202 message.ParseFromString(serialized) 203 self.assertTrue(isnan(message.optional_float)) 204 self.assertTrue(isnan(message.optional_double)) 205 self.assertTrue(isnan(message.repeated_float[0])) 206 self.assertTrue(isnan(message.repeated_double[0])) 207 208 def testPositiveInfinityPacked(self, message_module): 209 golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' 210 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') 211 golden_message = message_module.TestPackedTypes() 212 golden_message.ParseFromString(golden_data) 213 self.assertTrue(IsPosInf(golden_message.packed_float[0])) 214 self.assertTrue(IsPosInf(golden_message.packed_double[0])) 215 self.assertEqual(golden_data, golden_message.SerializeToString()) 216 217 def testNegativeInfinityPacked(self, message_module): 218 golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' 219 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') 220 golden_message = message_module.TestPackedTypes() 221 golden_message.ParseFromString(golden_data) 222 self.assertTrue(IsNegInf(golden_message.packed_float[0])) 223 self.assertTrue(IsNegInf(golden_message.packed_double[0])) 224 self.assertEqual(golden_data, golden_message.SerializeToString()) 225 226 def testNotANumberPacked(self, message_module): 227 golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' 228 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') 229 golden_message = message_module.TestPackedTypes() 230 golden_message.ParseFromString(golden_data) 231 self.assertTrue(isnan(golden_message.packed_float[0])) 232 self.assertTrue(isnan(golden_message.packed_double[0])) 233 234 serialized = golden_message.SerializeToString() 235 message = message_module.TestPackedTypes() 236 message.ParseFromString(serialized) 237 self.assertTrue(isnan(message.packed_float[0])) 238 self.assertTrue(isnan(message.packed_double[0])) 239 240 def testExtremeFloatValues(self, message_module): 241 message = message_module.TestAllTypes() 242 243 # Most positive exponent, no significand bits set. 244 kMostPosExponentNoSigBits = math.pow(2, 127) 245 message.optional_float = kMostPosExponentNoSigBits 246 message.ParseFromString(message.SerializeToString()) 247 self.assertTrue(message.optional_float == kMostPosExponentNoSigBits) 248 249 # Most positive exponent, one significand bit set. 250 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127) 251 message.optional_float = kMostPosExponentOneSigBit 252 message.ParseFromString(message.SerializeToString()) 253 self.assertTrue(message.optional_float == kMostPosExponentOneSigBit) 254 255 # Repeat last two cases with values of same magnitude, but negative. 256 message.optional_float = -kMostPosExponentNoSigBits 257 message.ParseFromString(message.SerializeToString()) 258 self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits) 259 260 message.optional_float = -kMostPosExponentOneSigBit 261 message.ParseFromString(message.SerializeToString()) 262 self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit) 263 264 # Most negative exponent, no significand bits set. 265 kMostNegExponentNoSigBits = math.pow(2, -127) 266 message.optional_float = kMostNegExponentNoSigBits 267 message.ParseFromString(message.SerializeToString()) 268 self.assertTrue(message.optional_float == kMostNegExponentNoSigBits) 269 270 # Most negative exponent, one significand bit set. 271 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127) 272 message.optional_float = kMostNegExponentOneSigBit 273 message.ParseFromString(message.SerializeToString()) 274 self.assertTrue(message.optional_float == kMostNegExponentOneSigBit) 275 276 # Repeat last two cases with values of the same magnitude, but negative. 277 message.optional_float = -kMostNegExponentNoSigBits 278 message.ParseFromString(message.SerializeToString()) 279 self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits) 280 281 message.optional_float = -kMostNegExponentOneSigBit 282 message.ParseFromString(message.SerializeToString()) 283 self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) 284 285 def testExtremeDoubleValues(self, message_module): 286 message = message_module.TestAllTypes() 287 288 # Most positive exponent, no significand bits set. 289 kMostPosExponentNoSigBits = math.pow(2, 1023) 290 message.optional_double = kMostPosExponentNoSigBits 291 message.ParseFromString(message.SerializeToString()) 292 self.assertTrue(message.optional_double == kMostPosExponentNoSigBits) 293 294 # Most positive exponent, one significand bit set. 295 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023) 296 message.optional_double = kMostPosExponentOneSigBit 297 message.ParseFromString(message.SerializeToString()) 298 self.assertTrue(message.optional_double == kMostPosExponentOneSigBit) 299 300 # Repeat last two cases with values of same magnitude, but negative. 301 message.optional_double = -kMostPosExponentNoSigBits 302 message.ParseFromString(message.SerializeToString()) 303 self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits) 304 305 message.optional_double = -kMostPosExponentOneSigBit 306 message.ParseFromString(message.SerializeToString()) 307 self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit) 308 309 # Most negative exponent, no significand bits set. 310 kMostNegExponentNoSigBits = math.pow(2, -1023) 311 message.optional_double = kMostNegExponentNoSigBits 312 message.ParseFromString(message.SerializeToString()) 313 self.assertTrue(message.optional_double == kMostNegExponentNoSigBits) 314 315 # Most negative exponent, one significand bit set. 316 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023) 317 message.optional_double = kMostNegExponentOneSigBit 318 message.ParseFromString(message.SerializeToString()) 319 self.assertTrue(message.optional_double == kMostNegExponentOneSigBit) 320 321 # Repeat last two cases with values of the same magnitude, but negative. 322 message.optional_double = -kMostNegExponentNoSigBits 323 message.ParseFromString(message.SerializeToString()) 324 self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits) 325 326 message.optional_double = -kMostNegExponentOneSigBit 327 message.ParseFromString(message.SerializeToString()) 328 self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) 329 330 def testFloatPrinting(self, message_module): 331 message = message_module.TestAllTypes() 332 message.optional_float = 2.0 333 self.assertEqual(str(message), 'optional_float: 2.0\n') 334 335 def testHighPrecisionFloatPrinting(self, message_module): 336 message = message_module.TestAllTypes() 337 message.optional_double = 0.12345678912345678 338 if sys.version_info >= (3,): 339 self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') 340 else: 341 self.assertEqual(str(message), 'optional_double: 0.123456789123\n') 342 343 def testUnknownFieldPrinting(self, message_module): 344 populated = message_module.TestAllTypes() 345 test_util.SetAllNonLazyFields(populated) 346 empty = message_module.TestEmptyMessage() 347 empty.ParseFromString(populated.SerializeToString()) 348 self.assertEqual(str(empty), '') 349 350 def testRepeatedNestedFieldIteration(self, message_module): 351 msg = message_module.TestAllTypes() 352 msg.repeated_nested_message.add(bb=1) 353 msg.repeated_nested_message.add(bb=2) 354 msg.repeated_nested_message.add(bb=3) 355 msg.repeated_nested_message.add(bb=4) 356 357 self.assertEqual([1, 2, 3, 4], 358 [m.bb for m in msg.repeated_nested_message]) 359 self.assertEqual([4, 3, 2, 1], 360 [m.bb for m in reversed(msg.repeated_nested_message)]) 361 self.assertEqual([4, 3, 2, 1], 362 [m.bb for m in msg.repeated_nested_message[::-1]]) 363 364 def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module): 365 """Check some different types with the default comparator.""" 366 message = message_module.TestAllTypes() 367 368 # TODO(mattp): would testing more scalar types strengthen test? 369 message.repeated_int32.append(1) 370 message.repeated_int32.append(3) 371 message.repeated_int32.append(2) 372 message.repeated_int32.sort() 373 self.assertEqual(message.repeated_int32[0], 1) 374 self.assertEqual(message.repeated_int32[1], 2) 375 self.assertEqual(message.repeated_int32[2], 3) 376 377 message.repeated_float.append(1.1) 378 message.repeated_float.append(1.3) 379 message.repeated_float.append(1.2) 380 message.repeated_float.sort() 381 self.assertAlmostEqual(message.repeated_float[0], 1.1) 382 self.assertAlmostEqual(message.repeated_float[1], 1.2) 383 self.assertAlmostEqual(message.repeated_float[2], 1.3) 384 385 message.repeated_string.append('a') 386 message.repeated_string.append('c') 387 message.repeated_string.append('b') 388 message.repeated_string.sort() 389 self.assertEqual(message.repeated_string[0], 'a') 390 self.assertEqual(message.repeated_string[1], 'b') 391 self.assertEqual(message.repeated_string[2], 'c') 392 393 message.repeated_bytes.append(b'a') 394 message.repeated_bytes.append(b'c') 395 message.repeated_bytes.append(b'b') 396 message.repeated_bytes.sort() 397 self.assertEqual(message.repeated_bytes[0], b'a') 398 self.assertEqual(message.repeated_bytes[1], b'b') 399 self.assertEqual(message.repeated_bytes[2], b'c') 400 401 def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): 402 """Check some different types with custom comparator.""" 403 message = message_module.TestAllTypes() 404 405 message.repeated_int32.append(-3) 406 message.repeated_int32.append(-2) 407 message.repeated_int32.append(-1) 408 message.repeated_int32.sort(key=abs) 409 self.assertEqual(message.repeated_int32[0], -1) 410 self.assertEqual(message.repeated_int32[1], -2) 411 self.assertEqual(message.repeated_int32[2], -3) 412 413 message.repeated_string.append('aaa') 414 message.repeated_string.append('bb') 415 message.repeated_string.append('c') 416 message.repeated_string.sort(key=len) 417 self.assertEqual(message.repeated_string[0], 'c') 418 self.assertEqual(message.repeated_string[1], 'bb') 419 self.assertEqual(message.repeated_string[2], 'aaa') 420 421 def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module): 422 """Check passing a custom comparator to sort a repeated composite field.""" 423 message = message_module.TestAllTypes() 424 425 message.repeated_nested_message.add().bb = 1 426 message.repeated_nested_message.add().bb = 3 427 message.repeated_nested_message.add().bb = 2 428 message.repeated_nested_message.add().bb = 6 429 message.repeated_nested_message.add().bb = 5 430 message.repeated_nested_message.add().bb = 4 431 message.repeated_nested_message.sort(key=operator.attrgetter('bb')) 432 self.assertEqual(message.repeated_nested_message[0].bb, 1) 433 self.assertEqual(message.repeated_nested_message[1].bb, 2) 434 self.assertEqual(message.repeated_nested_message[2].bb, 3) 435 self.assertEqual(message.repeated_nested_message[3].bb, 4) 436 self.assertEqual(message.repeated_nested_message[4].bb, 5) 437 self.assertEqual(message.repeated_nested_message[5].bb, 6) 438 439 def testSortingRepeatedCompositeFieldsStable(self, message_module): 440 """Check passing a custom comparator to sort a repeated composite field.""" 441 message = message_module.TestAllTypes() 442 443 message.repeated_nested_message.add().bb = 21 444 message.repeated_nested_message.add().bb = 20 445 message.repeated_nested_message.add().bb = 13 446 message.repeated_nested_message.add().bb = 33 447 message.repeated_nested_message.add().bb = 11 448 message.repeated_nested_message.add().bb = 24 449 message.repeated_nested_message.add().bb = 10 450 message.repeated_nested_message.sort(key=lambda z: z.bb // 10) 451 self.assertEqual( 452 [13, 11, 10, 21, 20, 24, 33], 453 [n.bb for n in message.repeated_nested_message]) 454 455 # Make sure that for the C++ implementation, the underlying fields 456 # are actually reordered. 457 pb = message.SerializeToString() 458 message.Clear() 459 message.MergeFromString(pb) 460 self.assertEqual( 461 [13, 11, 10, 21, 20, 24, 33], 462 [n.bb for n in message.repeated_nested_message]) 463 464 def testRepeatedCompositeFieldSortArguments(self, message_module): 465 """Check sorting a repeated composite field using list.sort() arguments.""" 466 message = message_module.TestAllTypes() 467 468 get_bb = operator.attrgetter('bb') 469 cmp_bb = lambda a, b: cmp(a.bb, b.bb) 470 message.repeated_nested_message.add().bb = 1 471 message.repeated_nested_message.add().bb = 3 472 message.repeated_nested_message.add().bb = 2 473 message.repeated_nested_message.add().bb = 6 474 message.repeated_nested_message.add().bb = 5 475 message.repeated_nested_message.add().bb = 4 476 message.repeated_nested_message.sort(key=get_bb) 477 self.assertEqual([k.bb for k in message.repeated_nested_message], 478 [1, 2, 3, 4, 5, 6]) 479 message.repeated_nested_message.sort(key=get_bb, reverse=True) 480 self.assertEqual([k.bb for k in message.repeated_nested_message], 481 [6, 5, 4, 3, 2, 1]) 482 if sys.version_info >= (3,): return # No cmp sorting in PY3. 483 message.repeated_nested_message.sort(sort_function=cmp_bb) 484 self.assertEqual([k.bb for k in message.repeated_nested_message], 485 [1, 2, 3, 4, 5, 6]) 486 message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) 487 self.assertEqual([k.bb for k in message.repeated_nested_message], 488 [6, 5, 4, 3, 2, 1]) 489 490 def testRepeatedScalarFieldSortArguments(self, message_module): 491 """Check sorting a scalar field using list.sort() arguments.""" 492 message = message_module.TestAllTypes() 493 494 message.repeated_int32.append(-3) 495 message.repeated_int32.append(-2) 496 message.repeated_int32.append(-1) 497 message.repeated_int32.sort(key=abs) 498 self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) 499 message.repeated_int32.sort(key=abs, reverse=True) 500 self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) 501 if sys.version_info < (3,): # No cmp sorting in PY3. 502 abs_cmp = lambda a, b: cmp(abs(a), abs(b)) 503 message.repeated_int32.sort(sort_function=abs_cmp) 504 self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) 505 message.repeated_int32.sort(cmp=abs_cmp, reverse=True) 506 self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) 507 508 message.repeated_string.append('aaa') 509 message.repeated_string.append('bb') 510 message.repeated_string.append('c') 511 message.repeated_string.sort(key=len) 512 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) 513 message.repeated_string.sort(key=len, reverse=True) 514 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) 515 if sys.version_info < (3,): # No cmp sorting in PY3. 516 len_cmp = lambda a, b: cmp(len(a), len(b)) 517 message.repeated_string.sort(sort_function=len_cmp) 518 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) 519 message.repeated_string.sort(cmp=len_cmp, reverse=True) 520 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) 521 522 def testRepeatedFieldsComparable(self, message_module): 523 m1 = message_module.TestAllTypes() 524 m2 = message_module.TestAllTypes() 525 m1.repeated_int32.append(0) 526 m1.repeated_int32.append(1) 527 m1.repeated_int32.append(2) 528 m2.repeated_int32.append(0) 529 m2.repeated_int32.append(1) 530 m2.repeated_int32.append(2) 531 m1.repeated_nested_message.add().bb = 1 532 m1.repeated_nested_message.add().bb = 2 533 m1.repeated_nested_message.add().bb = 3 534 m2.repeated_nested_message.add().bb = 1 535 m2.repeated_nested_message.add().bb = 2 536 m2.repeated_nested_message.add().bb = 3 537 538 if sys.version_info >= (3,): return # No cmp() in PY3. 539 540 # These comparisons should not raise errors. 541 _ = m1 < m2 542 _ = m1.repeated_nested_message < m2.repeated_nested_message 543 544 # Make sure cmp always works. If it wasn't defined, these would be 545 # id() comparisons and would all fail. 546 self.assertEqual(cmp(m1, m2), 0) 547 self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0) 548 self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0) 549 self.assertEqual(cmp(m1.repeated_nested_message, 550 m2.repeated_nested_message), 0) 551 with self.assertRaises(TypeError): 552 # Can't compare repeated composite containers to lists. 553 cmp(m1.repeated_nested_message, m2.repeated_nested_message[:]) 554 555 # TODO(anuraag): Implement extensiondict comparison in C++ and then add test 556 557 def testRepeatedFieldsAreSequences(self, message_module): 558 m = message_module.TestAllTypes() 559 self.assertIsInstance(m.repeated_int32, collections.MutableSequence) 560 self.assertIsInstance(m.repeated_nested_message, 561 collections.MutableSequence) 562 563 def ensureNestedMessageExists(self, msg, attribute): 564 """Make sure that a nested message object exists. 565 566 As soon as a nested message attribute is accessed, it will be present in the 567 _fields dict, without being marked as actually being set. 568 """ 569 getattr(msg, attribute) 570 self.assertFalse(msg.HasField(attribute)) 571 572 def testOneofGetCaseNonexistingField(self, message_module): 573 m = message_module.TestAllTypes() 574 self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') 575 576 def testOneofDefaultValues(self, message_module): 577 m = message_module.TestAllTypes() 578 self.assertIs(None, m.WhichOneof('oneof_field')) 579 self.assertFalse(m.HasField('oneof_uint32')) 580 581 # Oneof is set even when setting it to a default value. 582 m.oneof_uint32 = 0 583 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 584 self.assertTrue(m.HasField('oneof_uint32')) 585 self.assertFalse(m.HasField('oneof_string')) 586 587 m.oneof_string = "" 588 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 589 self.assertTrue(m.HasField('oneof_string')) 590 self.assertFalse(m.HasField('oneof_uint32')) 591 592 def testOneofSemantics(self, message_module): 593 m = message_module.TestAllTypes() 594 self.assertIs(None, m.WhichOneof('oneof_field')) 595 596 m.oneof_uint32 = 11 597 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 598 self.assertTrue(m.HasField('oneof_uint32')) 599 600 m.oneof_string = u'foo' 601 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 602 self.assertFalse(m.HasField('oneof_uint32')) 603 self.assertTrue(m.HasField('oneof_string')) 604 605 # Read nested message accessor without accessing submessage. 606 m.oneof_nested_message 607 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 608 self.assertTrue(m.HasField('oneof_string')) 609 self.assertFalse(m.HasField('oneof_nested_message')) 610 611 # Read accessor of nested message without accessing submessage. 612 m.oneof_nested_message.bb 613 self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) 614 self.assertTrue(m.HasField('oneof_string')) 615 self.assertFalse(m.HasField('oneof_nested_message')) 616 617 m.oneof_nested_message.bb = 11 618 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 619 self.assertFalse(m.HasField('oneof_string')) 620 self.assertTrue(m.HasField('oneof_nested_message')) 621 622 m.oneof_bytes = b'bb' 623 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 624 self.assertFalse(m.HasField('oneof_nested_message')) 625 self.assertTrue(m.HasField('oneof_bytes')) 626 627 def testOneofCompositeFieldReadAccess(self, message_module): 628 m = message_module.TestAllTypes() 629 m.oneof_uint32 = 11 630 631 self.ensureNestedMessageExists(m, 'oneof_nested_message') 632 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 633 self.assertEqual(11, m.oneof_uint32) 634 635 def testOneofWhichOneof(self, message_module): 636 m = message_module.TestAllTypes() 637 self.assertIs(None, m.WhichOneof('oneof_field')) 638 if message_module is unittest_pb2: 639 self.assertFalse(m.HasField('oneof_field')) 640 641 m.oneof_uint32 = 11 642 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 643 if message_module is unittest_pb2: 644 self.assertTrue(m.HasField('oneof_field')) 645 646 m.oneof_bytes = b'bb' 647 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 648 649 m.ClearField('oneof_bytes') 650 self.assertIs(None, m.WhichOneof('oneof_field')) 651 if message_module is unittest_pb2: 652 self.assertFalse(m.HasField('oneof_field')) 653 654 def testOneofClearField(self, message_module): 655 m = message_module.TestAllTypes() 656 m.oneof_uint32 = 11 657 m.ClearField('oneof_field') 658 if message_module is unittest_pb2: 659 self.assertFalse(m.HasField('oneof_field')) 660 self.assertFalse(m.HasField('oneof_uint32')) 661 self.assertIs(None, m.WhichOneof('oneof_field')) 662 663 def testOneofClearSetField(self, message_module): 664 m = message_module.TestAllTypes() 665 m.oneof_uint32 = 11 666 m.ClearField('oneof_uint32') 667 if message_module is unittest_pb2: 668 self.assertFalse(m.HasField('oneof_field')) 669 self.assertFalse(m.HasField('oneof_uint32')) 670 self.assertIs(None, m.WhichOneof('oneof_field')) 671 672 def testOneofClearUnsetField(self, message_module): 673 m = message_module.TestAllTypes() 674 m.oneof_uint32 = 11 675 self.ensureNestedMessageExists(m, 'oneof_nested_message') 676 m.ClearField('oneof_nested_message') 677 self.assertEqual(11, m.oneof_uint32) 678 if message_module is unittest_pb2: 679 self.assertTrue(m.HasField('oneof_field')) 680 self.assertTrue(m.HasField('oneof_uint32')) 681 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) 682 683 def testOneofDeserialize(self, message_module): 684 m = message_module.TestAllTypes() 685 m.oneof_uint32 = 11 686 m2 = message_module.TestAllTypes() 687 m2.ParseFromString(m.SerializeToString()) 688 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 689 690 def testOneofCopyFrom(self, message_module): 691 m = message_module.TestAllTypes() 692 m.oneof_uint32 = 11 693 m2 = message_module.TestAllTypes() 694 m2.CopyFrom(m) 695 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) 696 697 def testOneofNestedMergeFrom(self, message_module): 698 m = message_module.NestedTestAllTypes() 699 m.payload.oneof_uint32 = 11 700 m2 = message_module.NestedTestAllTypes() 701 m2.payload.oneof_bytes = b'bb' 702 m2.child.payload.oneof_bytes = b'bb' 703 m2.MergeFrom(m) 704 self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field')) 705 self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field')) 706 707 def testOneofMessageMergeFrom(self, message_module): 708 m = message_module.NestedTestAllTypes() 709 m.payload.oneof_nested_message.bb = 11 710 m.child.payload.oneof_nested_message.bb = 12 711 m2 = message_module.NestedTestAllTypes() 712 m2.payload.oneof_uint32 = 13 713 m2.MergeFrom(m) 714 self.assertEqual('oneof_nested_message', 715 m2.payload.WhichOneof('oneof_field')) 716 self.assertEqual('oneof_nested_message', 717 m2.child.payload.WhichOneof('oneof_field')) 718 719 def testOneofNestedMessageInit(self, message_module): 720 m = message_module.TestAllTypes( 721 oneof_nested_message=message_module.TestAllTypes.NestedMessage()) 722 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) 723 724 def testOneofClear(self, message_module): 725 m = message_module.TestAllTypes() 726 m.oneof_uint32 = 11 727 m.Clear() 728 self.assertIsNone(m.WhichOneof('oneof_field')) 729 m.oneof_bytes = b'bb' 730 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) 731 732 def testAssignByteStringToUnicodeField(self, message_module): 733 """Assigning a byte string to a string field should result 734 in the value being converted to a Unicode string.""" 735 m = message_module.TestAllTypes() 736 m.optional_string = str('') 737 self.assertIsInstance(m.optional_string, six.text_type) 738 739 def testLongValuedSlice(self, message_module): 740 """It should be possible to use long-valued indicies in slices 741 742 This didn't used to work in the v2 C++ implementation. 743 """ 744 m = message_module.TestAllTypes() 745 746 # Repeated scalar 747 m.repeated_int32.append(1) 748 sl = m.repeated_int32[long(0):long(len(m.repeated_int32))] 749 self.assertEqual(len(m.repeated_int32), len(sl)) 750 751 # Repeated composite 752 m.repeated_nested_message.add().bb = 3 753 sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))] 754 self.assertEqual(len(m.repeated_nested_message), len(sl)) 755 756 def testExtendShouldNotSwallowExceptions(self, message_module): 757 """This didn't use to work in the v2 C++ implementation.""" 758 m = message_module.TestAllTypes() 759 with self.assertRaises(NameError) as _: 760 m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable 761 with self.assertRaises(NameError) as _: 762 m.repeated_nested_enum.extend( 763 a for i in range(10)) # pylint: disable=undefined-variable 764 765 FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()] 766 767 def testExtendInt32WithNothing(self, message_module): 768 """Test no-ops extending repeated int32 fields.""" 769 m = message_module.TestAllTypes() 770 self.assertSequenceEqual([], m.repeated_int32) 771 772 # TODO(ptucker): Deprecate this behavior. b/18413862 773 for falsy_value in MessageTest.FALSY_VALUES: 774 m.repeated_int32.extend(falsy_value) 775 self.assertSequenceEqual([], m.repeated_int32) 776 777 m.repeated_int32.extend([]) 778 self.assertSequenceEqual([], m.repeated_int32) 779 780 def testExtendFloatWithNothing(self, message_module): 781 """Test no-ops extending repeated float fields.""" 782 m = message_module.TestAllTypes() 783 self.assertSequenceEqual([], m.repeated_float) 784 785 # TODO(ptucker): Deprecate this behavior. b/18413862 786 for falsy_value in MessageTest.FALSY_VALUES: 787 m.repeated_float.extend(falsy_value) 788 self.assertSequenceEqual([], m.repeated_float) 789 790 m.repeated_float.extend([]) 791 self.assertSequenceEqual([], m.repeated_float) 792 793 def testExtendStringWithNothing(self, message_module): 794 """Test no-ops extending repeated string fields.""" 795 m = message_module.TestAllTypes() 796 self.assertSequenceEqual([], m.repeated_string) 797 798 # TODO(ptucker): Deprecate this behavior. b/18413862 799 for falsy_value in MessageTest.FALSY_VALUES: 800 m.repeated_string.extend(falsy_value) 801 self.assertSequenceEqual([], m.repeated_string) 802 803 m.repeated_string.extend([]) 804 self.assertSequenceEqual([], m.repeated_string) 805 806 def testExtendInt32WithPythonList(self, message_module): 807 """Test extending repeated int32 fields with python lists.""" 808 m = message_module.TestAllTypes() 809 self.assertSequenceEqual([], m.repeated_int32) 810 m.repeated_int32.extend([0]) 811 self.assertSequenceEqual([0], m.repeated_int32) 812 m.repeated_int32.extend([1, 2]) 813 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 814 m.repeated_int32.extend([3, 4]) 815 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 816 817 def testExtendFloatWithPythonList(self, message_module): 818 """Test extending repeated float fields with python lists.""" 819 m = message_module.TestAllTypes() 820 self.assertSequenceEqual([], m.repeated_float) 821 m.repeated_float.extend([0.0]) 822 self.assertSequenceEqual([0.0], m.repeated_float) 823 m.repeated_float.extend([1.0, 2.0]) 824 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 825 m.repeated_float.extend([3.0, 4.0]) 826 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 827 828 def testExtendStringWithPythonList(self, message_module): 829 """Test extending repeated string fields with python lists.""" 830 m = message_module.TestAllTypes() 831 self.assertSequenceEqual([], m.repeated_string) 832 m.repeated_string.extend(['']) 833 self.assertSequenceEqual([''], m.repeated_string) 834 m.repeated_string.extend(['11', '22']) 835 self.assertSequenceEqual(['', '11', '22'], m.repeated_string) 836 m.repeated_string.extend(['33', '44']) 837 self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string) 838 839 def testExtendStringWithString(self, message_module): 840 """Test extending repeated string fields with characters from a string.""" 841 m = message_module.TestAllTypes() 842 self.assertSequenceEqual([], m.repeated_string) 843 m.repeated_string.extend('abc') 844 self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string) 845 846 class TestIterable(object): 847 """This iterable object mimics the behavior of numpy.array. 848 849 __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1. 850 851 """ 852 853 def __init__(self, values=None): 854 self._list = values or [] 855 856 def __nonzero__(self): 857 size = len(self._list) 858 if size == 0: 859 return False 860 if size == 1: 861 return bool(self._list[0]) 862 raise ValueError('Truth value is ambiguous.') 863 864 def __len__(self): 865 return len(self._list) 866 867 def __iter__(self): 868 return self._list.__iter__() 869 870 def testExtendInt32WithIterable(self, message_module): 871 """Test extending repeated int32 fields with iterable.""" 872 m = message_module.TestAllTypes() 873 self.assertSequenceEqual([], m.repeated_int32) 874 m.repeated_int32.extend(MessageTest.TestIterable([])) 875 self.assertSequenceEqual([], m.repeated_int32) 876 m.repeated_int32.extend(MessageTest.TestIterable([0])) 877 self.assertSequenceEqual([0], m.repeated_int32) 878 m.repeated_int32.extend(MessageTest.TestIterable([1, 2])) 879 self.assertSequenceEqual([0, 1, 2], m.repeated_int32) 880 m.repeated_int32.extend(MessageTest.TestIterable([3, 4])) 881 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) 882 883 def testExtendFloatWithIterable(self, message_module): 884 """Test extending repeated float fields with iterable.""" 885 m = message_module.TestAllTypes() 886 self.assertSequenceEqual([], m.repeated_float) 887 m.repeated_float.extend(MessageTest.TestIterable([])) 888 self.assertSequenceEqual([], m.repeated_float) 889 m.repeated_float.extend(MessageTest.TestIterable([0.0])) 890 self.assertSequenceEqual([0.0], m.repeated_float) 891 m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0])) 892 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) 893 m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0])) 894 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) 895 896 def testExtendStringWithIterable(self, message_module): 897 """Test extending repeated string fields with iterable.""" 898 m = message_module.TestAllTypes() 899 self.assertSequenceEqual([], m.repeated_string) 900 m.repeated_string.extend(MessageTest.TestIterable([])) 901 self.assertSequenceEqual([], m.repeated_string) 902 m.repeated_string.extend(MessageTest.TestIterable([''])) 903 self.assertSequenceEqual([''], m.repeated_string) 904 m.repeated_string.extend(MessageTest.TestIterable(['1', '2'])) 905 self.assertSequenceEqual(['', '1', '2'], m.repeated_string) 906 m.repeated_string.extend(MessageTest.TestIterable(['3', '4'])) 907 self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string) 908 909 def testPickleRepeatedScalarContainer(self, message_module): 910 # TODO(tibell): The pure-Python implementation support pickling of 911 # scalar containers in *some* cases. For now the cpp2 version 912 # throws an exception to avoid a segfault. Investigate if we 913 # want to support pickling of these fields. 914 # 915 # For more information see: https://b2.corp.google.com/u/0/issues/18677897 916 if (api_implementation.Type() != 'cpp' or 917 api_implementation.Version() == 2): 918 return 919 m = message_module.TestAllTypes() 920 with self.assertRaises(pickle.PickleError) as _: 921 pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) 922 923 def testSortEmptyRepeatedCompositeContainer(self, message_module): 924 """Exercise a scenario that has led to segfaults in the past. 925 """ 926 m = message_module.TestAllTypes() 927 m.repeated_nested_message.sort() 928 929 def testHasFieldOnRepeatedField(self, message_module): 930 """Using HasField on a repeated field should raise an exception. 931 """ 932 m = message_module.TestAllTypes() 933 with self.assertRaises(ValueError) as _: 934 m.HasField('repeated_int32') 935 936 def testRepeatedScalarFieldPop(self, message_module): 937 m = message_module.TestAllTypes() 938 with self.assertRaises(IndexError) as _: 939 m.repeated_int32.pop() 940 m.repeated_int32.extend(range(5)) 941 self.assertEqual(4, m.repeated_int32.pop()) 942 self.assertEqual(0, m.repeated_int32.pop(0)) 943 self.assertEqual(2, m.repeated_int32.pop(1)) 944 self.assertEqual([1, 3], m.repeated_int32) 945 946 def testRepeatedCompositeFieldPop(self, message_module): 947 m = message_module.TestAllTypes() 948 with self.assertRaises(IndexError) as _: 949 m.repeated_nested_message.pop() 950 for i in range(5): 951 n = m.repeated_nested_message.add() 952 n.bb = i 953 self.assertEqual(4, m.repeated_nested_message.pop().bb) 954 self.assertEqual(0, m.repeated_nested_message.pop(0).bb) 955 self.assertEqual(2, m.repeated_nested_message.pop(1).bb) 956 self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message]) 957 958 959# Class to test proto2-only features (required, extensions, etc.) 960class Proto2Test(unittest.TestCase): 961 962 def testFieldPresence(self): 963 message = unittest_pb2.TestAllTypes() 964 965 self.assertFalse(message.HasField("optional_int32")) 966 self.assertFalse(message.HasField("optional_bool")) 967 self.assertFalse(message.HasField("optional_nested_message")) 968 969 with self.assertRaises(ValueError): 970 message.HasField("field_doesnt_exist") 971 972 with self.assertRaises(ValueError): 973 message.HasField("repeated_int32") 974 with self.assertRaises(ValueError): 975 message.HasField("repeated_nested_message") 976 977 self.assertEqual(0, message.optional_int32) 978 self.assertEqual(False, message.optional_bool) 979 self.assertEqual(0, message.optional_nested_message.bb) 980 981 # Fields are set even when setting the values to default values. 982 message.optional_int32 = 0 983 message.optional_bool = False 984 message.optional_nested_message.bb = 0 985 self.assertTrue(message.HasField("optional_int32")) 986 self.assertTrue(message.HasField("optional_bool")) 987 self.assertTrue(message.HasField("optional_nested_message")) 988 989 # Set the fields to non-default values. 990 message.optional_int32 = 5 991 message.optional_bool = True 992 message.optional_nested_message.bb = 15 993 994 self.assertTrue(message.HasField("optional_int32")) 995 self.assertTrue(message.HasField("optional_bool")) 996 self.assertTrue(message.HasField("optional_nested_message")) 997 998 # Clearing the fields unsets them and resets their value to default. 999 message.ClearField("optional_int32") 1000 message.ClearField("optional_bool") 1001 message.ClearField("optional_nested_message") 1002 1003 self.assertFalse(message.HasField("optional_int32")) 1004 self.assertFalse(message.HasField("optional_bool")) 1005 self.assertFalse(message.HasField("optional_nested_message")) 1006 self.assertEqual(0, message.optional_int32) 1007 self.assertEqual(False, message.optional_bool) 1008 self.assertEqual(0, message.optional_nested_message.bb) 1009 1010 # TODO(tibell): The C++ implementations actually allows assignment 1011 # of unknown enum values to *scalar* fields (but not repeated 1012 # fields). Once checked enum fields becomes the default in the 1013 # Python implementation, the C++ implementation should follow suit. 1014 def testAssignInvalidEnum(self): 1015 """It should not be possible to assign an invalid enum number to an 1016 enum field.""" 1017 m = unittest_pb2.TestAllTypes() 1018 1019 with self.assertRaises(ValueError) as _: 1020 m.optional_nested_enum = 1234567 1021 self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567) 1022 1023 def testGoldenExtensions(self): 1024 golden_data = test_util.GoldenFileData('golden_message') 1025 golden_message = unittest_pb2.TestAllExtensions() 1026 golden_message.ParseFromString(golden_data) 1027 all_set = unittest_pb2.TestAllExtensions() 1028 test_util.SetAllExtensions(all_set) 1029 self.assertEqual(all_set, golden_message) 1030 self.assertEqual(golden_data, golden_message.SerializeToString()) 1031 golden_copy = copy.deepcopy(golden_message) 1032 self.assertEqual(golden_data, golden_copy.SerializeToString()) 1033 1034 def testGoldenPackedExtensions(self): 1035 golden_data = test_util.GoldenFileData('golden_packed_fields_message') 1036 golden_message = unittest_pb2.TestPackedExtensions() 1037 golden_message.ParseFromString(golden_data) 1038 all_set = unittest_pb2.TestPackedExtensions() 1039 test_util.SetAllPackedExtensions(all_set) 1040 self.assertEqual(all_set, golden_message) 1041 self.assertEqual(golden_data, all_set.SerializeToString()) 1042 golden_copy = copy.deepcopy(golden_message) 1043 self.assertEqual(golden_data, golden_copy.SerializeToString()) 1044 1045 def testPickleIncompleteProto(self): 1046 golden_message = unittest_pb2.TestRequired(a=1) 1047 pickled_message = pickle.dumps(golden_message) 1048 1049 unpickled_message = pickle.loads(pickled_message) 1050 self.assertEqual(unpickled_message, golden_message) 1051 self.assertEqual(unpickled_message.a, 1) 1052 # This is still an incomplete proto - so serializing should fail 1053 self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) 1054 1055 1056 # TODO(haberman): this isn't really a proto2-specific test except that this 1057 # message has a required field in it. Should probably be factored out so 1058 # that we can test the other parts with proto3. 1059 def testParsingMerge(self): 1060 """Check the merge behavior when a required or optional field appears 1061 multiple times in the input.""" 1062 messages = [ 1063 unittest_pb2.TestAllTypes(), 1064 unittest_pb2.TestAllTypes(), 1065 unittest_pb2.TestAllTypes() ] 1066 messages[0].optional_int32 = 1 1067 messages[1].optional_int64 = 2 1068 messages[2].optional_int32 = 3 1069 messages[2].optional_string = 'hello' 1070 1071 merged_message = unittest_pb2.TestAllTypes() 1072 merged_message.optional_int32 = 3 1073 merged_message.optional_int64 = 2 1074 merged_message.optional_string = 'hello' 1075 1076 generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() 1077 generator.field1.extend(messages) 1078 generator.field2.extend(messages) 1079 generator.field3.extend(messages) 1080 generator.ext1.extend(messages) 1081 generator.ext2.extend(messages) 1082 generator.group1.add().field1.MergeFrom(messages[0]) 1083 generator.group1.add().field1.MergeFrom(messages[1]) 1084 generator.group1.add().field1.MergeFrom(messages[2]) 1085 generator.group2.add().field1.MergeFrom(messages[0]) 1086 generator.group2.add().field1.MergeFrom(messages[1]) 1087 generator.group2.add().field1.MergeFrom(messages[2]) 1088 1089 data = generator.SerializeToString() 1090 parsing_merge = unittest_pb2.TestParsingMerge() 1091 parsing_merge.ParseFromString(data) 1092 1093 # Required and optional fields should be merged. 1094 self.assertEqual(parsing_merge.required_all_types, merged_message) 1095 self.assertEqual(parsing_merge.optional_all_types, merged_message) 1096 self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, 1097 merged_message) 1098 self.assertEqual(parsing_merge.Extensions[ 1099 unittest_pb2.TestParsingMerge.optional_ext], 1100 merged_message) 1101 1102 # Repeated fields should not be merged. 1103 self.assertEqual(len(parsing_merge.repeated_all_types), 3) 1104 self.assertEqual(len(parsing_merge.repeatedgroup), 3) 1105 self.assertEqual(len(parsing_merge.Extensions[ 1106 unittest_pb2.TestParsingMerge.repeated_ext]), 3) 1107 1108 def testPythonicInit(self): 1109 message = unittest_pb2.TestAllTypes( 1110 optional_int32=100, 1111 optional_fixed32=200, 1112 optional_float=300.5, 1113 optional_bytes=b'x', 1114 optionalgroup={'a': 400}, 1115 optional_nested_message={'bb': 500}, 1116 optional_nested_enum='BAZ', 1117 repeatedgroup=[{'a': 600}, 1118 {'a': 700}], 1119 repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], 1120 default_int32=800, 1121 oneof_string='y') 1122 self.assertIsInstance(message, unittest_pb2.TestAllTypes) 1123 self.assertEqual(100, message.optional_int32) 1124 self.assertEqual(200, message.optional_fixed32) 1125 self.assertEqual(300.5, message.optional_float) 1126 self.assertEqual(b'x', message.optional_bytes) 1127 self.assertEqual(400, message.optionalgroup.a) 1128 self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage) 1129 self.assertEqual(500, message.optional_nested_message.bb) 1130 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1131 message.optional_nested_enum) 1132 self.assertEqual(2, len(message.repeatedgroup)) 1133 self.assertEqual(600, message.repeatedgroup[0].a) 1134 self.assertEqual(700, message.repeatedgroup[1].a) 1135 self.assertEqual(2, len(message.repeated_nested_enum)) 1136 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 1137 message.repeated_nested_enum[0]) 1138 self.assertEqual(unittest_pb2.TestAllTypes.BAR, 1139 message.repeated_nested_enum[1]) 1140 self.assertEqual(800, message.default_int32) 1141 self.assertEqual('y', message.oneof_string) 1142 self.assertFalse(message.HasField('optional_int64')) 1143 self.assertEqual(0, len(message.repeated_float)) 1144 self.assertEqual(42, message.default_int64) 1145 1146 message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ') 1147 self.assertEqual(unittest_pb2.TestAllTypes.BAZ, 1148 message.optional_nested_enum) 1149 1150 with self.assertRaises(ValueError): 1151 unittest_pb2.TestAllTypes( 1152 optional_nested_message={'INVALID_NESTED_FIELD': 17}) 1153 1154 with self.assertRaises(TypeError): 1155 unittest_pb2.TestAllTypes( 1156 optional_nested_message={'bb': 'INVALID_VALUE_TYPE'}) 1157 1158 with self.assertRaises(ValueError): 1159 unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL') 1160 1161 with self.assertRaises(ValueError): 1162 unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') 1163 1164 1165 1166# Class to test proto3-only features/behavior (updated field presence & enums) 1167class Proto3Test(unittest.TestCase): 1168 1169 # Utility method for comparing equality with a map. 1170 def assertMapIterEquals(self, map_iter, dict_value): 1171 # Avoid mutating caller's copy. 1172 dict_value = dict(dict_value) 1173 1174 for k, v in map_iter: 1175 self.assertEqual(v, dict_value[k]) 1176 del dict_value[k] 1177 1178 self.assertEqual({}, dict_value) 1179 1180 def testFieldPresence(self): 1181 message = unittest_proto3_arena_pb2.TestAllTypes() 1182 1183 # We can't test presence of non-repeated, non-submessage fields. 1184 with self.assertRaises(ValueError): 1185 message.HasField('optional_int32') 1186 with self.assertRaises(ValueError): 1187 message.HasField('optional_float') 1188 with self.assertRaises(ValueError): 1189 message.HasField('optional_string') 1190 with self.assertRaises(ValueError): 1191 message.HasField('optional_bool') 1192 1193 # But we can still test presence of submessage fields. 1194 self.assertFalse(message.HasField('optional_nested_message')) 1195 1196 # As with proto2, we can't test presence of fields that don't exist, or 1197 # repeated fields. 1198 with self.assertRaises(ValueError): 1199 message.HasField('field_doesnt_exist') 1200 1201 with self.assertRaises(ValueError): 1202 message.HasField('repeated_int32') 1203 with self.assertRaises(ValueError): 1204 message.HasField('repeated_nested_message') 1205 1206 # Fields should default to their type-specific default. 1207 self.assertEqual(0, message.optional_int32) 1208 self.assertEqual(0, message.optional_float) 1209 self.assertEqual('', message.optional_string) 1210 self.assertEqual(False, message.optional_bool) 1211 self.assertEqual(0, message.optional_nested_message.bb) 1212 1213 # Setting a submessage should still return proper presence information. 1214 message.optional_nested_message.bb = 0 1215 self.assertTrue(message.HasField('optional_nested_message')) 1216 1217 # Set the fields to non-default values. 1218 message.optional_int32 = 5 1219 message.optional_float = 1.1 1220 message.optional_string = 'abc' 1221 message.optional_bool = True 1222 message.optional_nested_message.bb = 15 1223 1224 # Clearing the fields unsets them and resets their value to default. 1225 message.ClearField('optional_int32') 1226 message.ClearField('optional_float') 1227 message.ClearField('optional_string') 1228 message.ClearField('optional_bool') 1229 message.ClearField('optional_nested_message') 1230 1231 self.assertEqual(0, message.optional_int32) 1232 self.assertEqual(0, message.optional_float) 1233 self.assertEqual('', message.optional_string) 1234 self.assertEqual(False, message.optional_bool) 1235 self.assertEqual(0, message.optional_nested_message.bb) 1236 1237 def testAssignUnknownEnum(self): 1238 """Assigning an unknown enum value is allowed and preserves the value.""" 1239 m = unittest_proto3_arena_pb2.TestAllTypes() 1240 1241 m.optional_nested_enum = 1234567 1242 self.assertEqual(1234567, m.optional_nested_enum) 1243 m.repeated_nested_enum.append(22334455) 1244 self.assertEqual(22334455, m.repeated_nested_enum[0]) 1245 # Assignment is a different code path than append for the C++ impl. 1246 m.repeated_nested_enum[0] = 7654321 1247 self.assertEqual(7654321, m.repeated_nested_enum[0]) 1248 serialized = m.SerializeToString() 1249 1250 m2 = unittest_proto3_arena_pb2.TestAllTypes() 1251 m2.ParseFromString(serialized) 1252 self.assertEqual(1234567, m2.optional_nested_enum) 1253 self.assertEqual(7654321, m2.repeated_nested_enum[0]) 1254 1255 # Map isn't really a proto3-only feature. But there is no proto2 equivalent 1256 # of google/protobuf/map_unittest.proto right now, so it's not easy to 1257 # test both with the same test like we do for the other proto2/proto3 tests. 1258 # (google/protobuf/map_protobuf_unittest.proto is very different in the set 1259 # of messages and fields it contains). 1260 def testScalarMapDefaults(self): 1261 msg = map_unittest_pb2.TestMap() 1262 1263 # Scalars start out unset. 1264 self.assertFalse(-123 in msg.map_int32_int32) 1265 self.assertFalse(-2**33 in msg.map_int64_int64) 1266 self.assertFalse(123 in msg.map_uint32_uint32) 1267 self.assertFalse(2**33 in msg.map_uint64_uint64) 1268 self.assertFalse(123 in msg.map_int32_double) 1269 self.assertFalse(False in msg.map_bool_bool) 1270 self.assertFalse('abc' in msg.map_string_string) 1271 self.assertFalse(111 in msg.map_int32_bytes) 1272 self.assertFalse(888 in msg.map_int32_enum) 1273 1274 # Accessing an unset key returns the default. 1275 self.assertEqual(0, msg.map_int32_int32[-123]) 1276 self.assertEqual(0, msg.map_int64_int64[-2**33]) 1277 self.assertEqual(0, msg.map_uint32_uint32[123]) 1278 self.assertEqual(0, msg.map_uint64_uint64[2**33]) 1279 self.assertEqual(0.0, msg.map_int32_double[123]) 1280 self.assertTrue(isinstance(msg.map_int32_double[123], float)) 1281 self.assertEqual(False, msg.map_bool_bool[False]) 1282 self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) 1283 self.assertEqual('', msg.map_string_string['abc']) 1284 self.assertEqual(b'', msg.map_int32_bytes[111]) 1285 self.assertEqual(0, msg.map_int32_enum[888]) 1286 1287 # It also sets the value in the map 1288 self.assertTrue(-123 in msg.map_int32_int32) 1289 self.assertTrue(-2**33 in msg.map_int64_int64) 1290 self.assertTrue(123 in msg.map_uint32_uint32) 1291 self.assertTrue(2**33 in msg.map_uint64_uint64) 1292 self.assertTrue(123 in msg.map_int32_double) 1293 self.assertTrue(False in msg.map_bool_bool) 1294 self.assertTrue('abc' in msg.map_string_string) 1295 self.assertTrue(111 in msg.map_int32_bytes) 1296 self.assertTrue(888 in msg.map_int32_enum) 1297 1298 self.assertIsInstance(msg.map_string_string['abc'], six.text_type) 1299 1300 # Accessing an unset key still throws TypeError if the type of the key 1301 # is incorrect. 1302 with self.assertRaises(TypeError): 1303 msg.map_string_string[123] 1304 1305 with self.assertRaises(TypeError): 1306 123 in msg.map_string_string 1307 1308 def testMapGet(self): 1309 # Need to test that get() properly returns the default, even though the dict 1310 # has defaultdict-like semantics. 1311 msg = map_unittest_pb2.TestMap() 1312 1313 self.assertIsNone(msg.map_int32_int32.get(5)) 1314 self.assertEqual(10, msg.map_int32_int32.get(5, 10)) 1315 self.assertIsNone(msg.map_int32_int32.get(5)) 1316 1317 msg.map_int32_int32[5] = 15 1318 self.assertEqual(15, msg.map_int32_int32.get(5)) 1319 1320 self.assertIsNone(msg.map_int32_foreign_message.get(5)) 1321 self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10)) 1322 1323 submsg = msg.map_int32_foreign_message[5] 1324 self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) 1325 1326 def testScalarMap(self): 1327 msg = map_unittest_pb2.TestMap() 1328 1329 self.assertEqual(0, len(msg.map_int32_int32)) 1330 self.assertFalse(5 in msg.map_int32_int32) 1331 1332 msg.map_int32_int32[-123] = -456 1333 msg.map_int64_int64[-2**33] = -2**34 1334 msg.map_uint32_uint32[123] = 456 1335 msg.map_uint64_uint64[2**33] = 2**34 1336 msg.map_string_string['abc'] = '123' 1337 msg.map_int32_enum[888] = 2 1338 1339 self.assertEqual([], msg.FindInitializationErrors()) 1340 1341 self.assertEqual(1, len(msg.map_string_string)) 1342 1343 # Bad key. 1344 with self.assertRaises(TypeError): 1345 msg.map_string_string[123] = '123' 1346 1347 # Verify that trying to assign a bad key doesn't actually add a member to 1348 # the map. 1349 self.assertEqual(1, len(msg.map_string_string)) 1350 1351 # Bad value. 1352 with self.assertRaises(TypeError): 1353 msg.map_string_string['123'] = 123 1354 1355 serialized = msg.SerializeToString() 1356 msg2 = map_unittest_pb2.TestMap() 1357 msg2.ParseFromString(serialized) 1358 1359 # Bad key. 1360 with self.assertRaises(TypeError): 1361 msg2.map_string_string[123] = '123' 1362 1363 # Bad value. 1364 with self.assertRaises(TypeError): 1365 msg2.map_string_string['123'] = 123 1366 1367 self.assertEqual(-456, msg2.map_int32_int32[-123]) 1368 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 1369 self.assertEqual(456, msg2.map_uint32_uint32[123]) 1370 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 1371 self.assertEqual('123', msg2.map_string_string['abc']) 1372 self.assertEqual(2, msg2.map_int32_enum[888]) 1373 1374 def testStringUnicodeConversionInMap(self): 1375 msg = map_unittest_pb2.TestMap() 1376 1377 unicode_obj = u'\u1234' 1378 bytes_obj = unicode_obj.encode('utf8') 1379 1380 msg.map_string_string[bytes_obj] = bytes_obj 1381 1382 (key, value) = list(msg.map_string_string.items())[0] 1383 1384 self.assertEqual(key, unicode_obj) 1385 self.assertEqual(value, unicode_obj) 1386 1387 self.assertIsInstance(key, six.text_type) 1388 self.assertIsInstance(value, six.text_type) 1389 1390 def testMessageMap(self): 1391 msg = map_unittest_pb2.TestMap() 1392 1393 self.assertEqual(0, len(msg.map_int32_foreign_message)) 1394 self.assertFalse(5 in msg.map_int32_foreign_message) 1395 1396 msg.map_int32_foreign_message[123] 1397 # get_or_create() is an alias for getitem. 1398 msg.map_int32_foreign_message.get_or_create(-456) 1399 1400 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1401 self.assertIn(123, msg.map_int32_foreign_message) 1402 self.assertIn(-456, msg.map_int32_foreign_message) 1403 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1404 1405 # Bad key. 1406 with self.assertRaises(TypeError): 1407 msg.map_int32_foreign_message['123'] 1408 1409 # Can't assign directly to submessage. 1410 with self.assertRaises(ValueError): 1411 msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] 1412 1413 # Verify that trying to assign a bad key doesn't actually add a member to 1414 # the map. 1415 self.assertEqual(2, len(msg.map_int32_foreign_message)) 1416 1417 serialized = msg.SerializeToString() 1418 msg2 = map_unittest_pb2.TestMap() 1419 msg2.ParseFromString(serialized) 1420 1421 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 1422 self.assertIn(123, msg2.map_int32_foreign_message) 1423 self.assertIn(-456, msg2.map_int32_foreign_message) 1424 self.assertEqual(2, len(msg2.map_int32_foreign_message)) 1425 1426 def testMergeFrom(self): 1427 msg = map_unittest_pb2.TestMap() 1428 msg.map_int32_int32[12] = 34 1429 msg.map_int32_int32[56] = 78 1430 msg.map_int64_int64[22] = 33 1431 msg.map_int32_foreign_message[111].c = 5 1432 msg.map_int32_foreign_message[222].c = 10 1433 1434 msg2 = map_unittest_pb2.TestMap() 1435 msg2.map_int32_int32[12] = 55 1436 msg2.map_int64_int64[88] = 99 1437 msg2.map_int32_foreign_message[222].c = 15 1438 msg2.map_int32_foreign_message[222].d = 20 1439 old_map_value = msg2.map_int32_foreign_message[222] 1440 1441 msg2.MergeFrom(msg) 1442 1443 self.assertEqual(34, msg2.map_int32_int32[12]) 1444 self.assertEqual(78, msg2.map_int32_int32[56]) 1445 self.assertEqual(33, msg2.map_int64_int64[22]) 1446 self.assertEqual(99, msg2.map_int64_int64[88]) 1447 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 1448 self.assertEqual(10, msg2.map_int32_foreign_message[222].c) 1449 self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d')) 1450 self.assertEqual(15, old_map_value.c) 1451 1452 # Verify that there is only one entry per key, even though the MergeFrom 1453 # may have internally created multiple entries for a single key in the 1454 # list representation. 1455 as_dict = {} 1456 for key in msg2.map_int32_foreign_message: 1457 self.assertFalse(key in as_dict) 1458 as_dict[key] = msg2.map_int32_foreign_message[key].c 1459 1460 self.assertEqual({111: 5, 222: 10}, as_dict) 1461 1462 # Special case: test that delete of item really removes the item, even if 1463 # there might have physically been duplicate keys due to the previous merge. 1464 # This is only a special case for the C++ implementation which stores the 1465 # map as an array. 1466 del msg2.map_int32_int32[12] 1467 self.assertFalse(12 in msg2.map_int32_int32) 1468 1469 del msg2.map_int32_foreign_message[222] 1470 self.assertFalse(222 in msg2.map_int32_foreign_message) 1471 1472 def testMergeFromBadType(self): 1473 msg = map_unittest_pb2.TestMap() 1474 with self.assertRaisesRegexp( 1475 TypeError, 1476 r'Parameter to MergeFrom\(\) must be instance of same class: expected ' 1477 r'.*TestMap got int\.'): 1478 msg.MergeFrom(1) 1479 1480 def testCopyFromBadType(self): 1481 msg = map_unittest_pb2.TestMap() 1482 with self.assertRaisesRegexp( 1483 TypeError, 1484 r'Parameter to [A-Za-z]*From\(\) must be instance of same class: ' 1485 r'expected .*TestMap got int\.'): 1486 msg.CopyFrom(1) 1487 1488 def testIntegerMapWithLongs(self): 1489 msg = map_unittest_pb2.TestMap() 1490 msg.map_int32_int32[long(-123)] = long(-456) 1491 msg.map_int64_int64[long(-2**33)] = long(-2**34) 1492 msg.map_uint32_uint32[long(123)] = long(456) 1493 msg.map_uint64_uint64[long(2**33)] = long(2**34) 1494 1495 serialized = msg.SerializeToString() 1496 msg2 = map_unittest_pb2.TestMap() 1497 msg2.ParseFromString(serialized) 1498 1499 self.assertEqual(-456, msg2.map_int32_int32[-123]) 1500 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) 1501 self.assertEqual(456, msg2.map_uint32_uint32[123]) 1502 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) 1503 1504 def testMapAssignmentCausesPresence(self): 1505 msg = map_unittest_pb2.TestMapSubmessage() 1506 msg.test_map.map_int32_int32[123] = 456 1507 1508 serialized = msg.SerializeToString() 1509 msg2 = map_unittest_pb2.TestMapSubmessage() 1510 msg2.ParseFromString(serialized) 1511 1512 self.assertEqual(msg, msg2) 1513 1514 # Now test that various mutations of the map properly invalidate the 1515 # cached size of the submessage. 1516 msg.test_map.map_int32_int32[888] = 999 1517 serialized = msg.SerializeToString() 1518 msg2.ParseFromString(serialized) 1519 self.assertEqual(msg, msg2) 1520 1521 msg.test_map.map_int32_int32.clear() 1522 serialized = msg.SerializeToString() 1523 msg2.ParseFromString(serialized) 1524 self.assertEqual(msg, msg2) 1525 1526 def testMapAssignmentCausesPresenceForSubmessages(self): 1527 msg = map_unittest_pb2.TestMapSubmessage() 1528 msg.test_map.map_int32_foreign_message[123].c = 5 1529 1530 serialized = msg.SerializeToString() 1531 msg2 = map_unittest_pb2.TestMapSubmessage() 1532 msg2.ParseFromString(serialized) 1533 1534 self.assertEqual(msg, msg2) 1535 1536 # Now test that various mutations of the map properly invalidate the 1537 # cached size of the submessage. 1538 msg.test_map.map_int32_foreign_message[888].c = 7 1539 serialized = msg.SerializeToString() 1540 msg2.ParseFromString(serialized) 1541 self.assertEqual(msg, msg2) 1542 1543 msg.test_map.map_int32_foreign_message[888].MergeFrom( 1544 msg.test_map.map_int32_foreign_message[123]) 1545 serialized = msg.SerializeToString() 1546 msg2.ParseFromString(serialized) 1547 self.assertEqual(msg, msg2) 1548 1549 msg.test_map.map_int32_foreign_message.clear() 1550 serialized = msg.SerializeToString() 1551 msg2.ParseFromString(serialized) 1552 self.assertEqual(msg, msg2) 1553 1554 def testModifyMapWhileIterating(self): 1555 msg = map_unittest_pb2.TestMap() 1556 1557 string_string_iter = iter(msg.map_string_string) 1558 int32_foreign_iter = iter(msg.map_int32_foreign_message) 1559 1560 msg.map_string_string['abc'] = '123' 1561 msg.map_int32_foreign_message[5].c = 5 1562 1563 with self.assertRaises(RuntimeError): 1564 for key in string_string_iter: 1565 pass 1566 1567 with self.assertRaises(RuntimeError): 1568 for key in int32_foreign_iter: 1569 pass 1570 1571 def testSubmessageMap(self): 1572 msg = map_unittest_pb2.TestMap() 1573 1574 submsg = msg.map_int32_foreign_message[111] 1575 self.assertIs(submsg, msg.map_int32_foreign_message[111]) 1576 self.assertIsInstance(submsg, unittest_pb2.ForeignMessage) 1577 1578 submsg.c = 5 1579 1580 serialized = msg.SerializeToString() 1581 msg2 = map_unittest_pb2.TestMap() 1582 msg2.ParseFromString(serialized) 1583 1584 self.assertEqual(5, msg2.map_int32_foreign_message[111].c) 1585 1586 # Doesn't allow direct submessage assignment. 1587 with self.assertRaises(ValueError): 1588 msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage() 1589 1590 def testMapIteration(self): 1591 msg = map_unittest_pb2.TestMap() 1592 1593 for k, v in msg.map_int32_int32.items(): 1594 # Should not be reached. 1595 self.assertTrue(False) 1596 1597 msg.map_int32_int32[2] = 4 1598 msg.map_int32_int32[3] = 6 1599 msg.map_int32_int32[4] = 8 1600 self.assertEqual(3, len(msg.map_int32_int32)) 1601 1602 matching_dict = {2: 4, 3: 6, 4: 8} 1603 self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) 1604 1605 def testMapItems(self): 1606 # Map items used to have strange behaviors when use c extension. Because 1607 # [] may reorder the map and invalidate any exsting iterators. 1608 # TODO(jieluo): Check if [] reordering the map is a bug or intended 1609 # behavior. 1610 msg = map_unittest_pb2.TestMap() 1611 msg.map_string_string['local_init_op'] = '' 1612 msg.map_string_string['trainable_variables'] = '' 1613 msg.map_string_string['variables'] = '' 1614 msg.map_string_string['init_op'] = '' 1615 msg.map_string_string['summaries'] = '' 1616 items1 = msg.map_string_string.items() 1617 items2 = msg.map_string_string.items() 1618 self.assertEqual(items1, items2) 1619 1620 def testMapIterationClearMessage(self): 1621 # Iterator needs to work even if message and map are deleted. 1622 msg = map_unittest_pb2.TestMap() 1623 1624 msg.map_int32_int32[2] = 4 1625 msg.map_int32_int32[3] = 6 1626 msg.map_int32_int32[4] = 8 1627 1628 it = msg.map_int32_int32.items() 1629 del msg 1630 1631 matching_dict = {2: 4, 3: 6, 4: 8} 1632 self.assertMapIterEquals(it, matching_dict) 1633 1634 def testMapConstruction(self): 1635 msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4}) 1636 self.assertEqual(2, msg.map_int32_int32[1]) 1637 self.assertEqual(4, msg.map_int32_int32[3]) 1638 1639 msg = map_unittest_pb2.TestMap( 1640 map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)}) 1641 self.assertEqual(5, msg.map_int32_foreign_message[3].c) 1642 1643 def testMapValidAfterFieldCleared(self): 1644 # Map needs to work even if field is cleared. 1645 # For the C++ implementation this tests the correctness of 1646 # ScalarMapContainer::Release() 1647 msg = map_unittest_pb2.TestMap() 1648 int32_map = msg.map_int32_int32 1649 1650 int32_map[2] = 4 1651 int32_map[3] = 6 1652 int32_map[4] = 8 1653 1654 msg.ClearField('map_int32_int32') 1655 self.assertEqual(b'', msg.SerializeToString()) 1656 matching_dict = {2: 4, 3: 6, 4: 8} 1657 self.assertMapIterEquals(int32_map.items(), matching_dict) 1658 1659 def testMessageMapValidAfterFieldCleared(self): 1660 # Map needs to work even if field is cleared. 1661 # For the C++ implementation this tests the correctness of 1662 # ScalarMapContainer::Release() 1663 msg = map_unittest_pb2.TestMap() 1664 int32_foreign_message = msg.map_int32_foreign_message 1665 1666 int32_foreign_message[2].c = 5 1667 1668 msg.ClearField('map_int32_foreign_message') 1669 self.assertEqual(b'', msg.SerializeToString()) 1670 self.assertTrue(2 in int32_foreign_message.keys()) 1671 1672 def testMapIterInvalidatedByClearField(self): 1673 # Map iterator is invalidated when field is cleared. 1674 # But this case does need to not crash the interpreter. 1675 # For the C++ implementation this tests the correctness of 1676 # ScalarMapContainer::Release() 1677 msg = map_unittest_pb2.TestMap() 1678 1679 it = iter(msg.map_int32_int32) 1680 1681 msg.ClearField('map_int32_int32') 1682 with self.assertRaises(RuntimeError): 1683 for _ in it: 1684 pass 1685 1686 it = iter(msg.map_int32_foreign_message) 1687 msg.ClearField('map_int32_foreign_message') 1688 with self.assertRaises(RuntimeError): 1689 for _ in it: 1690 pass 1691 1692 def testMapDelete(self): 1693 msg = map_unittest_pb2.TestMap() 1694 1695 self.assertEqual(0, len(msg.map_int32_int32)) 1696 1697 msg.map_int32_int32[4] = 6 1698 self.assertEqual(1, len(msg.map_int32_int32)) 1699 1700 with self.assertRaises(KeyError): 1701 del msg.map_int32_int32[88] 1702 1703 del msg.map_int32_int32[4] 1704 self.assertEqual(0, len(msg.map_int32_int32)) 1705 1706 def testMapsAreMapping(self): 1707 msg = map_unittest_pb2.TestMap() 1708 self.assertIsInstance(msg.map_int32_int32, collections.Mapping) 1709 self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping) 1710 self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping) 1711 self.assertIsInstance(msg.map_int32_foreign_message, 1712 collections.MutableMapping) 1713 1714 def testMapFindInitializationErrorsSmokeTest(self): 1715 msg = map_unittest_pb2.TestMap() 1716 msg.map_string_string['abc'] = '123' 1717 msg.map_int32_int32[35] = 64 1718 msg.map_string_foreign_message['foo'].c = 5 1719 self.assertEqual(0, len(msg.FindInitializationErrors())) 1720 1721 1722 1723class ValidTypeNamesTest(unittest.TestCase): 1724 1725 def assertImportFromName(self, msg, base_name): 1726 # Parse <type 'module.class_name'> to extra 'some.name' as a string. 1727 tp_name = str(type(msg)).split("'")[1] 1728 valid_names = ('Repeated%sContainer' % base_name, 1729 'Repeated%sFieldContainer' % base_name) 1730 self.assertTrue(any(tp_name.endswith(v) for v in valid_names), 1731 '%r does end with any of %r' % (tp_name, valid_names)) 1732 1733 parts = tp_name.split('.') 1734 class_name = parts[-1] 1735 module_name = '.'.join(parts[:-1]) 1736 __import__(module_name, fromlist=[class_name]) 1737 1738 def testTypeNamesCanBeImported(self): 1739 # If import doesn't work, pickling won't work either. 1740 pb = unittest_pb2.TestAllTypes() 1741 self.assertImportFromName(pb.repeated_int32, 'Scalar') 1742 self.assertImportFromName(pb.repeated_nested_message, 'Composite') 1743 1744class PackedFieldTest(unittest.TestCase): 1745 1746 def setMessage(self, message): 1747 message.repeated_int32.append(1) 1748 message.repeated_int64.append(1) 1749 message.repeated_uint32.append(1) 1750 message.repeated_uint64.append(1) 1751 message.repeated_sint32.append(1) 1752 message.repeated_sint64.append(1) 1753 message.repeated_fixed32.append(1) 1754 message.repeated_fixed64.append(1) 1755 message.repeated_sfixed32.append(1) 1756 message.repeated_sfixed64.append(1) 1757 message.repeated_float.append(1.0) 1758 message.repeated_double.append(1.0) 1759 message.repeated_bool.append(True) 1760 message.repeated_nested_enum.append(1) 1761 1762 def testPackedFields(self): 1763 message = packed_field_test_pb2.TestPackedTypes() 1764 self.setMessage(message) 1765 golden_data = (b'\x0A\x01\x01' 1766 b'\x12\x01\x01' 1767 b'\x1A\x01\x01' 1768 b'\x22\x01\x01' 1769 b'\x2A\x01\x02' 1770 b'\x32\x01\x02' 1771 b'\x3A\x04\x01\x00\x00\x00' 1772 b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00' 1773 b'\x4A\x04\x01\x00\x00\x00' 1774 b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00' 1775 b'\x5A\x04\x00\x00\x80\x3f' 1776 b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f' 1777 b'\x6A\x01\x01' 1778 b'\x72\x01\x01') 1779 self.assertEqual(golden_data, message.SerializeToString()) 1780 1781 def testUnpackedFields(self): 1782 message = packed_field_test_pb2.TestUnpackedTypes() 1783 self.setMessage(message) 1784 golden_data = (b'\x08\x01' 1785 b'\x10\x01' 1786 b'\x18\x01' 1787 b'\x20\x01' 1788 b'\x28\x02' 1789 b'\x30\x02' 1790 b'\x3D\x01\x00\x00\x00' 1791 b'\x41\x01\x00\x00\x00\x00\x00\x00\x00' 1792 b'\x4D\x01\x00\x00\x00' 1793 b'\x51\x01\x00\x00\x00\x00\x00\x00\x00' 1794 b'\x5D\x00\x00\x80\x3f' 1795 b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f' 1796 b'\x68\x01' 1797 b'\x70\x01') 1798 self.assertEqual(golden_data, message.SerializeToString()) 1799 1800 1801@unittest.skipIf(api_implementation.Type() != 'cpp', 1802 'explicit tests of the C++ implementation') 1803class OversizeProtosTest(unittest.TestCase): 1804 1805 def setUp(self): 1806 self.file_desc = """ 1807 name: "f/f.msg2" 1808 package: "f" 1809 message_type { 1810 name: "msg1" 1811 field { 1812 name: "payload" 1813 number: 1 1814 label: LABEL_OPTIONAL 1815 type: TYPE_STRING 1816 } 1817 } 1818 message_type { 1819 name: "msg2" 1820 field { 1821 name: "field" 1822 number: 1 1823 label: LABEL_OPTIONAL 1824 type: TYPE_MESSAGE 1825 type_name: "msg1" 1826 } 1827 } 1828 """ 1829 pool = descriptor_pool.DescriptorPool() 1830 desc = descriptor_pb2.FileDescriptorProto() 1831 text_format.Parse(self.file_desc, desc) 1832 pool.Add(desc) 1833 self.proto_cls = message_factory.MessageFactory(pool).GetPrototype( 1834 pool.FindMessageTypeByName('f.msg2')) 1835 self.p = self.proto_cls() 1836 self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) 1837 self.p_serialized = self.p.SerializeToString() 1838 1839 def testAssertOversizeProto(self): 1840 from google.protobuf.pyext._message import SetAllowOversizeProtos 1841 SetAllowOversizeProtos(False) 1842 q = self.proto_cls() 1843 try: 1844 q.ParseFromString(self.p_serialized) 1845 except message.DecodeError as e: 1846 self.assertEqual(str(e), 'Error parsing message') 1847 1848 def testSucceedOversizeProto(self): 1849 from google.protobuf.pyext._message import SetAllowOversizeProtos 1850 SetAllowOversizeProtos(True) 1851 q = self.proto_cls() 1852 q.ParseFromString(self.p_serialized) 1853 self.assertEqual(self.p.field.payload, q.field.payload) 1854 1855if __name__ == '__main__': 1856 unittest.main() 1857