• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#! /usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# https://developers.google.com/protocol-buffers/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Tests python protocol buffers against the golden message.
35
36Note that the golden messages exercise every known field type, thus this
37test ends up exercising and verifying nearly all of the parsing and
38serialization code in the whole library.
39
40TODO(kenton):  Merge with wire_format_test?  It doesn't make a whole lot of
41sense to call this a test of the "message" module, which only declares an
42abstract interface.
43"""
44
45__author__ = 'gps@google.com (Gregory P. Smith)'
46
47
48import copy
49import math
50import operator
51import pickle
52import pydoc
53import six
54import sys
55import warnings
56
57try:
58  # Since python 3
59  import collections.abc as collections_abc
60except ImportError:
61  # Won't work after python 3.8
62  import collections as collections_abc
63
64try:
65  import unittest2 as unittest  # PY26
66except ImportError:
67  import unittest
68try:
69  cmp                                   # Python 2
70except NameError:
71  cmp = lambda x, y: (x > y) - (x < y)  # Python 3
72
73from google.protobuf import map_proto2_unittest_pb2
74from google.protobuf import map_unittest_pb2
75from google.protobuf import unittest_pb2
76from google.protobuf import unittest_proto3_arena_pb2
77from google.protobuf import descriptor_pb2
78from google.protobuf import descriptor_pool
79from google.protobuf import message_factory
80from google.protobuf import text_format
81from google.protobuf.internal import api_implementation
82from google.protobuf.internal import encoder
83from google.protobuf.internal import more_extensions_pb2
84from google.protobuf.internal import packed_field_test_pb2
85from google.protobuf.internal import test_util
86from google.protobuf.internal import test_proto3_optional_pb2
87from google.protobuf.internal import testing_refleaks
88from google.protobuf import message
89from google.protobuf.internal import _parameterized
90
91UCS2_MAXUNICODE = 65535
92if six.PY3:
93  long = int
94
95
96# Python pre-2.6 does not have isinf() or isnan() functions, so we have
97# to provide our own.
98def isnan(val):
99  # NaN is never equal to itself.
100  return val != val
101def isinf(val):
102  # Infinity times zero equals NaN.
103  return not isnan(val) and isnan(val * 0)
104def IsPosInf(val):
105  return isinf(val) and (val > 0)
106def IsNegInf(val):
107  return isinf(val) and (val < 0)
108
109
110warnings.simplefilter('error', DeprecationWarning)
111
112
113@_parameterized.named_parameters(
114    ('_proto2', unittest_pb2),
115    ('_proto3', unittest_proto3_arena_pb2))
116@testing_refleaks.TestCase
117class MessageTest(unittest.TestCase):
118
119  def testBadUtf8String(self, message_module):
120    if api_implementation.Type() != 'python':
121      self.skipTest("Skipping testBadUtf8String, currently only the python "
122                    "api implementation raises UnicodeDecodeError when a "
123                    "string field contains bad utf-8.")
124    bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
125    with self.assertRaises(UnicodeDecodeError) as context:
126      message_module.TestAllTypes.FromString(bad_utf8_data)
127    self.assertIn('TestAllTypes.optional_string', str(context.exception))
128
129  def testGoldenMessage(self, message_module):
130    # Proto3 doesn't have the "default_foo" members or foreign enums,
131    # and doesn't preserve unknown fields, so for proto3 we use a golden
132    # message that doesn't have these fields set.
133    if message_module is unittest_pb2:
134      golden_data = test_util.GoldenFileData(
135          'golden_message_oneof_implemented')
136    else:
137      golden_data = test_util.GoldenFileData('golden_message_proto3')
138
139    golden_message = message_module.TestAllTypes()
140    golden_message.ParseFromString(golden_data)
141    if message_module is unittest_pb2:
142      test_util.ExpectAllFieldsSet(self, golden_message)
143    self.assertEqual(golden_data, golden_message.SerializeToString())
144    golden_copy = copy.deepcopy(golden_message)
145    self.assertEqual(golden_data, golden_copy.SerializeToString())
146
147  def testGoldenPackedMessage(self, message_module):
148    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
149    golden_message = message_module.TestPackedTypes()
150    parsed_bytes = golden_message.ParseFromString(golden_data)
151    all_set = message_module.TestPackedTypes()
152    test_util.SetAllPackedFields(all_set)
153    self.assertEqual(parsed_bytes, len(golden_data))
154    self.assertEqual(all_set, golden_message)
155    self.assertEqual(golden_data, all_set.SerializeToString())
156    golden_copy = copy.deepcopy(golden_message)
157    self.assertEqual(golden_data, golden_copy.SerializeToString())
158
159  def testParseErrors(self, message_module):
160    msg = message_module.TestAllTypes()
161    self.assertRaises(TypeError, msg.FromString, 0)
162    self.assertRaises(Exception, msg.FromString, '0')
163    # TODO(jieluo): Fix cpp extension to raise error instead of warning.
164    # b/27494216
165    end_tag = encoder.TagBytes(1, 4)
166    if api_implementation.Type() == 'python':
167      with self.assertRaises(message.DecodeError) as context:
168        msg.FromString(end_tag)
169      self.assertEqual('Unexpected end-group tag.', str(context.exception))
170
171    # Field number 0 is illegal.
172    self.assertRaises(message.DecodeError, msg.FromString, b'\3\4')
173
174  def testDeterminismParameters(self, message_module):
175    # This message is always deterministically serialized, even if determinism
176    # is disabled, so we can use it to verify that all the determinism
177    # parameters work correctly.
178    golden_data = (b'\xe2\x02\nOne string'
179                   b'\xe2\x02\nTwo string'
180                   b'\xe2\x02\nRed string'
181                   b'\xe2\x02\x0bBlue string')
182    golden_message = message_module.TestAllTypes()
183    golden_message.repeated_string.extend([
184        'One string',
185        'Two string',
186        'Red string',
187        'Blue string',
188    ])
189    self.assertEqual(golden_data,
190                     golden_message.SerializeToString(deterministic=None))
191    self.assertEqual(golden_data,
192                     golden_message.SerializeToString(deterministic=False))
193    self.assertEqual(golden_data,
194                     golden_message.SerializeToString(deterministic=True))
195
196    class BadArgError(Exception):
197      pass
198
199    class BadArg(object):
200
201      def __nonzero__(self):
202        raise BadArgError()
203
204      def __bool__(self):
205        raise BadArgError()
206
207    with self.assertRaises(BadArgError):
208      golden_message.SerializeToString(deterministic=BadArg())
209
210  def testPickleSupport(self, message_module):
211    golden_data = test_util.GoldenFileData('golden_message')
212    golden_message = message_module.TestAllTypes()
213    golden_message.ParseFromString(golden_data)
214    pickled_message = pickle.dumps(golden_message)
215
216    unpickled_message = pickle.loads(pickled_message)
217    self.assertEqual(unpickled_message, golden_message)
218
219  def testPickleNestedMessage(self, message_module):
220    golden_message = message_module.TestPickleNestedMessage.NestedMessage(bb=1)
221    pickled_message = pickle.dumps(golden_message)
222    unpickled_message = pickle.loads(pickled_message)
223    self.assertEqual(unpickled_message, golden_message)
224
225  def testPickleNestedNestedMessage(self, message_module):
226    cls = message_module.TestPickleNestedMessage.NestedMessage
227    golden_message = cls.NestedNestedMessage(cc=1)
228    pickled_message = pickle.dumps(golden_message)
229    unpickled_message = pickle.loads(pickled_message)
230    self.assertEqual(unpickled_message, golden_message)
231
232  def testPositiveInfinity(self, message_module):
233    if message_module is unittest_pb2:
234      golden_data = (b'\x5D\x00\x00\x80\x7F'
235                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
236                     b'\xCD\x02\x00\x00\x80\x7F'
237                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
238    else:
239      golden_data = (b'\x5D\x00\x00\x80\x7F'
240                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
241                     b'\xCA\x02\x04\x00\x00\x80\x7F'
242                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
243
244    golden_message = message_module.TestAllTypes()
245    golden_message.ParseFromString(golden_data)
246    self.assertTrue(IsPosInf(golden_message.optional_float))
247    self.assertTrue(IsPosInf(golden_message.optional_double))
248    self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
249    self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
250    self.assertEqual(golden_data, golden_message.SerializeToString())
251
252  def testNegativeInfinity(self, message_module):
253    if message_module is unittest_pb2:
254      golden_data = (b'\x5D\x00\x00\x80\xFF'
255                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
256                     b'\xCD\x02\x00\x00\x80\xFF'
257                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
258    else:
259      golden_data = (b'\x5D\x00\x00\x80\xFF'
260                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
261                     b'\xCA\x02\x04\x00\x00\x80\xFF'
262                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
263
264    golden_message = message_module.TestAllTypes()
265    golden_message.ParseFromString(golden_data)
266    self.assertTrue(IsNegInf(golden_message.optional_float))
267    self.assertTrue(IsNegInf(golden_message.optional_double))
268    self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
269    self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
270    self.assertEqual(golden_data, golden_message.SerializeToString())
271
272  def testNotANumber(self, message_module):
273    golden_data = (b'\x5D\x00\x00\xC0\x7F'
274                   b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
275                   b'\xCD\x02\x00\x00\xC0\x7F'
276                   b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
277    golden_message = message_module.TestAllTypes()
278    golden_message.ParseFromString(golden_data)
279    self.assertTrue(isnan(golden_message.optional_float))
280    self.assertTrue(isnan(golden_message.optional_double))
281    self.assertTrue(isnan(golden_message.repeated_float[0]))
282    self.assertTrue(isnan(golden_message.repeated_double[0]))
283
284    # The protocol buffer may serialize to any one of multiple different
285    # representations of a NaN.  Rather than verify a specific representation,
286    # verify the serialized string can be converted into a correctly
287    # behaving protocol buffer.
288    serialized = golden_message.SerializeToString()
289    message = message_module.TestAllTypes()
290    message.ParseFromString(serialized)
291    self.assertTrue(isnan(message.optional_float))
292    self.assertTrue(isnan(message.optional_double))
293    self.assertTrue(isnan(message.repeated_float[0]))
294    self.assertTrue(isnan(message.repeated_double[0]))
295
296  def testPositiveInfinityPacked(self, message_module):
297    golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
298                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
299    golden_message = message_module.TestPackedTypes()
300    golden_message.ParseFromString(golden_data)
301    self.assertTrue(IsPosInf(golden_message.packed_float[0]))
302    self.assertTrue(IsPosInf(golden_message.packed_double[0]))
303    self.assertEqual(golden_data, golden_message.SerializeToString())
304
305  def testNegativeInfinityPacked(self, message_module):
306    golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
307                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
308    golden_message = message_module.TestPackedTypes()
309    golden_message.ParseFromString(golden_data)
310    self.assertTrue(IsNegInf(golden_message.packed_float[0]))
311    self.assertTrue(IsNegInf(golden_message.packed_double[0]))
312    self.assertEqual(golden_data, golden_message.SerializeToString())
313
314  def testNotANumberPacked(self, message_module):
315    golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
316                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
317    golden_message = message_module.TestPackedTypes()
318    golden_message.ParseFromString(golden_data)
319    self.assertTrue(isnan(golden_message.packed_float[0]))
320    self.assertTrue(isnan(golden_message.packed_double[0]))
321
322    serialized = golden_message.SerializeToString()
323    message = message_module.TestPackedTypes()
324    message.ParseFromString(serialized)
325    self.assertTrue(isnan(message.packed_float[0]))
326    self.assertTrue(isnan(message.packed_double[0]))
327
328  def testExtremeFloatValues(self, message_module):
329    message = message_module.TestAllTypes()
330
331    # Most positive exponent, no significand bits set.
332    kMostPosExponentNoSigBits = math.pow(2, 127)
333    message.optional_float = kMostPosExponentNoSigBits
334    message.ParseFromString(message.SerializeToString())
335    self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
336
337    # Most positive exponent, one significand bit set.
338    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
339    message.optional_float = kMostPosExponentOneSigBit
340    message.ParseFromString(message.SerializeToString())
341    self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
342
343    # Repeat last two cases with values of same magnitude, but negative.
344    message.optional_float = -kMostPosExponentNoSigBits
345    message.ParseFromString(message.SerializeToString())
346    self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
347
348    message.optional_float = -kMostPosExponentOneSigBit
349    message.ParseFromString(message.SerializeToString())
350    self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
351
352    # Most negative exponent, no significand bits set.
353    kMostNegExponentNoSigBits = math.pow(2, -127)
354    message.optional_float = kMostNegExponentNoSigBits
355    message.ParseFromString(message.SerializeToString())
356    self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
357
358    # Most negative exponent, one significand bit set.
359    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
360    message.optional_float = kMostNegExponentOneSigBit
361    message.ParseFromString(message.SerializeToString())
362    self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
363
364    # Repeat last two cases with values of the same magnitude, but negative.
365    message.optional_float = -kMostNegExponentNoSigBits
366    message.ParseFromString(message.SerializeToString())
367    self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
368
369    message.optional_float = -kMostNegExponentOneSigBit
370    message.ParseFromString(message.SerializeToString())
371    self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
372
373    # Max 4 bytes float value
374    max_float = float.fromhex('0x1.fffffep+127')
375    message.optional_float = max_float
376    self.assertAlmostEqual(message.optional_float, max_float)
377    serialized_data = message.SerializeToString()
378    message.ParseFromString(serialized_data)
379    self.assertAlmostEqual(message.optional_float, max_float)
380
381    # Test set double to float field.
382    message.optional_float = 3.4028235e+39
383    self.assertEqual(message.optional_float, float('inf'))
384    serialized_data = message.SerializeToString()
385    message.ParseFromString(serialized_data)
386    self.assertEqual(message.optional_float, float('inf'))
387
388    message.optional_float = -3.4028235e+39
389    self.assertEqual(message.optional_float, float('-inf'))
390
391    message.optional_float = 1.4028235e-39
392    self.assertAlmostEqual(message.optional_float, 1.4028235e-39)
393
394  def testExtremeDoubleValues(self, message_module):
395    message = message_module.TestAllTypes()
396
397    # Most positive exponent, no significand bits set.
398    kMostPosExponentNoSigBits = math.pow(2, 1023)
399    message.optional_double = kMostPosExponentNoSigBits
400    message.ParseFromString(message.SerializeToString())
401    self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
402
403    # Most positive exponent, one significand bit set.
404    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
405    message.optional_double = kMostPosExponentOneSigBit
406    message.ParseFromString(message.SerializeToString())
407    self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
408
409    # Repeat last two cases with values of same magnitude, but negative.
410    message.optional_double = -kMostPosExponentNoSigBits
411    message.ParseFromString(message.SerializeToString())
412    self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
413
414    message.optional_double = -kMostPosExponentOneSigBit
415    message.ParseFromString(message.SerializeToString())
416    self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
417
418    # Most negative exponent, no significand bits set.
419    kMostNegExponentNoSigBits = math.pow(2, -1023)
420    message.optional_double = kMostNegExponentNoSigBits
421    message.ParseFromString(message.SerializeToString())
422    self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
423
424    # Most negative exponent, one significand bit set.
425    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
426    message.optional_double = kMostNegExponentOneSigBit
427    message.ParseFromString(message.SerializeToString())
428    self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
429
430    # Repeat last two cases with values of the same magnitude, but negative.
431    message.optional_double = -kMostNegExponentNoSigBits
432    message.ParseFromString(message.SerializeToString())
433    self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
434
435    message.optional_double = -kMostNegExponentOneSigBit
436    message.ParseFromString(message.SerializeToString())
437    self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
438
439  def testFloatPrinting(self, message_module):
440    message = message_module.TestAllTypes()
441    message.optional_float = 2.0
442    self.assertEqual(str(message), 'optional_float: 2.0\n')
443
444  def testHighPrecisionFloatPrinting(self, message_module):
445    msg = message_module.TestAllTypes()
446    msg.optional_float = 0.12345678912345678
447    old_float = msg.optional_float
448    msg.ParseFromString(msg.SerializeToString())
449    self.assertEqual(old_float, msg.optional_float)
450
451  def testHighPrecisionDoublePrinting(self, message_module):
452    msg = message_module.TestAllTypes()
453    msg.optional_double = 0.12345678912345678
454    if sys.version_info >= (3,):
455      self.assertEqual(str(msg), 'optional_double: 0.12345678912345678\n')
456    else:
457      self.assertEqual(str(msg), 'optional_double: 0.123456789123\n')
458
459  def testUnknownFieldPrinting(self, message_module):
460    populated = message_module.TestAllTypes()
461    test_util.SetAllNonLazyFields(populated)
462    empty = message_module.TestEmptyMessage()
463    empty.ParseFromString(populated.SerializeToString())
464    self.assertEqual(str(empty), '')
465
466  def testAppendRepeatedCompositeField(self, message_module):
467    msg = message_module.TestAllTypes()
468    msg.repeated_nested_message.append(
469        message_module.TestAllTypes.NestedMessage(bb=1))
470    nested = message_module.TestAllTypes.NestedMessage(bb=2)
471    msg.repeated_nested_message.append(nested)
472    try:
473      msg.repeated_nested_message.append(1)
474    except TypeError:
475      pass
476    self.assertEqual(2, len(msg.repeated_nested_message))
477    self.assertEqual([1, 2],
478                     [m.bb for m in msg.repeated_nested_message])
479
480  def testInsertRepeatedCompositeField(self, message_module):
481    msg = message_module.TestAllTypes()
482    msg.repeated_nested_message.insert(
483        -1, message_module.TestAllTypes.NestedMessage(bb=1))
484    sub_msg = msg.repeated_nested_message[0]
485    msg.repeated_nested_message.insert(
486        0, message_module.TestAllTypes.NestedMessage(bb=2))
487    msg.repeated_nested_message.insert(
488        99, message_module.TestAllTypes.NestedMessage(bb=3))
489    msg.repeated_nested_message.insert(
490        -2, message_module.TestAllTypes.NestedMessage(bb=-1))
491    msg.repeated_nested_message.insert(
492        -1000, message_module.TestAllTypes.NestedMessage(bb=-1000))
493    try:
494      msg.repeated_nested_message.insert(1, 999)
495    except TypeError:
496      pass
497    self.assertEqual(5, len(msg.repeated_nested_message))
498    self.assertEqual([-1000, 2, -1, 1, 3],
499                     [m.bb for m in msg.repeated_nested_message])
500    self.assertEqual(str(msg),
501                     'repeated_nested_message {\n'
502                     '  bb: -1000\n'
503                     '}\n'
504                     'repeated_nested_message {\n'
505                     '  bb: 2\n'
506                     '}\n'
507                     'repeated_nested_message {\n'
508                     '  bb: -1\n'
509                     '}\n'
510                     'repeated_nested_message {\n'
511                     '  bb: 1\n'
512                     '}\n'
513                     'repeated_nested_message {\n'
514                     '  bb: 3\n'
515                     '}\n')
516    self.assertEqual(sub_msg.bb, 1)
517
518  def testMergeFromRepeatedField(self, message_module):
519    msg = message_module.TestAllTypes()
520    msg.repeated_int32.append(1)
521    msg.repeated_int32.append(3)
522    msg.repeated_nested_message.add(bb=1)
523    msg.repeated_nested_message.add(bb=2)
524    other_msg = message_module.TestAllTypes()
525    other_msg.repeated_nested_message.add(bb=3)
526    other_msg.repeated_nested_message.add(bb=4)
527    other_msg.repeated_int32.append(5)
528    other_msg.repeated_int32.append(7)
529
530    msg.repeated_int32.MergeFrom(other_msg.repeated_int32)
531    self.assertEqual(4, len(msg.repeated_int32))
532
533    msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
534    self.assertEqual([1, 2, 3, 4],
535                     [m.bb for m in msg.repeated_nested_message])
536
537  def testAddWrongRepeatedNestedField(self, message_module):
538    msg = message_module.TestAllTypes()
539    try:
540      msg.repeated_nested_message.add('wrong')
541    except TypeError:
542      pass
543    try:
544      msg.repeated_nested_message.add(value_field='wrong')
545    except ValueError:
546      pass
547    self.assertEqual(len(msg.repeated_nested_message), 0)
548
549  def testRepeatedContains(self, message_module):
550    msg = message_module.TestAllTypes()
551    msg.repeated_int32.extend([1, 2, 3])
552    self.assertIn(2, msg.repeated_int32)
553    self.assertNotIn(0, msg.repeated_int32)
554
555    msg.repeated_nested_message.add(bb=1)
556    sub_msg1 = msg.repeated_nested_message[0]
557    sub_msg2 = message_module.TestAllTypes.NestedMessage(bb=2)
558    sub_msg3 = message_module.TestAllTypes.NestedMessage(bb=3)
559    msg.repeated_nested_message.append(sub_msg2)
560    msg.repeated_nested_message.insert(0, sub_msg3)
561    self.assertIn(sub_msg1, msg.repeated_nested_message)
562    self.assertIn(sub_msg2, msg.repeated_nested_message)
563    self.assertIn(sub_msg3, msg.repeated_nested_message)
564
565  def testRepeatedScalarIterable(self, message_module):
566    msg = message_module.TestAllTypes()
567    msg.repeated_int32.extend([1, 2, 3])
568    add = 0
569    for item in msg.repeated_int32:
570      add += item
571    self.assertEqual(add, 6)
572
573  def testRepeatedNestedFieldIteration(self, message_module):
574    msg = message_module.TestAllTypes()
575    msg.repeated_nested_message.add(bb=1)
576    msg.repeated_nested_message.add(bb=2)
577    msg.repeated_nested_message.add(bb=3)
578    msg.repeated_nested_message.add(bb=4)
579
580    self.assertEqual([1, 2, 3, 4],
581                     [m.bb for m in msg.repeated_nested_message])
582    self.assertEqual([4, 3, 2, 1],
583                     [m.bb for m in reversed(msg.repeated_nested_message)])
584    self.assertEqual([4, 3, 2, 1],
585                     [m.bb for m in msg.repeated_nested_message[::-1]])
586
587  def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
588    """Check some different types with the default comparator."""
589    message = message_module.TestAllTypes()
590
591    # TODO(mattp): would testing more scalar types strengthen test?
592    message.repeated_int32.append(1)
593    message.repeated_int32.append(3)
594    message.repeated_int32.append(2)
595    message.repeated_int32.sort()
596    self.assertEqual(message.repeated_int32[0], 1)
597    self.assertEqual(message.repeated_int32[1], 2)
598    self.assertEqual(message.repeated_int32[2], 3)
599    self.assertEqual(str(message.repeated_int32), str([1, 2, 3]))
600
601    message.repeated_float.append(1.1)
602    message.repeated_float.append(1.3)
603    message.repeated_float.append(1.2)
604    message.repeated_float.sort()
605    self.assertAlmostEqual(message.repeated_float[0], 1.1)
606    self.assertAlmostEqual(message.repeated_float[1], 1.2)
607    self.assertAlmostEqual(message.repeated_float[2], 1.3)
608
609    message.repeated_string.append('a')
610    message.repeated_string.append('c')
611    message.repeated_string.append('b')
612    message.repeated_string.sort()
613    self.assertEqual(message.repeated_string[0], 'a')
614    self.assertEqual(message.repeated_string[1], 'b')
615    self.assertEqual(message.repeated_string[2], 'c')
616    self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c']))
617
618    message.repeated_bytes.append(b'a')
619    message.repeated_bytes.append(b'c')
620    message.repeated_bytes.append(b'b')
621    message.repeated_bytes.sort()
622    self.assertEqual(message.repeated_bytes[0], b'a')
623    self.assertEqual(message.repeated_bytes[1], b'b')
624    self.assertEqual(message.repeated_bytes[2], b'c')
625    self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c']))
626
627  def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
628    """Check some different types with custom comparator."""
629    message = message_module.TestAllTypes()
630
631    message.repeated_int32.append(-3)
632    message.repeated_int32.append(-2)
633    message.repeated_int32.append(-1)
634    message.repeated_int32.sort(key=abs)
635    self.assertEqual(message.repeated_int32[0], -1)
636    self.assertEqual(message.repeated_int32[1], -2)
637    self.assertEqual(message.repeated_int32[2], -3)
638
639    message.repeated_string.append('aaa')
640    message.repeated_string.append('bb')
641    message.repeated_string.append('c')
642    message.repeated_string.sort(key=len)
643    self.assertEqual(message.repeated_string[0], 'c')
644    self.assertEqual(message.repeated_string[1], 'bb')
645    self.assertEqual(message.repeated_string[2], 'aaa')
646
647  def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
648    """Check passing a custom comparator to sort a repeated composite field."""
649    message = message_module.TestAllTypes()
650
651    message.repeated_nested_message.add().bb = 1
652    message.repeated_nested_message.add().bb = 3
653    message.repeated_nested_message.add().bb = 2
654    message.repeated_nested_message.add().bb = 6
655    message.repeated_nested_message.add().bb = 5
656    message.repeated_nested_message.add().bb = 4
657    message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
658    self.assertEqual(message.repeated_nested_message[0].bb, 1)
659    self.assertEqual(message.repeated_nested_message[1].bb, 2)
660    self.assertEqual(message.repeated_nested_message[2].bb, 3)
661    self.assertEqual(message.repeated_nested_message[3].bb, 4)
662    self.assertEqual(message.repeated_nested_message[4].bb, 5)
663    self.assertEqual(message.repeated_nested_message[5].bb, 6)
664    self.assertEqual(str(message.repeated_nested_message),
665                     '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
666
667  def testSortingRepeatedCompositeFieldsStable(self, message_module):
668    """Check passing a custom comparator to sort a repeated composite field."""
669    message = message_module.TestAllTypes()
670
671    message.repeated_nested_message.add().bb = 21
672    message.repeated_nested_message.add().bb = 20
673    message.repeated_nested_message.add().bb = 13
674    message.repeated_nested_message.add().bb = 33
675    message.repeated_nested_message.add().bb = 11
676    message.repeated_nested_message.add().bb = 24
677    message.repeated_nested_message.add().bb = 10
678    message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
679    self.assertEqual(
680        [13, 11, 10, 21, 20, 24, 33],
681        [n.bb for n in message.repeated_nested_message])
682
683    # Make sure that for the C++ implementation, the underlying fields
684    # are actually reordered.
685    pb = message.SerializeToString()
686    message.Clear()
687    message.MergeFromString(pb)
688    self.assertEqual(
689        [13, 11, 10, 21, 20, 24, 33],
690        [n.bb for n in message.repeated_nested_message])
691
692  def testRepeatedCompositeFieldSortArguments(self, message_module):
693    """Check sorting a repeated composite field using list.sort() arguments."""
694    message = message_module.TestAllTypes()
695
696    get_bb = operator.attrgetter('bb')
697    cmp_bb = lambda a, b: cmp(a.bb, b.bb)
698    message.repeated_nested_message.add().bb = 1
699    message.repeated_nested_message.add().bb = 3
700    message.repeated_nested_message.add().bb = 2
701    message.repeated_nested_message.add().bb = 6
702    message.repeated_nested_message.add().bb = 5
703    message.repeated_nested_message.add().bb = 4
704    message.repeated_nested_message.sort(key=get_bb)
705    self.assertEqual([k.bb for k in message.repeated_nested_message],
706                     [1, 2, 3, 4, 5, 6])
707    message.repeated_nested_message.sort(key=get_bb, reverse=True)
708    self.assertEqual([k.bb for k in message.repeated_nested_message],
709                     [6, 5, 4, 3, 2, 1])
710    if sys.version_info >= (3,): return  # No cmp sorting in PY3.
711    message.repeated_nested_message.sort(sort_function=cmp_bb)
712    self.assertEqual([k.bb for k in message.repeated_nested_message],
713                     [1, 2, 3, 4, 5, 6])
714    message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
715    self.assertEqual([k.bb for k in message.repeated_nested_message],
716                     [6, 5, 4, 3, 2, 1])
717
718  def testRepeatedScalarFieldSortArguments(self, message_module):
719    """Check sorting a scalar field using list.sort() arguments."""
720    message = message_module.TestAllTypes()
721
722    message.repeated_int32.append(-3)
723    message.repeated_int32.append(-2)
724    message.repeated_int32.append(-1)
725    message.repeated_int32.sort(key=abs)
726    self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
727    message.repeated_int32.sort(key=abs, reverse=True)
728    self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
729    if sys.version_info < (3,):  # No cmp sorting in PY3.
730      abs_cmp = lambda a, b: cmp(abs(a), abs(b))
731      message.repeated_int32.sort(sort_function=abs_cmp)
732      self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
733      message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
734      self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
735
736    message.repeated_string.append('aaa')
737    message.repeated_string.append('bb')
738    message.repeated_string.append('c')
739    message.repeated_string.sort(key=len)
740    self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
741    message.repeated_string.sort(key=len, reverse=True)
742    self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
743    if sys.version_info < (3,):  # No cmp sorting in PY3.
744      len_cmp = lambda a, b: cmp(len(a), len(b))
745      message.repeated_string.sort(sort_function=len_cmp)
746      self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
747      message.repeated_string.sort(cmp=len_cmp, reverse=True)
748      self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
749
750  def testRepeatedFieldsComparable(self, message_module):
751    m1 = message_module.TestAllTypes()
752    m2 = message_module.TestAllTypes()
753    m1.repeated_int32.append(0)
754    m1.repeated_int32.append(1)
755    m1.repeated_int32.append(2)
756    m2.repeated_int32.append(0)
757    m2.repeated_int32.append(1)
758    m2.repeated_int32.append(2)
759    m1.repeated_nested_message.add().bb = 1
760    m1.repeated_nested_message.add().bb = 2
761    m1.repeated_nested_message.add().bb = 3
762    m2.repeated_nested_message.add().bb = 1
763    m2.repeated_nested_message.add().bb = 2
764    m2.repeated_nested_message.add().bb = 3
765
766    if sys.version_info >= (3,): return  # No cmp() in PY3.
767
768    # These comparisons should not raise errors.
769    _ = m1 < m2
770    _ = m1.repeated_nested_message < m2.repeated_nested_message
771
772    # Make sure cmp always works. If it wasn't defined, these would be
773    # id() comparisons and would all fail.
774    self.assertEqual(cmp(m1, m2), 0)
775    self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
776    self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
777    self.assertEqual(cmp(m1.repeated_nested_message,
778                         m2.repeated_nested_message), 0)
779    with self.assertRaises(TypeError):
780      # Can't compare repeated composite containers to lists.
781      cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
782
783    # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
784
785  def testRepeatedFieldsAreSequences(self, message_module):
786    m = message_module.TestAllTypes()
787    self.assertIsInstance(m.repeated_int32, collections_abc.MutableSequence)
788    self.assertIsInstance(m.repeated_nested_message,
789                          collections_abc.MutableSequence)
790
791  def testRepeatedFieldsNotHashable(self, message_module):
792    m = message_module.TestAllTypes()
793    with self.assertRaises(TypeError):
794      hash(m.repeated_int32)
795    with self.assertRaises(TypeError):
796      hash(m.repeated_nested_message)
797
798  def testRepeatedFieldInsideNestedMessage(self, message_module):
799    m = message_module.NestedTestAllTypes()
800    m.payload.repeated_int32.extend([])
801    self.assertTrue(m.HasField('payload'))
802
803  def testMergeFrom(self, message_module):
804    m1 = message_module.TestAllTypes()
805    m2 = message_module.TestAllTypes()
806    # Cpp extension will lazily create a sub message which is immutable.
807    nested = m1.optional_nested_message
808    self.assertEqual(0, nested.bb)
809    m2.optional_nested_message.bb = 1
810    # Make sure cmessage pointing to a mutable message after merge instead of
811    # the lazily created message.
812    m1.MergeFrom(m2)
813    self.assertEqual(1, nested.bb)
814
815    # Test more nested sub message.
816    msg1 = message_module.NestedTestAllTypes()
817    msg2 = message_module.NestedTestAllTypes()
818    nested = msg1.child.payload.optional_nested_message
819    self.assertEqual(0, nested.bb)
820    msg2.child.payload.optional_nested_message.bb = 1
821    msg1.MergeFrom(msg2)
822    self.assertEqual(1, nested.bb)
823
824    # Test repeated field.
825    self.assertEqual(msg1.payload.repeated_nested_message,
826                     msg1.payload.repeated_nested_message)
827    nested = msg2.payload.repeated_nested_message.add()
828    nested.bb = 1
829    msg1.MergeFrom(msg2)
830    self.assertEqual(1, len(msg1.payload.repeated_nested_message))
831    self.assertEqual(1, nested.bb)
832
833  def testMergeFromString(self, message_module):
834    m1 = message_module.TestAllTypes()
835    m2 = message_module.TestAllTypes()
836    # Cpp extension will lazily create a sub message which is immutable.
837    self.assertEqual(0, m1.optional_nested_message.bb)
838    m2.optional_nested_message.bb = 1
839    # Make sure cmessage pointing to a mutable message after merge instead of
840    # the lazily created message.
841    m1.MergeFromString(m2.SerializeToString())
842    self.assertEqual(1, m1.optional_nested_message.bb)
843
844  @unittest.skipIf(six.PY2, 'memoryview objects are not supported on py2')
845  def testMergeFromStringUsingMemoryViewWorksInPy3(self, message_module):
846    m2 = message_module.TestAllTypes()
847    m2.optional_string = 'scalar string'
848    m2.repeated_string.append('repeated string')
849    m2.optional_bytes = b'scalar bytes'
850    m2.repeated_bytes.append(b'repeated bytes')
851
852    serialized = m2.SerializeToString()
853    memview = memoryview(serialized)
854    m1 = message_module.TestAllTypes.FromString(memview)
855
856    self.assertEqual(m1.optional_bytes, b'scalar bytes')
857    self.assertEqual(m1.repeated_bytes, [b'repeated bytes'])
858    self.assertEqual(m1.optional_string, 'scalar string')
859    self.assertEqual(m1.repeated_string, ['repeated string'])
860    # Make sure that the memoryview was correctly converted to bytes, and
861    # that a sub-sliced memoryview is not being used.
862    self.assertIsInstance(m1.optional_bytes, bytes)
863    self.assertIsInstance(m1.repeated_bytes[0], bytes)
864    self.assertIsInstance(m1.optional_string, six.text_type)
865    self.assertIsInstance(m1.repeated_string[0], six.text_type)
866
867  @unittest.skipIf(six.PY3, 'memoryview is supported by py3')
868  def testMergeFromStringUsingMemoryViewIsPy2Error(self, message_module):
869    memview = memoryview(b'')
870    with self.assertRaises(TypeError):
871      message_module.TestAllTypes.FromString(memview)
872
873  def testMergeFromEmpty(self, message_module):
874    m1 = message_module.TestAllTypes()
875    # Cpp extension will lazily create a sub message which is immutable.
876    self.assertEqual(0, m1.optional_nested_message.bb)
877    self.assertFalse(m1.HasField('optional_nested_message'))
878    # Make sure the sub message is still immutable after merge from empty.
879    m1.MergeFromString(b'')  # field state should not change
880    self.assertFalse(m1.HasField('optional_nested_message'))
881
882  def ensureNestedMessageExists(self, msg, attribute):
883    """Make sure that a nested message object exists.
884
885    As soon as a nested message attribute is accessed, it will be present in the
886    _fields dict, without being marked as actually being set.
887    """
888    getattr(msg, attribute)
889    self.assertFalse(msg.HasField(attribute))
890
891  def testOneofGetCaseNonexistingField(self, message_module):
892    m = message_module.TestAllTypes()
893    self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
894    self.assertRaises(Exception, m.WhichOneof, 0)
895
896  def testOneofDefaultValues(self, message_module):
897    m = message_module.TestAllTypes()
898    self.assertIs(None, m.WhichOneof('oneof_field'))
899    self.assertFalse(m.HasField('oneof_field'))
900    self.assertFalse(m.HasField('oneof_uint32'))
901
902    # Oneof is set even when setting it to a default value.
903    m.oneof_uint32 = 0
904    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
905    self.assertTrue(m.HasField('oneof_field'))
906    self.assertTrue(m.HasField('oneof_uint32'))
907    self.assertFalse(m.HasField('oneof_string'))
908
909    m.oneof_string = ""
910    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
911    self.assertTrue(m.HasField('oneof_string'))
912    self.assertFalse(m.HasField('oneof_uint32'))
913
914  def testOneofSemantics(self, message_module):
915    m = message_module.TestAllTypes()
916    self.assertIs(None, m.WhichOneof('oneof_field'))
917
918    m.oneof_uint32 = 11
919    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
920    self.assertTrue(m.HasField('oneof_uint32'))
921
922    m.oneof_string = u'foo'
923    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
924    self.assertFalse(m.HasField('oneof_uint32'))
925    self.assertTrue(m.HasField('oneof_string'))
926
927    # Read nested message accessor without accessing submessage.
928    m.oneof_nested_message
929    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
930    self.assertTrue(m.HasField('oneof_string'))
931    self.assertFalse(m.HasField('oneof_nested_message'))
932
933    # Read accessor of nested message without accessing submessage.
934    m.oneof_nested_message.bb
935    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
936    self.assertTrue(m.HasField('oneof_string'))
937    self.assertFalse(m.HasField('oneof_nested_message'))
938
939    m.oneof_nested_message.bb = 11
940    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
941    self.assertFalse(m.HasField('oneof_string'))
942    self.assertTrue(m.HasField('oneof_nested_message'))
943
944    m.oneof_bytes = b'bb'
945    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
946    self.assertFalse(m.HasField('oneof_nested_message'))
947    self.assertTrue(m.HasField('oneof_bytes'))
948
949  def testOneofCompositeFieldReadAccess(self, message_module):
950    m = message_module.TestAllTypes()
951    m.oneof_uint32 = 11
952
953    self.ensureNestedMessageExists(m, 'oneof_nested_message')
954    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
955    self.assertEqual(11, m.oneof_uint32)
956
957  def testOneofWhichOneof(self, message_module):
958    m = message_module.TestAllTypes()
959    self.assertIs(None, m.WhichOneof('oneof_field'))
960    if message_module is unittest_pb2:
961      self.assertFalse(m.HasField('oneof_field'))
962
963    m.oneof_uint32 = 11
964    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
965    if message_module is unittest_pb2:
966      self.assertTrue(m.HasField('oneof_field'))
967
968    m.oneof_bytes = b'bb'
969    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
970
971    m.ClearField('oneof_bytes')
972    self.assertIs(None, m.WhichOneof('oneof_field'))
973    if message_module is unittest_pb2:
974      self.assertFalse(m.HasField('oneof_field'))
975
976  def testOneofClearField(self, message_module):
977    m = message_module.TestAllTypes()
978    m.oneof_uint32 = 11
979    m.ClearField('oneof_field')
980    if message_module is unittest_pb2:
981      self.assertFalse(m.HasField('oneof_field'))
982    self.assertFalse(m.HasField('oneof_uint32'))
983    self.assertIs(None, m.WhichOneof('oneof_field'))
984
985  def testOneofClearSetField(self, message_module):
986    m = message_module.TestAllTypes()
987    m.oneof_uint32 = 11
988    m.ClearField('oneof_uint32')
989    if message_module is unittest_pb2:
990      self.assertFalse(m.HasField('oneof_field'))
991    self.assertFalse(m.HasField('oneof_uint32'))
992    self.assertIs(None, m.WhichOneof('oneof_field'))
993
994  def testOneofClearUnsetField(self, message_module):
995    m = message_module.TestAllTypes()
996    m.oneof_uint32 = 11
997    self.ensureNestedMessageExists(m, 'oneof_nested_message')
998    m.ClearField('oneof_nested_message')
999    self.assertEqual(11, m.oneof_uint32)
1000    if message_module is unittest_pb2:
1001      self.assertTrue(m.HasField('oneof_field'))
1002    self.assertTrue(m.HasField('oneof_uint32'))
1003    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
1004
1005  def testOneofDeserialize(self, message_module):
1006    m = message_module.TestAllTypes()
1007    m.oneof_uint32 = 11
1008    m2 = message_module.TestAllTypes()
1009    m2.ParseFromString(m.SerializeToString())
1010    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
1011
1012  def testOneofCopyFrom(self, message_module):
1013    m = message_module.TestAllTypes()
1014    m.oneof_uint32 = 11
1015    m2 = message_module.TestAllTypes()
1016    m2.CopyFrom(m)
1017    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
1018
1019  def testOneofNestedMergeFrom(self, message_module):
1020    m = message_module.NestedTestAllTypes()
1021    m.payload.oneof_uint32 = 11
1022    m2 = message_module.NestedTestAllTypes()
1023    m2.payload.oneof_bytes = b'bb'
1024    m2.child.payload.oneof_bytes = b'bb'
1025    m2.MergeFrom(m)
1026    self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
1027    self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
1028
1029  def testOneofMessageMergeFrom(self, message_module):
1030    m = message_module.NestedTestAllTypes()
1031    m.payload.oneof_nested_message.bb = 11
1032    m.child.payload.oneof_nested_message.bb = 12
1033    m2 = message_module.NestedTestAllTypes()
1034    m2.payload.oneof_uint32 = 13
1035    m2.MergeFrom(m)
1036    self.assertEqual('oneof_nested_message',
1037                     m2.payload.WhichOneof('oneof_field'))
1038    self.assertEqual('oneof_nested_message',
1039                     m2.child.payload.WhichOneof('oneof_field'))
1040
1041  def testOneofNestedMessageInit(self, message_module):
1042    m = message_module.TestAllTypes(
1043        oneof_nested_message=message_module.TestAllTypes.NestedMessage())
1044    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
1045
1046  def testOneofClear(self, message_module):
1047    m = message_module.TestAllTypes()
1048    m.oneof_uint32 = 11
1049    m.Clear()
1050    self.assertIsNone(m.WhichOneof('oneof_field'))
1051    m.oneof_bytes = b'bb'
1052    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
1053
1054  def testAssignByteStringToUnicodeField(self, message_module):
1055    """Assigning a byte string to a string field should result
1056    in the value being converted to a Unicode string."""
1057    m = message_module.TestAllTypes()
1058    m.optional_string = str('')
1059    self.assertIsInstance(m.optional_string, six.text_type)
1060
1061  def testLongValuedSlice(self, message_module):
1062    """It should be possible to use long-valued indices in slices.
1063
1064    This didn't used to work in the v2 C++ implementation.
1065    """
1066    m = message_module.TestAllTypes()
1067
1068    # Repeated scalar
1069    m.repeated_int32.append(1)
1070    sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
1071    self.assertEqual(len(m.repeated_int32), len(sl))
1072
1073    # Repeated composite
1074    m.repeated_nested_message.add().bb = 3
1075    sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
1076    self.assertEqual(len(m.repeated_nested_message), len(sl))
1077
1078  def testExtendShouldNotSwallowExceptions(self, message_module):
1079    """This didn't use to work in the v2 C++ implementation."""
1080    m = message_module.TestAllTypes()
1081    with self.assertRaises(NameError) as _:
1082      m.repeated_int32.extend(a for i in range(10))  # pylint: disable=undefined-variable
1083    with self.assertRaises(NameError) as _:
1084      m.repeated_nested_enum.extend(
1085          a for i in range(10))  # pylint: disable=undefined-variable
1086
1087  FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
1088
1089  def testExtendInt32WithNothing(self, message_module):
1090    """Test no-ops extending repeated int32 fields."""
1091    m = message_module.TestAllTypes()
1092    self.assertSequenceEqual([], m.repeated_int32)
1093
1094    # TODO(ptucker): Deprecate this behavior. b/18413862
1095    for falsy_value in MessageTest.FALSY_VALUES:
1096      m.repeated_int32.extend(falsy_value)
1097      self.assertSequenceEqual([], m.repeated_int32)
1098
1099    m.repeated_int32.extend([])
1100    self.assertSequenceEqual([], m.repeated_int32)
1101
1102  def testExtendFloatWithNothing(self, message_module):
1103    """Test no-ops extending repeated float fields."""
1104    m = message_module.TestAllTypes()
1105    self.assertSequenceEqual([], m.repeated_float)
1106
1107    # TODO(ptucker): Deprecate this behavior. b/18413862
1108    for falsy_value in MessageTest.FALSY_VALUES:
1109      m.repeated_float.extend(falsy_value)
1110      self.assertSequenceEqual([], m.repeated_float)
1111
1112    m.repeated_float.extend([])
1113    self.assertSequenceEqual([], m.repeated_float)
1114
1115  def testExtendStringWithNothing(self, message_module):
1116    """Test no-ops extending repeated string fields."""
1117    m = message_module.TestAllTypes()
1118    self.assertSequenceEqual([], m.repeated_string)
1119
1120    # TODO(ptucker): Deprecate this behavior. b/18413862
1121    for falsy_value in MessageTest.FALSY_VALUES:
1122      m.repeated_string.extend(falsy_value)
1123      self.assertSequenceEqual([], m.repeated_string)
1124
1125    m.repeated_string.extend([])
1126    self.assertSequenceEqual([], m.repeated_string)
1127
1128  def testExtendInt32WithPythonList(self, message_module):
1129    """Test extending repeated int32 fields with python lists."""
1130    m = message_module.TestAllTypes()
1131    self.assertSequenceEqual([], m.repeated_int32)
1132    m.repeated_int32.extend([0])
1133    self.assertSequenceEqual([0], m.repeated_int32)
1134    m.repeated_int32.extend([1, 2])
1135    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
1136    m.repeated_int32.extend([3, 4])
1137    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
1138
1139  def testExtendFloatWithPythonList(self, message_module):
1140    """Test extending repeated float fields with python lists."""
1141    m = message_module.TestAllTypes()
1142    self.assertSequenceEqual([], m.repeated_float)
1143    m.repeated_float.extend([0.0])
1144    self.assertSequenceEqual([0.0], m.repeated_float)
1145    m.repeated_float.extend([1.0, 2.0])
1146    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
1147    m.repeated_float.extend([3.0, 4.0])
1148    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
1149
1150  def testExtendStringWithPythonList(self, message_module):
1151    """Test extending repeated string fields with python lists."""
1152    m = message_module.TestAllTypes()
1153    self.assertSequenceEqual([], m.repeated_string)
1154    m.repeated_string.extend([''])
1155    self.assertSequenceEqual([''], m.repeated_string)
1156    m.repeated_string.extend(['11', '22'])
1157    self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
1158    m.repeated_string.extend(['33', '44'])
1159    self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
1160
1161  def testExtendStringWithString(self, message_module):
1162    """Test extending repeated string fields with characters from a string."""
1163    m = message_module.TestAllTypes()
1164    self.assertSequenceEqual([], m.repeated_string)
1165    m.repeated_string.extend('abc')
1166    self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
1167
1168  class TestIterable(object):
1169    """This iterable object mimics the behavior of numpy.array.
1170
1171    __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
1172
1173    """
1174
1175    def __init__(self, values=None):
1176      self._list = values or []
1177
1178    def __nonzero__(self):
1179      size = len(self._list)
1180      if size == 0:
1181        return False
1182      if size == 1:
1183        return bool(self._list[0])
1184      raise ValueError('Truth value is ambiguous.')
1185
1186    def __len__(self):
1187      return len(self._list)
1188
1189    def __iter__(self):
1190      return self._list.__iter__()
1191
1192  def testExtendInt32WithIterable(self, message_module):
1193    """Test extending repeated int32 fields with iterable."""
1194    m = message_module.TestAllTypes()
1195    self.assertSequenceEqual([], m.repeated_int32)
1196    m.repeated_int32.extend(MessageTest.TestIterable([]))
1197    self.assertSequenceEqual([], m.repeated_int32)
1198    m.repeated_int32.extend(MessageTest.TestIterable([0]))
1199    self.assertSequenceEqual([0], m.repeated_int32)
1200    m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
1201    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
1202    m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
1203    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
1204
1205  def testExtendFloatWithIterable(self, message_module):
1206    """Test extending repeated float fields with iterable."""
1207    m = message_module.TestAllTypes()
1208    self.assertSequenceEqual([], m.repeated_float)
1209    m.repeated_float.extend(MessageTest.TestIterable([]))
1210    self.assertSequenceEqual([], m.repeated_float)
1211    m.repeated_float.extend(MessageTest.TestIterable([0.0]))
1212    self.assertSequenceEqual([0.0], m.repeated_float)
1213    m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
1214    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
1215    m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
1216    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
1217
1218  def testExtendStringWithIterable(self, message_module):
1219    """Test extending repeated string fields with iterable."""
1220    m = message_module.TestAllTypes()
1221    self.assertSequenceEqual([], m.repeated_string)
1222    m.repeated_string.extend(MessageTest.TestIterable([]))
1223    self.assertSequenceEqual([], m.repeated_string)
1224    m.repeated_string.extend(MessageTest.TestIterable(['']))
1225    self.assertSequenceEqual([''], m.repeated_string)
1226    m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
1227    self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
1228    m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
1229    self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
1230
1231  def testPickleRepeatedScalarContainer(self, message_module):
1232    # TODO(tibell): The pure-Python implementation support pickling of
1233    #   scalar containers in *some* cases. For now the cpp2 version
1234    #   throws an exception to avoid a segfault. Investigate if we
1235    #   want to support pickling of these fields.
1236    #
1237    # For more information see: https://b2.corp.google.com/u/0/issues/18677897
1238    if (api_implementation.Type() != 'cpp' or
1239        api_implementation.Version() == 2):
1240      return
1241    m = message_module.TestAllTypes()
1242    with self.assertRaises(pickle.PickleError) as _:
1243      pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
1244
1245  def testSortEmptyRepeatedCompositeContainer(self, message_module):
1246    """Exercise a scenario that has led to segfaults in the past.
1247    """
1248    m = message_module.TestAllTypes()
1249    m.repeated_nested_message.sort()
1250
1251  def testHasFieldOnRepeatedField(self, message_module):
1252    """Using HasField on a repeated field should raise an exception.
1253    """
1254    m = message_module.TestAllTypes()
1255    with self.assertRaises(ValueError) as _:
1256      m.HasField('repeated_int32')
1257
1258  def testRepeatedScalarFieldPop(self, message_module):
1259    m = message_module.TestAllTypes()
1260    with self.assertRaises(IndexError) as _:
1261      m.repeated_int32.pop()
1262    m.repeated_int32.extend(range(5))
1263    self.assertEqual(4, m.repeated_int32.pop())
1264    self.assertEqual(0, m.repeated_int32.pop(0))
1265    self.assertEqual(2, m.repeated_int32.pop(1))
1266    self.assertEqual([1, 3], m.repeated_int32)
1267
1268  def testRepeatedCompositeFieldPop(self, message_module):
1269    m = message_module.TestAllTypes()
1270    with self.assertRaises(IndexError) as _:
1271      m.repeated_nested_message.pop()
1272    with self.assertRaises(TypeError) as _:
1273      m.repeated_nested_message.pop('0')
1274    for i in range(5):
1275      n = m.repeated_nested_message.add()
1276      n.bb = i
1277    self.assertEqual(4, m.repeated_nested_message.pop().bb)
1278    self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
1279    self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
1280    self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
1281
1282  def testRepeatedCompareWithSelf(self, message_module):
1283    m = message_module.TestAllTypes()
1284    for i in range(5):
1285      m.repeated_int32.insert(i, i)
1286      n = m.repeated_nested_message.add()
1287      n.bb = i
1288    self.assertSequenceEqual(m.repeated_int32, m.repeated_int32)
1289    self.assertEqual(m.repeated_nested_message, m.repeated_nested_message)
1290
1291  def testReleasedNestedMessages(self, message_module):
1292    """A case that lead to a segfault when a message detached from its parent
1293    container has itself a child container.
1294    """
1295    m = message_module.NestedTestAllTypes()
1296    m = m.repeated_child.add()
1297    m = m.child
1298    m = m.repeated_child.add()
1299    self.assertEqual(m.payload.optional_int32, 0)
1300
1301  def testSetRepeatedComposite(self, message_module):
1302    m = message_module.TestAllTypes()
1303    with self.assertRaises(AttributeError):
1304      m.repeated_int32 = []
1305    m.repeated_int32.append(1)
1306    with self.assertRaises(AttributeError):
1307      m.repeated_int32 = []
1308
1309  def testReturningType(self, message_module):
1310    m = message_module.TestAllTypes()
1311    self.assertEqual(float, type(m.optional_float))
1312    self.assertEqual(float, type(m.optional_double))
1313    self.assertEqual(bool, type(m.optional_bool))
1314    m.optional_float = 1
1315    m.optional_double = 1
1316    m.optional_bool = 1
1317    m.repeated_float.append(1)
1318    m.repeated_double.append(1)
1319    m.repeated_bool.append(1)
1320    m.ParseFromString(m.SerializeToString())
1321    self.assertEqual(float, type(m.optional_float))
1322    self.assertEqual(float, type(m.optional_double))
1323    self.assertEqual('1.0', str(m.optional_double))
1324    self.assertEqual(bool, type(m.optional_bool))
1325    self.assertEqual(float, type(m.repeated_float[0]))
1326    self.assertEqual(float, type(m.repeated_double[0]))
1327    self.assertEqual(bool, type(m.repeated_bool[0]))
1328    self.assertEqual(True, m.repeated_bool[0])
1329
1330
1331# Class to test proto2-only features (required, extensions, etc.)
1332@testing_refleaks.TestCase
1333class Proto2Test(unittest.TestCase):
1334
1335  def testFieldPresence(self):
1336    message = unittest_pb2.TestAllTypes()
1337
1338    self.assertFalse(message.HasField("optional_int32"))
1339    self.assertFalse(message.HasField("optional_bool"))
1340    self.assertFalse(message.HasField("optional_nested_message"))
1341
1342    with self.assertRaises(ValueError):
1343      message.HasField("field_doesnt_exist")
1344
1345    with self.assertRaises(ValueError):
1346      message.HasField("repeated_int32")
1347    with self.assertRaises(ValueError):
1348      message.HasField("repeated_nested_message")
1349
1350    self.assertEqual(0, message.optional_int32)
1351    self.assertEqual(False, message.optional_bool)
1352    self.assertEqual(0, message.optional_nested_message.bb)
1353
1354    # Fields are set even when setting the values to default values.
1355    message.optional_int32 = 0
1356    message.optional_bool = False
1357    message.optional_nested_message.bb = 0
1358    self.assertTrue(message.HasField("optional_int32"))
1359    self.assertTrue(message.HasField("optional_bool"))
1360    self.assertTrue(message.HasField("optional_nested_message"))
1361
1362    # Set the fields to non-default values.
1363    message.optional_int32 = 5
1364    message.optional_bool = True
1365    message.optional_nested_message.bb = 15
1366
1367    self.assertTrue(message.HasField(u"optional_int32"))
1368    self.assertTrue(message.HasField("optional_bool"))
1369    self.assertTrue(message.HasField("optional_nested_message"))
1370
1371    # Clearing the fields unsets them and resets their value to default.
1372    message.ClearField("optional_int32")
1373    message.ClearField(u"optional_bool")
1374    message.ClearField("optional_nested_message")
1375
1376    self.assertFalse(message.HasField("optional_int32"))
1377    self.assertFalse(message.HasField("optional_bool"))
1378    self.assertFalse(message.HasField("optional_nested_message"))
1379    self.assertEqual(0, message.optional_int32)
1380    self.assertEqual(False, message.optional_bool)
1381    self.assertEqual(0, message.optional_nested_message.bb)
1382
1383  def testAssignInvalidEnum(self):
1384    """Assigning an invalid enum number is not allowed in proto2."""
1385    m = unittest_pb2.TestAllTypes()
1386
1387    # Proto2 can not assign unknown enum.
1388    with self.assertRaises(ValueError) as _:
1389      m.optional_nested_enum = 1234567
1390    self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
1391    # Assignment is a different code path than append for the C++ impl.
1392    m.repeated_nested_enum.append(2)
1393    m.repeated_nested_enum[0] = 2
1394    with self.assertRaises(ValueError):
1395      m.repeated_nested_enum[0] = 123456
1396
1397    # Unknown enum value can be parsed but is ignored.
1398    m2 = unittest_proto3_arena_pb2.TestAllTypes()
1399    m2.optional_nested_enum = 1234567
1400    m2.repeated_nested_enum.append(7654321)
1401    serialized = m2.SerializeToString()
1402
1403    m3 = unittest_pb2.TestAllTypes()
1404    m3.ParseFromString(serialized)
1405    self.assertFalse(m3.HasField('optional_nested_enum'))
1406    # 1 is the default value for optional_nested_enum.
1407    self.assertEqual(1, m3.optional_nested_enum)
1408    self.assertEqual(0, len(m3.repeated_nested_enum))
1409    m2.Clear()
1410    m2.ParseFromString(m3.SerializeToString())
1411    self.assertEqual(1234567, m2.optional_nested_enum)
1412    self.assertEqual(7654321, m2.repeated_nested_enum[0])
1413
1414  def testUnknownEnumMap(self):
1415    m = map_proto2_unittest_pb2.TestEnumMap()
1416    m.known_map_field[123] = 0
1417    with self.assertRaises(ValueError):
1418      m.unknown_map_field[1] = 123
1419
1420  def testExtensionsErrors(self):
1421    msg = unittest_pb2.TestAllTypes()
1422    self.assertRaises(AttributeError, getattr, msg, 'Extensions')
1423
1424  def testMergeFromExtensions(self):
1425    msg1 = more_extensions_pb2.TopLevelMessage()
1426    msg2 = more_extensions_pb2.TopLevelMessage()
1427    # Cpp extension will lazily create a sub message which is immutable.
1428    self.assertEqual(0, msg1.submessage.Extensions[
1429        more_extensions_pb2.optional_int_extension])
1430    self.assertFalse(msg1.HasField('submessage'))
1431    msg2.submessage.Extensions[
1432        more_extensions_pb2.optional_int_extension] = 123
1433    # Make sure cmessage and extensions pointing to a mutable message
1434    # after merge instead of the lazily created message.
1435    msg1.MergeFrom(msg2)
1436    self.assertEqual(123, msg1.submessage.Extensions[
1437        more_extensions_pb2.optional_int_extension])
1438
1439  def testGoldenExtensions(self):
1440    golden_data = test_util.GoldenFileData('golden_message')
1441    golden_message = unittest_pb2.TestAllExtensions()
1442    golden_message.ParseFromString(golden_data)
1443    all_set = unittest_pb2.TestAllExtensions()
1444    test_util.SetAllExtensions(all_set)
1445    self.assertEqual(all_set, golden_message)
1446    self.assertEqual(golden_data, golden_message.SerializeToString())
1447    golden_copy = copy.deepcopy(golden_message)
1448    self.assertEqual(golden_data, golden_copy.SerializeToString())
1449
1450  def testGoldenPackedExtensions(self):
1451    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
1452    golden_message = unittest_pb2.TestPackedExtensions()
1453    golden_message.ParseFromString(golden_data)
1454    all_set = unittest_pb2.TestPackedExtensions()
1455    test_util.SetAllPackedExtensions(all_set)
1456    self.assertEqual(all_set, golden_message)
1457    self.assertEqual(golden_data, all_set.SerializeToString())
1458    golden_copy = copy.deepcopy(golden_message)
1459    self.assertEqual(golden_data, golden_copy.SerializeToString())
1460
1461  def testPickleIncompleteProto(self):
1462    golden_message = unittest_pb2.TestRequired(a=1)
1463    pickled_message = pickle.dumps(golden_message)
1464
1465    unpickled_message = pickle.loads(pickled_message)
1466    self.assertEqual(unpickled_message, golden_message)
1467    self.assertEqual(unpickled_message.a, 1)
1468    # This is still an incomplete proto - so serializing should fail
1469    self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
1470
1471
1472  # TODO(haberman): this isn't really a proto2-specific test except that this
1473  # message has a required field in it.  Should probably be factored out so
1474  # that we can test the other parts with proto3.
1475  def testParsingMerge(self):
1476    """Check the merge behavior when a required or optional field appears
1477    multiple times in the input."""
1478    messages = [
1479        unittest_pb2.TestAllTypes(),
1480        unittest_pb2.TestAllTypes(),
1481        unittest_pb2.TestAllTypes() ]
1482    messages[0].optional_int32 = 1
1483    messages[1].optional_int64 = 2
1484    messages[2].optional_int32 = 3
1485    messages[2].optional_string = 'hello'
1486
1487    merged_message = unittest_pb2.TestAllTypes()
1488    merged_message.optional_int32 = 3
1489    merged_message.optional_int64 = 2
1490    merged_message.optional_string = 'hello'
1491
1492    generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
1493    generator.field1.extend(messages)
1494    generator.field2.extend(messages)
1495    generator.field3.extend(messages)
1496    generator.ext1.extend(messages)
1497    generator.ext2.extend(messages)
1498    generator.group1.add().field1.MergeFrom(messages[0])
1499    generator.group1.add().field1.MergeFrom(messages[1])
1500    generator.group1.add().field1.MergeFrom(messages[2])
1501    generator.group2.add().field1.MergeFrom(messages[0])
1502    generator.group2.add().field1.MergeFrom(messages[1])
1503    generator.group2.add().field1.MergeFrom(messages[2])
1504
1505    data = generator.SerializeToString()
1506    parsing_merge = unittest_pb2.TestParsingMerge()
1507    parsing_merge.ParseFromString(data)
1508
1509    # Required and optional fields should be merged.
1510    self.assertEqual(parsing_merge.required_all_types, merged_message)
1511    self.assertEqual(parsing_merge.optional_all_types, merged_message)
1512    self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
1513                     merged_message)
1514    self.assertEqual(parsing_merge.Extensions[
1515                     unittest_pb2.TestParsingMerge.optional_ext],
1516                     merged_message)
1517
1518    # Repeated fields should not be merged.
1519    self.assertEqual(len(parsing_merge.repeated_all_types), 3)
1520    self.assertEqual(len(parsing_merge.repeatedgroup), 3)
1521    self.assertEqual(len(parsing_merge.Extensions[
1522        unittest_pb2.TestParsingMerge.repeated_ext]), 3)
1523
1524  def testPythonicInit(self):
1525    message = unittest_pb2.TestAllTypes(
1526        optional_int32=100,
1527        optional_fixed32=200,
1528        optional_float=300.5,
1529        optional_bytes=b'x',
1530        optionalgroup={'a': 400},
1531        optional_nested_message={'bb': 500},
1532        optional_foreign_message={},
1533        optional_nested_enum='BAZ',
1534        repeatedgroup=[{'a': 600},
1535                       {'a': 700}],
1536        repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
1537        default_int32=800,
1538        oneof_string='y')
1539    self.assertIsInstance(message, unittest_pb2.TestAllTypes)
1540    self.assertEqual(100, message.optional_int32)
1541    self.assertEqual(200, message.optional_fixed32)
1542    self.assertEqual(300.5, message.optional_float)
1543    self.assertEqual(b'x', message.optional_bytes)
1544    self.assertEqual(400, message.optionalgroup.a)
1545    self.assertIsInstance(message.optional_nested_message,
1546                          unittest_pb2.TestAllTypes.NestedMessage)
1547    self.assertEqual(500, message.optional_nested_message.bb)
1548    self.assertTrue(message.HasField('optional_foreign_message'))
1549    self.assertEqual(message.optional_foreign_message,
1550                     unittest_pb2.ForeignMessage())
1551    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1552                     message.optional_nested_enum)
1553    self.assertEqual(2, len(message.repeatedgroup))
1554    self.assertEqual(600, message.repeatedgroup[0].a)
1555    self.assertEqual(700, message.repeatedgroup[1].a)
1556    self.assertEqual(2, len(message.repeated_nested_enum))
1557    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
1558                     message.repeated_nested_enum[0])
1559    self.assertEqual(unittest_pb2.TestAllTypes.BAR,
1560                     message.repeated_nested_enum[1])
1561    self.assertEqual(800, message.default_int32)
1562    self.assertEqual('y', message.oneof_string)
1563    self.assertFalse(message.HasField('optional_int64'))
1564    self.assertEqual(0, len(message.repeated_float))
1565    self.assertEqual(42, message.default_int64)
1566
1567    message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
1568    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1569                     message.optional_nested_enum)
1570
1571    with self.assertRaises(ValueError):
1572      unittest_pb2.TestAllTypes(
1573          optional_nested_message={'INVALID_NESTED_FIELD': 17})
1574
1575    with self.assertRaises(TypeError):
1576      unittest_pb2.TestAllTypes(
1577          optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
1578
1579    with self.assertRaises(ValueError):
1580      unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
1581
1582    with self.assertRaises(ValueError):
1583      unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
1584
1585  def testPythonicInitWithDict(self):
1586    # Both string/unicode field name keys should work.
1587    kwargs = {
1588        'optional_int32': 100,
1589        u'optional_fixed32': 200,
1590    }
1591    msg = unittest_pb2.TestAllTypes(**kwargs)
1592    self.assertEqual(100, msg.optional_int32)
1593    self.assertEqual(200, msg.optional_fixed32)
1594
1595
1596  def test_documentation(self):
1597    # Also used by the interactive help() function.
1598    doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message')
1599    self.assertIn('class TestAllTypes', doc)
1600    self.assertIn('SerializePartialToString', doc)
1601    self.assertIn('repeated_float', doc)
1602    base = unittest_pb2.TestAllTypes.__bases__[0]
1603    self.assertRaises(AttributeError, getattr, base, '_extensions_by_name')
1604
1605
1606# Class to test proto3-only features/behavior (updated field presence & enums)
1607@testing_refleaks.TestCase
1608class Proto3Test(unittest.TestCase):
1609
1610  # Utility method for comparing equality with a map.
1611  def assertMapIterEquals(self, map_iter, dict_value):
1612    # Avoid mutating caller's copy.
1613    dict_value = dict(dict_value)
1614
1615    for k, v in map_iter:
1616      self.assertEqual(v, dict_value[k])
1617      del dict_value[k]
1618
1619    self.assertEqual({}, dict_value)
1620
1621  def testFieldPresence(self):
1622    message = unittest_proto3_arena_pb2.TestAllTypes()
1623
1624    # We can't test presence of non-repeated, non-submessage fields.
1625    with self.assertRaises(ValueError):
1626      message.HasField('optional_int32')
1627    with self.assertRaises(ValueError):
1628      message.HasField('optional_float')
1629    with self.assertRaises(ValueError):
1630      message.HasField('optional_string')
1631    with self.assertRaises(ValueError):
1632      message.HasField('optional_bool')
1633
1634    # But we can still test presence of submessage fields.
1635    self.assertFalse(message.HasField('optional_nested_message'))
1636
1637    # As with proto2, we can't test presence of fields that don't exist, or
1638    # repeated fields.
1639    with self.assertRaises(ValueError):
1640      message.HasField('field_doesnt_exist')
1641
1642    with self.assertRaises(ValueError):
1643      message.HasField('repeated_int32')
1644    with self.assertRaises(ValueError):
1645      message.HasField('repeated_nested_message')
1646
1647    # Fields should default to their type-specific default.
1648    self.assertEqual(0, message.optional_int32)
1649    self.assertEqual(0, message.optional_float)
1650    self.assertEqual('', message.optional_string)
1651    self.assertEqual(False, message.optional_bool)
1652    self.assertEqual(0, message.optional_nested_message.bb)
1653
1654    # Setting a submessage should still return proper presence information.
1655    message.optional_nested_message.bb = 0
1656    self.assertTrue(message.HasField('optional_nested_message'))
1657
1658    # Set the fields to non-default values.
1659    message.optional_int32 = 5
1660    message.optional_float = 1.1
1661    message.optional_string = 'abc'
1662    message.optional_bool = True
1663    message.optional_nested_message.bb = 15
1664
1665    # Clearing the fields unsets them and resets their value to default.
1666    message.ClearField('optional_int32')
1667    message.ClearField('optional_float')
1668    message.ClearField('optional_string')
1669    message.ClearField('optional_bool')
1670    message.ClearField('optional_nested_message')
1671
1672    self.assertEqual(0, message.optional_int32)
1673    self.assertEqual(0, message.optional_float)
1674    self.assertEqual('', message.optional_string)
1675    self.assertEqual(False, message.optional_bool)
1676    self.assertEqual(0, message.optional_nested_message.bb)
1677
1678  def testProto3ParserDropDefaultScalar(self):
1679    message_proto2 = unittest_pb2.TestAllTypes()
1680    message_proto2.optional_int32 = 0
1681    message_proto2.optional_string = ''
1682    message_proto2.optional_bytes = b''
1683    self.assertEqual(len(message_proto2.ListFields()), 3)
1684
1685    message_proto3 = unittest_proto3_arena_pb2.TestAllTypes()
1686    message_proto3.ParseFromString(message_proto2.SerializeToString())
1687    self.assertEqual(len(message_proto3.ListFields()), 0)
1688
1689  def testProto3Optional(self):
1690    msg = test_proto3_optional_pb2.TestProto3Optional()
1691    self.assertFalse(msg.HasField('optional_int32'))
1692    self.assertFalse(msg.HasField('optional_float'))
1693    self.assertFalse(msg.HasField('optional_string'))
1694    self.assertFalse(msg.HasField('optional_nested_message'))
1695    self.assertFalse(msg.optional_nested_message.HasField('bb'))
1696
1697    # Set fields.
1698    msg.optional_int32 = 1
1699    msg.optional_float = 1.0
1700    msg.optional_string = '123'
1701    msg.optional_nested_message.bb = 1
1702    self.assertTrue(msg.HasField('optional_int32'))
1703    self.assertTrue(msg.HasField('optional_float'))
1704    self.assertTrue(msg.HasField('optional_string'))
1705    self.assertTrue(msg.HasField('optional_nested_message'))
1706    self.assertTrue(msg.optional_nested_message.HasField('bb'))
1707    # Set to default value does not clear the fields
1708    msg.optional_int32 = 0
1709    msg.optional_float = 0.0
1710    msg.optional_string = ''
1711    msg.optional_nested_message.bb = 0
1712    self.assertTrue(msg.HasField('optional_int32'))
1713    self.assertTrue(msg.HasField('optional_float'))
1714    self.assertTrue(msg.HasField('optional_string'))
1715    self.assertTrue(msg.HasField('optional_nested_message'))
1716    self.assertTrue(msg.optional_nested_message.HasField('bb'))
1717
1718    # Test serialize
1719    msg2 = test_proto3_optional_pb2.TestProto3Optional()
1720    msg2.ParseFromString(msg.SerializeToString())
1721    self.assertTrue(msg2.HasField('optional_int32'))
1722    self.assertTrue(msg2.HasField('optional_float'))
1723    self.assertTrue(msg2.HasField('optional_string'))
1724    self.assertTrue(msg2.HasField('optional_nested_message'))
1725    self.assertTrue(msg2.optional_nested_message.HasField('bb'))
1726
1727    self.assertEqual(msg.WhichOneof('_optional_int32'), 'optional_int32')
1728
1729    # Clear these fields.
1730    msg.ClearField('optional_int32')
1731    msg.ClearField('optional_float')
1732    msg.ClearField('optional_string')
1733    msg.ClearField('optional_nested_message')
1734    self.assertFalse(msg.HasField('optional_int32'))
1735    self.assertFalse(msg.HasField('optional_float'))
1736    self.assertFalse(msg.HasField('optional_string'))
1737    self.assertFalse(msg.HasField('optional_nested_message'))
1738    self.assertFalse(msg.optional_nested_message.HasField('bb'))
1739
1740    self.assertEqual(msg.WhichOneof('_optional_int32'), None)
1741
1742  def testAssignUnknownEnum(self):
1743    """Assigning an unknown enum value is allowed and preserves the value."""
1744    m = unittest_proto3_arena_pb2.TestAllTypes()
1745
1746    # Proto3 can assign unknown enums.
1747    m.optional_nested_enum = 1234567
1748    self.assertEqual(1234567, m.optional_nested_enum)
1749    m.repeated_nested_enum.append(22334455)
1750    self.assertEqual(22334455, m.repeated_nested_enum[0])
1751    # Assignment is a different code path than append for the C++ impl.
1752    m.repeated_nested_enum[0] = 7654321
1753    self.assertEqual(7654321, m.repeated_nested_enum[0])
1754    serialized = m.SerializeToString()
1755
1756    m2 = unittest_proto3_arena_pb2.TestAllTypes()
1757    m2.ParseFromString(serialized)
1758    self.assertEqual(1234567, m2.optional_nested_enum)
1759    self.assertEqual(7654321, m2.repeated_nested_enum[0])
1760
1761  # Map isn't really a proto3-only feature. But there is no proto2 equivalent
1762  # of google/protobuf/map_unittest.proto right now, so it's not easy to
1763  # test both with the same test like we do for the other proto2/proto3 tests.
1764  # (google/protobuf/map_proto2_unittest.proto is very different in the set
1765  # of messages and fields it contains).
1766  def testScalarMapDefaults(self):
1767    msg = map_unittest_pb2.TestMap()
1768
1769    # Scalars start out unset.
1770    self.assertFalse(-123 in msg.map_int32_int32)
1771    self.assertFalse(-2**33 in msg.map_int64_int64)
1772    self.assertFalse(123 in msg.map_uint32_uint32)
1773    self.assertFalse(2**33 in msg.map_uint64_uint64)
1774    self.assertFalse(123 in msg.map_int32_double)
1775    self.assertFalse(False in msg.map_bool_bool)
1776    self.assertFalse('abc' in msg.map_string_string)
1777    self.assertFalse(111 in msg.map_int32_bytes)
1778    self.assertFalse(888 in msg.map_int32_enum)
1779
1780    # Accessing an unset key returns the default.
1781    self.assertEqual(0, msg.map_int32_int32[-123])
1782    self.assertEqual(0, msg.map_int64_int64[-2**33])
1783    self.assertEqual(0, msg.map_uint32_uint32[123])
1784    self.assertEqual(0, msg.map_uint64_uint64[2**33])
1785    self.assertEqual(0.0, msg.map_int32_double[123])
1786    self.assertTrue(isinstance(msg.map_int32_double[123], float))
1787    self.assertEqual(False, msg.map_bool_bool[False])
1788    self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
1789    self.assertEqual('', msg.map_string_string['abc'])
1790    self.assertEqual(b'', msg.map_int32_bytes[111])
1791    self.assertEqual(0, msg.map_int32_enum[888])
1792
1793    # It also sets the value in the map
1794    self.assertTrue(-123 in msg.map_int32_int32)
1795    self.assertTrue(-2**33 in msg.map_int64_int64)
1796    self.assertTrue(123 in msg.map_uint32_uint32)
1797    self.assertTrue(2**33 in msg.map_uint64_uint64)
1798    self.assertTrue(123 in msg.map_int32_double)
1799    self.assertTrue(False in msg.map_bool_bool)
1800    self.assertTrue('abc' in msg.map_string_string)
1801    self.assertTrue(111 in msg.map_int32_bytes)
1802    self.assertTrue(888 in msg.map_int32_enum)
1803
1804    self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
1805
1806    # Accessing an unset key still throws TypeError if the type of the key
1807    # is incorrect.
1808    with self.assertRaises(TypeError):
1809      msg.map_string_string[123]
1810
1811    with self.assertRaises(TypeError):
1812      123 in msg.map_string_string
1813
1814  def testMapGet(self):
1815    # Need to test that get() properly returns the default, even though the dict
1816    # has defaultdict-like semantics.
1817    msg = map_unittest_pb2.TestMap()
1818
1819    self.assertIsNone(msg.map_int32_int32.get(5))
1820    self.assertEqual(10, msg.map_int32_int32.get(5, 10))
1821    self.assertEqual(10, msg.map_int32_int32.get(key=5, default=10))
1822    self.assertIsNone(msg.map_int32_int32.get(5))
1823
1824    msg.map_int32_int32[5] = 15
1825    self.assertEqual(15, msg.map_int32_int32.get(5))
1826    self.assertEqual(15, msg.map_int32_int32.get(5))
1827    with self.assertRaises(TypeError):
1828      msg.map_int32_int32.get('')
1829
1830    self.assertIsNone(msg.map_int32_foreign_message.get(5))
1831    self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
1832    self.assertEqual(10, msg.map_int32_foreign_message.get(key=5, default=10))
1833
1834    submsg = msg.map_int32_foreign_message[5]
1835    self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
1836    with self.assertRaises(TypeError):
1837      msg.map_int32_foreign_message.get('')
1838
1839  def testScalarMap(self):
1840    msg = map_unittest_pb2.TestMap()
1841
1842    self.assertEqual(0, len(msg.map_int32_int32))
1843    self.assertFalse(5 in msg.map_int32_int32)
1844
1845    msg.map_int32_int32[-123] = -456
1846    msg.map_int64_int64[-2**33] = -2**34
1847    msg.map_uint32_uint32[123] = 456
1848    msg.map_uint64_uint64[2**33] = 2**34
1849    msg.map_int32_float[2] = 1.2
1850    msg.map_int32_double[1] = 3.3
1851    msg.map_string_string['abc'] = '123'
1852    msg.map_bool_bool[True] = True
1853    msg.map_int32_enum[888] = 2
1854    # Unknown numeric enum is supported in proto3.
1855    msg.map_int32_enum[123] = 456
1856
1857    self.assertEqual([], msg.FindInitializationErrors())
1858
1859    self.assertEqual(1, len(msg.map_string_string))
1860
1861    # Bad key.
1862    with self.assertRaises(TypeError):
1863      msg.map_string_string[123] = '123'
1864
1865    # Verify that trying to assign a bad key doesn't actually add a member to
1866    # the map.
1867    self.assertEqual(1, len(msg.map_string_string))
1868
1869    # Bad value.
1870    with self.assertRaises(TypeError):
1871      msg.map_string_string['123'] = 123
1872
1873    serialized = msg.SerializeToString()
1874    msg2 = map_unittest_pb2.TestMap()
1875    msg2.ParseFromString(serialized)
1876
1877    # Bad key.
1878    with self.assertRaises(TypeError):
1879      msg2.map_string_string[123] = '123'
1880
1881    # Bad value.
1882    with self.assertRaises(TypeError):
1883      msg2.map_string_string['123'] = 123
1884
1885    self.assertEqual(-456, msg2.map_int32_int32[-123])
1886    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
1887    self.assertEqual(456, msg2.map_uint32_uint32[123])
1888    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
1889    self.assertAlmostEqual(1.2, msg.map_int32_float[2])
1890    self.assertEqual(3.3, msg.map_int32_double[1])
1891    self.assertEqual('123', msg2.map_string_string['abc'])
1892    self.assertEqual(True, msg2.map_bool_bool[True])
1893    self.assertEqual(2, msg2.map_int32_enum[888])
1894    self.assertEqual(456, msg2.map_int32_enum[123])
1895    self.assertEqual('{-123: -456}',
1896                     str(msg2.map_int32_int32))
1897
1898  def testMapEntryAlwaysSerialized(self):
1899    msg = map_unittest_pb2.TestMap()
1900    msg.map_int32_int32[0] = 0
1901    msg.map_string_string[''] = ''
1902    self.assertEqual(msg.ByteSize(), 12)
1903    self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00',
1904                     msg.SerializeToString())
1905
1906  def testStringUnicodeConversionInMap(self):
1907    msg = map_unittest_pb2.TestMap()
1908
1909    unicode_obj = u'\u1234'
1910    bytes_obj = unicode_obj.encode('utf8')
1911
1912    msg.map_string_string[bytes_obj] = bytes_obj
1913
1914    (key, value) = list(msg.map_string_string.items())[0]
1915
1916    self.assertEqual(key, unicode_obj)
1917    self.assertEqual(value, unicode_obj)
1918
1919    self.assertIsInstance(key, six.text_type)
1920    self.assertIsInstance(value, six.text_type)
1921
1922  def testMessageMap(self):
1923    msg = map_unittest_pb2.TestMap()
1924
1925    self.assertEqual(0, len(msg.map_int32_foreign_message))
1926    self.assertFalse(5 in msg.map_int32_foreign_message)
1927
1928    msg.map_int32_foreign_message[123]
1929    # get_or_create() is an alias for getitem.
1930    msg.map_int32_foreign_message.get_or_create(-456)
1931
1932    self.assertEqual(2, len(msg.map_int32_foreign_message))
1933    self.assertIn(123, msg.map_int32_foreign_message)
1934    self.assertIn(-456, msg.map_int32_foreign_message)
1935    self.assertEqual(2, len(msg.map_int32_foreign_message))
1936
1937    # Bad key.
1938    with self.assertRaises(TypeError):
1939      msg.map_int32_foreign_message['123']
1940
1941    # Can't assign directly to submessage.
1942    with self.assertRaises(ValueError):
1943      msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
1944
1945    # Verify that trying to assign a bad key doesn't actually add a member to
1946    # the map.
1947    self.assertEqual(2, len(msg.map_int32_foreign_message))
1948
1949    serialized = msg.SerializeToString()
1950    msg2 = map_unittest_pb2.TestMap()
1951    msg2.ParseFromString(serialized)
1952
1953    self.assertEqual(2, len(msg2.map_int32_foreign_message))
1954    self.assertIn(123, msg2.map_int32_foreign_message)
1955    self.assertIn(-456, msg2.map_int32_foreign_message)
1956    self.assertEqual(2, len(msg2.map_int32_foreign_message))
1957    msg2.map_int32_foreign_message[123].c = 1
1958    # TODO(jieluo): Fix text format for message map.
1959    self.assertIn(str(msg2.map_int32_foreign_message),
1960                  ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
1961
1962  def testNestedMessageMapItemDelete(self):
1963    msg = map_unittest_pb2.TestMap()
1964    msg.map_int32_all_types[1].optional_nested_message.bb = 1
1965    del msg.map_int32_all_types[1]
1966    msg.map_int32_all_types[2].optional_nested_message.bb = 2
1967    self.assertEqual(1, len(msg.map_int32_all_types))
1968    msg.map_int32_all_types[1].optional_nested_message.bb = 1
1969    self.assertEqual(2, len(msg.map_int32_all_types))
1970
1971    serialized = msg.SerializeToString()
1972    msg2 = map_unittest_pb2.TestMap()
1973    msg2.ParseFromString(serialized)
1974    keys = [1, 2]
1975    # The loop triggers PyErr_Occurred() in c extension.
1976    for key in keys:
1977      del msg2.map_int32_all_types[key]
1978
1979  def testMapByteSize(self):
1980    msg = map_unittest_pb2.TestMap()
1981    msg.map_int32_int32[1] = 1
1982    size = msg.ByteSize()
1983    msg.map_int32_int32[1] = 128
1984    self.assertEqual(msg.ByteSize(), size + 1)
1985
1986    msg.map_int32_foreign_message[19].c = 1
1987    size = msg.ByteSize()
1988    msg.map_int32_foreign_message[19].c = 128
1989    self.assertEqual(msg.ByteSize(), size + 1)
1990
1991  def testMergeFrom(self):
1992    msg = map_unittest_pb2.TestMap()
1993    msg.map_int32_int32[12] = 34
1994    msg.map_int32_int32[56] = 78
1995    msg.map_int64_int64[22] = 33
1996    msg.map_int32_foreign_message[111].c = 5
1997    msg.map_int32_foreign_message[222].c = 10
1998
1999    msg2 = map_unittest_pb2.TestMap()
2000    msg2.map_int32_int32[12] = 55
2001    msg2.map_int64_int64[88] = 99
2002    msg2.map_int32_foreign_message[222].c = 15
2003    msg2.map_int32_foreign_message[222].d = 20
2004    old_map_value = msg2.map_int32_foreign_message[222]
2005
2006    msg2.MergeFrom(msg)
2007    # Compare with expected message instead of call
2008    # msg2.map_int32_foreign_message[222] to make sure MergeFrom does not
2009    # sync with repeated field and there is no duplicated keys.
2010    expected_msg = map_unittest_pb2.TestMap()
2011    expected_msg.CopyFrom(msg)
2012    expected_msg.map_int64_int64[88] = 99
2013    self.assertEqual(msg2, expected_msg)
2014
2015    self.assertEqual(34, msg2.map_int32_int32[12])
2016    self.assertEqual(78, msg2.map_int32_int32[56])
2017    self.assertEqual(33, msg2.map_int64_int64[22])
2018    self.assertEqual(99, msg2.map_int64_int64[88])
2019    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
2020    self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
2021    self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
2022    if api_implementation.Type() != 'cpp':
2023      # During the call to MergeFrom(), the C++ implementation will have
2024      # deallocated the underlying message, but this is very difficult to detect
2025      # properly. The line below is likely to cause a segmentation fault.
2026      # With the Python implementation, old_map_value is just 'detached' from
2027      # the main message. Using it will not crash of course, but since it still
2028      # have a reference to the parent message I'm sure we can find interesting
2029      # ways to cause inconsistencies.
2030      self.assertEqual(15, old_map_value.c)
2031
2032    # Verify that there is only one entry per key, even though the MergeFrom
2033    # may have internally created multiple entries for a single key in the
2034    # list representation.
2035    as_dict = {}
2036    for key in msg2.map_int32_foreign_message:
2037      self.assertFalse(key in as_dict)
2038      as_dict[key] = msg2.map_int32_foreign_message[key].c
2039
2040    self.assertEqual({111: 5, 222: 10}, as_dict)
2041
2042    # Special case: test that delete of item really removes the item, even if
2043    # there might have physically been duplicate keys due to the previous merge.
2044    # This is only a special case for the C++ implementation which stores the
2045    # map as an array.
2046    del msg2.map_int32_int32[12]
2047    self.assertFalse(12 in msg2.map_int32_int32)
2048
2049    del msg2.map_int32_foreign_message[222]
2050    self.assertFalse(222 in msg2.map_int32_foreign_message)
2051    with self.assertRaises(TypeError):
2052      del msg2.map_int32_foreign_message['']
2053
2054  def testMapMergeFrom(self):
2055    msg = map_unittest_pb2.TestMap()
2056    msg.map_int32_int32[12] = 34
2057    msg.map_int32_int32[56] = 78
2058    msg.map_int64_int64[22] = 33
2059    msg.map_int32_foreign_message[111].c = 5
2060    msg.map_int32_foreign_message[222].c = 10
2061
2062    msg2 = map_unittest_pb2.TestMap()
2063    msg2.map_int32_int32[12] = 55
2064    msg2.map_int64_int64[88] = 99
2065    msg2.map_int32_foreign_message[222].c = 15
2066    msg2.map_int32_foreign_message[222].d = 20
2067
2068    msg2.map_int32_int32.MergeFrom(msg.map_int32_int32)
2069    self.assertEqual(34, msg2.map_int32_int32[12])
2070    self.assertEqual(78, msg2.map_int32_int32[56])
2071
2072    msg2.map_int64_int64.MergeFrom(msg.map_int64_int64)
2073    self.assertEqual(33, msg2.map_int64_int64[22])
2074    self.assertEqual(99, msg2.map_int64_int64[88])
2075
2076    msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message)
2077    # Compare with expected message instead of call
2078    # msg.map_int32_foreign_message[222] to make sure MergeFrom does not
2079    # sync with repeated field and no duplicated keys.
2080    expected_msg = map_unittest_pb2.TestMap()
2081    expected_msg.CopyFrom(msg)
2082    expected_msg.map_int64_int64[88] = 99
2083    self.assertEqual(msg2, expected_msg)
2084
2085    # Test when cpp extension cache a map.
2086    m1 = map_unittest_pb2.TestMap()
2087    m2 = map_unittest_pb2.TestMap()
2088    self.assertEqual(m1.map_int32_foreign_message,
2089                     m1.map_int32_foreign_message)
2090    m2.map_int32_foreign_message[123].c = 10
2091    m1.MergeFrom(m2)
2092    self.assertEqual(10, m2.map_int32_foreign_message[123].c)
2093
2094    # Test merge maps within different message types.
2095    m1 = map_unittest_pb2.TestMap()
2096    m2 = map_unittest_pb2.TestMessageMap()
2097    m2.map_int32_message[123].optional_int32 = 10
2098    m1.map_int32_all_types.MergeFrom(m2.map_int32_message)
2099    self.assertEqual(10, m1.map_int32_all_types[123].optional_int32)
2100
2101    # Test overwrite message value map
2102    msg = map_unittest_pb2.TestMap()
2103    msg.map_int32_foreign_message[222].c = 123
2104    msg2 = map_unittest_pb2.TestMap()
2105    msg2.map_int32_foreign_message[222].d = 20
2106    msg.MergeFromString(msg2.SerializeToString())
2107    self.assertEqual(msg.map_int32_foreign_message[222].d, 20)
2108    self.assertNotEqual(msg.map_int32_foreign_message[222].c, 123)
2109
2110    # Merge a dict to map field is not accepted
2111    with self.assertRaises(AttributeError):
2112      m1.map_int32_all_types.MergeFrom(
2113          {1: unittest_proto3_arena_pb2.TestAllTypes()})
2114
2115  def testMergeFromBadType(self):
2116    msg = map_unittest_pb2.TestMap()
2117    with self.assertRaisesRegexp(
2118        TypeError,
2119        r'Parameter to MergeFrom\(\) must be instance of same class: expected '
2120        r'.*TestMap got int\.'):
2121      msg.MergeFrom(1)
2122
2123  def testCopyFromBadType(self):
2124    msg = map_unittest_pb2.TestMap()
2125    with self.assertRaisesRegexp(
2126        TypeError,
2127        r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
2128        r'expected .*TestMap got int\.'):
2129      msg.CopyFrom(1)
2130
2131  def testIntegerMapWithLongs(self):
2132    msg = map_unittest_pb2.TestMap()
2133    msg.map_int32_int32[long(-123)] = long(-456)
2134    msg.map_int64_int64[long(-2**33)] = long(-2**34)
2135    msg.map_uint32_uint32[long(123)] = long(456)
2136    msg.map_uint64_uint64[long(2**33)] = long(2**34)
2137
2138    serialized = msg.SerializeToString()
2139    msg2 = map_unittest_pb2.TestMap()
2140    msg2.ParseFromString(serialized)
2141
2142    self.assertEqual(-456, msg2.map_int32_int32[-123])
2143    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
2144    self.assertEqual(456, msg2.map_uint32_uint32[123])
2145    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
2146
2147  def testMapAssignmentCausesPresence(self):
2148    msg = map_unittest_pb2.TestMapSubmessage()
2149    msg.test_map.map_int32_int32[123] = 456
2150
2151    serialized = msg.SerializeToString()
2152    msg2 = map_unittest_pb2.TestMapSubmessage()
2153    msg2.ParseFromString(serialized)
2154
2155    self.assertEqual(msg, msg2)
2156
2157    # Now test that various mutations of the map properly invalidate the
2158    # cached size of the submessage.
2159    msg.test_map.map_int32_int32[888] = 999
2160    serialized = msg.SerializeToString()
2161    msg2.ParseFromString(serialized)
2162    self.assertEqual(msg, msg2)
2163
2164    msg.test_map.map_int32_int32.clear()
2165    serialized = msg.SerializeToString()
2166    msg2.ParseFromString(serialized)
2167    self.assertEqual(msg, msg2)
2168
2169  def testMapAssignmentCausesPresenceForSubmessages(self):
2170    msg = map_unittest_pb2.TestMapSubmessage()
2171    msg.test_map.map_int32_foreign_message[123].c = 5
2172
2173    serialized = msg.SerializeToString()
2174    msg2 = map_unittest_pb2.TestMapSubmessage()
2175    msg2.ParseFromString(serialized)
2176
2177    self.assertEqual(msg, msg2)
2178
2179    # Now test that various mutations of the map properly invalidate the
2180    # cached size of the submessage.
2181    msg.test_map.map_int32_foreign_message[888].c = 7
2182    serialized = msg.SerializeToString()
2183    msg2.ParseFromString(serialized)
2184    self.assertEqual(msg, msg2)
2185
2186    msg.test_map.map_int32_foreign_message[888].MergeFrom(
2187        msg.test_map.map_int32_foreign_message[123])
2188    serialized = msg.SerializeToString()
2189    msg2.ParseFromString(serialized)
2190    self.assertEqual(msg, msg2)
2191
2192    msg.test_map.map_int32_foreign_message.clear()
2193    serialized = msg.SerializeToString()
2194    msg2.ParseFromString(serialized)
2195    self.assertEqual(msg, msg2)
2196
2197  def testModifyMapWhileIterating(self):
2198    msg = map_unittest_pb2.TestMap()
2199
2200    string_string_iter = iter(msg.map_string_string)
2201    int32_foreign_iter = iter(msg.map_int32_foreign_message)
2202
2203    msg.map_string_string['abc'] = '123'
2204    msg.map_int32_foreign_message[5].c = 5
2205
2206    with self.assertRaises(RuntimeError):
2207      for key in string_string_iter:
2208        pass
2209
2210    with self.assertRaises(RuntimeError):
2211      for key in int32_foreign_iter:
2212        pass
2213
2214  def testSubmessageMap(self):
2215    msg = map_unittest_pb2.TestMap()
2216
2217    submsg = msg.map_int32_foreign_message[111]
2218    self.assertIs(submsg, msg.map_int32_foreign_message[111])
2219    self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
2220
2221    submsg.c = 5
2222
2223    serialized = msg.SerializeToString()
2224    msg2 = map_unittest_pb2.TestMap()
2225    msg2.ParseFromString(serialized)
2226
2227    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
2228
2229    # Doesn't allow direct submessage assignment.
2230    with self.assertRaises(ValueError):
2231      msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
2232
2233  def testMapIteration(self):
2234    msg = map_unittest_pb2.TestMap()
2235
2236    for k, v in msg.map_int32_int32.items():
2237      # Should not be reached.
2238      self.assertTrue(False)
2239
2240    msg.map_int32_int32[2] = 4
2241    msg.map_int32_int32[3] = 6
2242    msg.map_int32_int32[4] = 8
2243    self.assertEqual(3, len(msg.map_int32_int32))
2244
2245    matching_dict = {2: 4, 3: 6, 4: 8}
2246    self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
2247
2248  def testPython2Map(self):
2249    if sys.version_info < (3,):
2250      msg = map_unittest_pb2.TestMap()
2251      msg.map_int32_int32[2] = 4
2252      msg.map_int32_int32[3] = 6
2253      msg.map_int32_int32[4] = 8
2254      msg.map_int32_int32[5] = 10
2255      map_int32 = msg.map_int32_int32
2256      self.assertEqual(4, len(map_int32))
2257      msg2 = map_unittest_pb2.TestMap()
2258      msg2.ParseFromString(msg.SerializeToString())
2259
2260      def CheckItems(seq, iterator):
2261        self.assertEqual(next(iterator), seq[0])
2262        self.assertEqual(list(iterator), seq[1:])
2263
2264      CheckItems(map_int32.items(), map_int32.iteritems())
2265      CheckItems(map_int32.keys(), map_int32.iterkeys())
2266      CheckItems(map_int32.values(), map_int32.itervalues())
2267
2268      self.assertEqual(6, map_int32.get(3))
2269      self.assertEqual(None, map_int32.get(999))
2270      self.assertEqual(6, map_int32.pop(3))
2271      self.assertEqual(0, map_int32.pop(3))
2272      self.assertEqual(3, len(map_int32))
2273      key, value = map_int32.popitem()
2274      self.assertEqual(2 * key, value)
2275      self.assertEqual(2, len(map_int32))
2276      map_int32.clear()
2277      self.assertEqual(0, len(map_int32))
2278
2279      with self.assertRaises(KeyError):
2280        map_int32.popitem()
2281
2282      self.assertEqual(0, map_int32.setdefault(2))
2283      self.assertEqual(1, len(map_int32))
2284
2285      map_int32.update(msg2.map_int32_int32)
2286      self.assertEqual(4, len(map_int32))
2287
2288      with self.assertRaises(TypeError):
2289        map_int32.update(msg2.map_int32_int32,
2290                         msg2.map_int32_int32)
2291      with self.assertRaises(TypeError):
2292        map_int32.update(0)
2293      with self.assertRaises(TypeError):
2294        map_int32.update(value=12)
2295
2296  def testMapItems(self):
2297    # Map items used to have strange behaviors when use c extension. Because
2298    # [] may reorder the map and invalidate any exsting iterators.
2299    # TODO(jieluo): Check if [] reordering the map is a bug or intended
2300    # behavior.
2301    msg = map_unittest_pb2.TestMap()
2302    msg.map_string_string['local_init_op'] = ''
2303    msg.map_string_string['trainable_variables'] = ''
2304    msg.map_string_string['variables'] = ''
2305    msg.map_string_string['init_op'] = ''
2306    msg.map_string_string['summaries'] = ''
2307    items1 = msg.map_string_string.items()
2308    items2 = msg.map_string_string.items()
2309    self.assertEqual(items1, items2)
2310
2311  def testMapDeterministicSerialization(self):
2312    golden_data = (b'r\x0c\n\x07init_op\x12\x01d'
2313                   b'r\n\n\x05item1\x12\x01e'
2314                   b'r\n\n\x05item2\x12\x01f'
2315                   b'r\n\n\x05item3\x12\x01g'
2316                   b'r\x0b\n\x05item4\x12\x02QQ'
2317                   b'r\x12\n\rlocal_init_op\x12\x01a'
2318                   b'r\x0e\n\tsummaries\x12\x01e'
2319                   b'r\x18\n\x13trainable_variables\x12\x01b'
2320                   b'r\x0e\n\tvariables\x12\x01c')
2321    msg = map_unittest_pb2.TestMap()
2322    msg.map_string_string['local_init_op'] = 'a'
2323    msg.map_string_string['trainable_variables'] = 'b'
2324    msg.map_string_string['variables'] = 'c'
2325    msg.map_string_string['init_op'] = 'd'
2326    msg.map_string_string['summaries'] = 'e'
2327    msg.map_string_string['item1'] = 'e'
2328    msg.map_string_string['item2'] = 'f'
2329    msg.map_string_string['item3'] = 'g'
2330    msg.map_string_string['item4'] = 'QQ'
2331
2332    # If deterministic serialization is not working correctly, this will be
2333    # "flaky" depending on the exact python dict hash seed.
2334    #
2335    # Fortunately, there are enough items in this map that it is extremely
2336    # unlikely to ever hit the "right" in-order combination, so the test
2337    # itself should fail reliably.
2338    self.assertEqual(golden_data, msg.SerializeToString(deterministic=True))
2339
2340  def testMapIterationClearMessage(self):
2341    # Iterator needs to work even if message and map are deleted.
2342    msg = map_unittest_pb2.TestMap()
2343
2344    msg.map_int32_int32[2] = 4
2345    msg.map_int32_int32[3] = 6
2346    msg.map_int32_int32[4] = 8
2347
2348    it = msg.map_int32_int32.items()
2349    del msg
2350
2351    matching_dict = {2: 4, 3: 6, 4: 8}
2352    self.assertMapIterEquals(it, matching_dict)
2353
2354  def testMapConstruction(self):
2355    msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
2356    self.assertEqual(2, msg.map_int32_int32[1])
2357    self.assertEqual(4, msg.map_int32_int32[3])
2358
2359    msg = map_unittest_pb2.TestMap(
2360        map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
2361    self.assertEqual(5, msg.map_int32_foreign_message[3].c)
2362
2363  def testMapScalarFieldConstruction(self):
2364    msg1 = map_unittest_pb2.TestMap()
2365    msg1.map_int32_int32[1] = 42
2366    msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32)
2367    self.assertEqual(42, msg2.map_int32_int32[1])
2368
2369  def testMapMessageFieldConstruction(self):
2370    msg1 = map_unittest_pb2.TestMap()
2371    msg1.map_string_foreign_message['test'].c = 42
2372    msg2 = map_unittest_pb2.TestMap(
2373      map_string_foreign_message=msg1.map_string_foreign_message)
2374    self.assertEqual(42, msg2.map_string_foreign_message['test'].c)
2375
2376  def testMapFieldRaisesCorrectError(self):
2377    # Should raise a TypeError when given a non-iterable.
2378    with self.assertRaises(TypeError):
2379      map_unittest_pb2.TestMap(map_string_foreign_message=1)
2380
2381  def testMapValidAfterFieldCleared(self):
2382    # Map needs to work even if field is cleared.
2383    # For the C++ implementation this tests the correctness of
2384    # MapContainer::Release()
2385    msg = map_unittest_pb2.TestMap()
2386    int32_map = msg.map_int32_int32
2387
2388    int32_map[2] = 4
2389    int32_map[3] = 6
2390    int32_map[4] = 8
2391
2392    msg.ClearField('map_int32_int32')
2393    self.assertEqual(b'', msg.SerializeToString())
2394    matching_dict = {2: 4, 3: 6, 4: 8}
2395    self.assertMapIterEquals(int32_map.items(), matching_dict)
2396
2397  def testMessageMapValidAfterFieldCleared(self):
2398    # Map needs to work even if field is cleared.
2399    # For the C++ implementation this tests the correctness of
2400    # MapContainer::Release()
2401    msg = map_unittest_pb2.TestMap()
2402    int32_foreign_message = msg.map_int32_foreign_message
2403
2404    int32_foreign_message[2].c = 5
2405
2406    msg.ClearField('map_int32_foreign_message')
2407    self.assertEqual(b'', msg.SerializeToString())
2408    self.assertTrue(2 in int32_foreign_message.keys())
2409
2410  def testMessageMapItemValidAfterTopMessageCleared(self):
2411    # Message map item needs to work even if it is cleared.
2412    # For the C++ implementation this tests the correctness of
2413    # MapContainer::Release()
2414    msg = map_unittest_pb2.TestMap()
2415    msg.map_int32_all_types[2].optional_string = 'bar'
2416
2417    if api_implementation.Type() == 'cpp':
2418      # Need to keep the map reference because of b/27942626.
2419      # TODO(jieluo): Remove it.
2420      unused_map = msg.map_int32_all_types  # pylint: disable=unused-variable
2421    msg_value = msg.map_int32_all_types[2]
2422    msg.Clear()
2423
2424    # Reset to trigger sync between repeated field and map in c++.
2425    msg.map_int32_all_types[3].optional_string = 'foo'
2426    self.assertEqual(msg_value.optional_string, 'bar')
2427
2428  def testMapIterInvalidatedByClearField(self):
2429    # Map iterator is invalidated when field is cleared.
2430    # But this case does need to not crash the interpreter.
2431    # For the C++ implementation this tests the correctness of
2432    # ScalarMapContainer::Release()
2433    msg = map_unittest_pb2.TestMap()
2434
2435    it = iter(msg.map_int32_int32)
2436
2437    msg.ClearField('map_int32_int32')
2438    with self.assertRaises(RuntimeError):
2439      for _ in it:
2440        pass
2441
2442    it = iter(msg.map_int32_foreign_message)
2443    msg.ClearField('map_int32_foreign_message')
2444    with self.assertRaises(RuntimeError):
2445      for _ in it:
2446        pass
2447
2448  def testMapDelete(self):
2449    msg = map_unittest_pb2.TestMap()
2450
2451    self.assertEqual(0, len(msg.map_int32_int32))
2452
2453    msg.map_int32_int32[4] = 6
2454    self.assertEqual(1, len(msg.map_int32_int32))
2455
2456    with self.assertRaises(KeyError):
2457      del msg.map_int32_int32[88]
2458
2459    del msg.map_int32_int32[4]
2460    self.assertEqual(0, len(msg.map_int32_int32))
2461
2462    with self.assertRaises(KeyError):
2463      del msg.map_int32_all_types[32]
2464
2465  def testMapsAreMapping(self):
2466    msg = map_unittest_pb2.TestMap()
2467    self.assertIsInstance(msg.map_int32_int32, collections_abc.Mapping)
2468    self.assertIsInstance(msg.map_int32_int32, collections_abc.MutableMapping)
2469    self.assertIsInstance(msg.map_int32_foreign_message, collections_abc.Mapping)
2470    self.assertIsInstance(msg.map_int32_foreign_message,
2471                          collections_abc.MutableMapping)
2472
2473  def testMapsCompare(self):
2474    msg = map_unittest_pb2.TestMap()
2475    msg.map_int32_int32[-123] = -456
2476    self.assertEqual(msg.map_int32_int32, msg.map_int32_int32)
2477    self.assertEqual(msg.map_int32_foreign_message,
2478                     msg.map_int32_foreign_message)
2479    self.assertNotEqual(msg.map_int32_int32, 0)
2480
2481  def testMapFindInitializationErrorsSmokeTest(self):
2482    msg = map_unittest_pb2.TestMap()
2483    msg.map_string_string['abc'] = '123'
2484    msg.map_int32_int32[35] = 64
2485    msg.map_string_foreign_message['foo'].c = 5
2486    self.assertEqual(0, len(msg.FindInitializationErrors()))
2487
2488  @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2')
2489  def testStrictUtf8Check(self):
2490    # Test u'\ud801' is rejected at parser in both python2 and python3.
2491    serialized = (b'r\x03\xed\xa0\x81')
2492    msg = unittest_proto3_arena_pb2.TestAllTypes()
2493    with self.assertRaises(Exception) as context:
2494      msg.MergeFromString(serialized)
2495    if api_implementation.Type() == 'python':
2496      self.assertIn('optional_string', str(context.exception))
2497    else:
2498      self.assertIn('Error parsing message', str(context.exception))
2499
2500    # Test optional_string=u'��' is accepted.
2501    serialized = unittest_proto3_arena_pb2.TestAllTypes(
2502        optional_string=u'��').SerializeToString()
2503    msg2 = unittest_proto3_arena_pb2.TestAllTypes()
2504    msg2.MergeFromString(serialized)
2505    self.assertEqual(msg2.optional_string, u'��')
2506
2507    msg = unittest_proto3_arena_pb2.TestAllTypes(
2508        optional_string=u'\ud001')
2509    self.assertEqual(msg.optional_string, u'\ud001')
2510
2511  @unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2')
2512  def testSurrogatesInPython3(self):
2513    # Surrogates like U+D83D is an invalid unicode character, it is
2514    # supported by Python2 only because in some builds, unicode strings
2515    # use 2-bytes code units. Since Python 3.3, we don't have this problem.
2516    #
2517    # Surrogates are utf16 code units, in a unicode string they are invalid
2518    # characters even when they appear in pairs like u'\ud801\udc01'. Protobuf
2519    # Python3 reject such cases at setters and parsers. Python2 accpect it
2520    # to keep same features with the language itself. 'Unpaired pairs'
2521    # like u'\ud801' are rejected at parsers when strict utf8 check is enabled
2522    # in proto3 to keep same behavior with c extension.
2523
2524    # Surrogates are rejected at setters in Python3.
2525    with self.assertRaises(ValueError):
2526      unittest_proto3_arena_pb2.TestAllTypes(
2527          optional_string=u'\ud801\udc01')
2528    with self.assertRaises(ValueError):
2529      unittest_proto3_arena_pb2.TestAllTypes(
2530          optional_string=b'\xed\xa0\x81')
2531    with self.assertRaises(ValueError):
2532      unittest_proto3_arena_pb2.TestAllTypes(
2533          optional_string=u'\ud801')
2534    with self.assertRaises(ValueError):
2535      unittest_proto3_arena_pb2.TestAllTypes(
2536          optional_string=u'\ud801\ud801')
2537
2538  @unittest.skipIf(six.PY3 or sys.maxunicode == UCS2_MAXUNICODE,
2539                   'Surrogates are rejected at setters in Python3')
2540  def testSurrogatesInPython2(self):
2541    # Test optional_string=u'\ud801\udc01'.
2542    # surrogate pair is acceptable in python2.
2543    msg = unittest_proto3_arena_pb2.TestAllTypes(
2544        optional_string=u'\ud801\udc01')
2545    # TODO(jieluo): Change pure python to have same behavior with c extension.
2546    # Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are
2547    # equal, some are not equal.
2548    if api_implementation.Type() == 'python':
2549      self.assertEqual(msg.optional_string, u'\ud801\udc01')
2550    else:
2551      self.assertEqual(msg.optional_string, u'\U00010401')
2552    serialized = msg.SerializeToString()
2553    msg2 = unittest_proto3_arena_pb2.TestAllTypes()
2554    msg2.MergeFromString(serialized)
2555    self.assertEqual(msg2.optional_string, u'\U00010401')
2556
2557    # Python2 does not reject surrogates at setters.
2558    msg = unittest_proto3_arena_pb2.TestAllTypes(
2559        optional_string=b'\xed\xa0\x81')
2560    unittest_proto3_arena_pb2.TestAllTypes(
2561        optional_string=u'\ud801')
2562    unittest_proto3_arena_pb2.TestAllTypes(
2563        optional_string=u'\ud801\ud801')
2564
2565
2566@testing_refleaks.TestCase
2567class ValidTypeNamesTest(unittest.TestCase):
2568
2569  def assertImportFromName(self, msg, base_name):
2570    # Parse <type 'module.class_name'> to extra 'some.name' as a string.
2571    tp_name = str(type(msg)).split("'")[1]
2572    valid_names = ('Repeated%sContainer' % base_name,
2573                   'Repeated%sFieldContainer' % base_name)
2574    self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
2575                    '%r does end with any of %r' % (tp_name, valid_names))
2576
2577    parts = tp_name.split('.')
2578    class_name = parts[-1]
2579    module_name = '.'.join(parts[:-1])
2580    __import__(module_name, fromlist=[class_name])
2581
2582  def testTypeNamesCanBeImported(self):
2583    # If import doesn't work, pickling won't work either.
2584    pb = unittest_pb2.TestAllTypes()
2585    self.assertImportFromName(pb.repeated_int32, 'Scalar')
2586    self.assertImportFromName(pb.repeated_nested_message, 'Composite')
2587
2588@testing_refleaks.TestCase
2589class PackedFieldTest(unittest.TestCase):
2590
2591  def setMessage(self, message):
2592    message.repeated_int32.append(1)
2593    message.repeated_int64.append(1)
2594    message.repeated_uint32.append(1)
2595    message.repeated_uint64.append(1)
2596    message.repeated_sint32.append(1)
2597    message.repeated_sint64.append(1)
2598    message.repeated_fixed32.append(1)
2599    message.repeated_fixed64.append(1)
2600    message.repeated_sfixed32.append(1)
2601    message.repeated_sfixed64.append(1)
2602    message.repeated_float.append(1.0)
2603    message.repeated_double.append(1.0)
2604    message.repeated_bool.append(True)
2605    message.repeated_nested_enum.append(1)
2606
2607  def testPackedFields(self):
2608    message = packed_field_test_pb2.TestPackedTypes()
2609    self.setMessage(message)
2610    golden_data = (b'\x0A\x01\x01'
2611                   b'\x12\x01\x01'
2612                   b'\x1A\x01\x01'
2613                   b'\x22\x01\x01'
2614                   b'\x2A\x01\x02'
2615                   b'\x32\x01\x02'
2616                   b'\x3A\x04\x01\x00\x00\x00'
2617                   b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2618                   b'\x4A\x04\x01\x00\x00\x00'
2619                   b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2620                   b'\x5A\x04\x00\x00\x80\x3f'
2621                   b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
2622                   b'\x6A\x01\x01'
2623                   b'\x72\x01\x01')
2624    self.assertEqual(golden_data, message.SerializeToString())
2625
2626  def testUnpackedFields(self):
2627    message = packed_field_test_pb2.TestUnpackedTypes()
2628    self.setMessage(message)
2629    golden_data = (b'\x08\x01'
2630                   b'\x10\x01'
2631                   b'\x18\x01'
2632                   b'\x20\x01'
2633                   b'\x28\x02'
2634                   b'\x30\x02'
2635                   b'\x3D\x01\x00\x00\x00'
2636                   b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
2637                   b'\x4D\x01\x00\x00\x00'
2638                   b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
2639                   b'\x5D\x00\x00\x80\x3f'
2640                   b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
2641                   b'\x68\x01'
2642                   b'\x70\x01')
2643    self.assertEqual(golden_data, message.SerializeToString())
2644
2645
2646@unittest.skipIf(api_implementation.Type() != 'cpp' or
2647                 sys.version_info < (2, 7),
2648                 'explicit tests of the C++ implementation for PY27 and above')
2649@testing_refleaks.TestCase
2650class OversizeProtosTest(unittest.TestCase):
2651
2652  @classmethod
2653  def setUpClass(cls):
2654    # At the moment, reference cycles between DescriptorPool and Message classes
2655    # are not detected and these objects are never freed.
2656    # To avoid errors with ReferenceLeakChecker, we create the class only once.
2657    file_desc = """
2658      name: "f/f.msg2"
2659      package: "f"
2660      message_type {
2661        name: "msg1"
2662        field {
2663          name: "payload"
2664          number: 1
2665          label: LABEL_OPTIONAL
2666          type: TYPE_STRING
2667        }
2668      }
2669      message_type {
2670        name: "msg2"
2671        field {
2672          name: "field"
2673          number: 1
2674          label: LABEL_OPTIONAL
2675          type: TYPE_MESSAGE
2676          type_name: "msg1"
2677        }
2678      }
2679    """
2680    pool = descriptor_pool.DescriptorPool()
2681    desc = descriptor_pb2.FileDescriptorProto()
2682    text_format.Parse(file_desc, desc)
2683    pool.Add(desc)
2684    cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
2685        pool.FindMessageTypeByName('f.msg2'))
2686
2687  def setUp(self):
2688    self.p = self.proto_cls()
2689    self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
2690    self.p_serialized = self.p.SerializeToString()
2691
2692  def testAssertOversizeProto(self):
2693    from google.protobuf.pyext._message import SetAllowOversizeProtos
2694    SetAllowOversizeProtos(False)
2695    q = self.proto_cls()
2696    try:
2697      q.ParseFromString(self.p_serialized)
2698    except message.DecodeError as e:
2699      self.assertEqual(str(e), 'Error parsing message')
2700
2701  def testSucceedOversizeProto(self):
2702    from google.protobuf.pyext._message import SetAllowOversizeProtos
2703    SetAllowOversizeProtos(True)
2704    q = self.proto_cls()
2705    q.ParseFromString(self.p_serialized)
2706    self.assertEqual(self.p.field.payload, q.field.payload)
2707
2708if __name__ == '__main__':
2709  unittest.main()
2710