• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#! /usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# https://developers.google.com/protocol-buffers/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Tests python protocol buffers against the golden message.
35
36Note that the golden messages exercise every known field type, thus this
37test ends up exercising and verifying nearly all of the parsing and
38serialization code in the whole library.
39
40TODO(kenton):  Merge with wire_format_test?  It doesn't make a whole lot of
41sense to call this a test of the "message" module, which only declares an
42abstract interface.
43"""
44
45__author__ = 'gps@google.com (Gregory P. Smith)'
46
47
48import copy
49import math
50import operator
51import pickle
52import pydoc
53import six
54import sys
55import warnings
56
57try:
58  # Since python 3
59  import collections.abc as collections_abc
60except ImportError:
61  # Won't work after python 3.8
62  import collections as collections_abc
63
64try:
65  import unittest2 as unittest  # PY26
66except ImportError:
67  import unittest
68try:
69  cmp                                   # Python 2
70except NameError:
71  cmp = lambda x, y: (x > y) - (x < y)  # Python 3
72
73from google.protobuf import map_proto2_unittest_pb2
74from google.protobuf import map_unittest_pb2
75from google.protobuf import unittest_pb2
76from google.protobuf import unittest_proto3_arena_pb2
77from google.protobuf import descriptor_pb2
78from google.protobuf import descriptor_pool
79from google.protobuf import message_factory
80from google.protobuf import text_format
81from google.protobuf.internal import api_implementation
82from google.protobuf.internal import encoder
83from google.protobuf.internal import more_extensions_pb2
84from google.protobuf.internal import packed_field_test_pb2
85from google.protobuf.internal import test_util
86from google.protobuf.internal import test_proto3_optional_pb2
87from google.protobuf.internal import testing_refleaks
88from google.protobuf import message
89from google.protobuf.internal import _parameterized
90
91UCS2_MAXUNICODE = 65535
92if six.PY3:
93  long = int
94
95
96# Python pre-2.6 does not have isinf() or isnan() functions, so we have
97# to provide our own.
98def isnan(val):
99  # NaN is never equal to itself.
100  return val != val
101def isinf(val):
102  # Infinity times zero equals NaN.
103  return not isnan(val) and isnan(val * 0)
104def IsPosInf(val):
105  return isinf(val) and (val > 0)
106def IsNegInf(val):
107  return isinf(val) and (val < 0)
108
109
110warnings.simplefilter('error', DeprecationWarning)
111
112
113@_parameterized.named_parameters(
114    ('_proto2', unittest_pb2),
115    ('_proto3', unittest_proto3_arena_pb2))
116@testing_refleaks.TestCase
117class MessageTest(unittest.TestCase):
118
119  def testBadUtf8String(self, message_module):
120    if api_implementation.Type() != 'python':
121      self.skipTest("Skipping testBadUtf8String, currently only the python "
122                    "api implementation raises UnicodeDecodeError when a "
123                    "string field contains bad utf-8.")
124    bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
125    with self.assertRaises(UnicodeDecodeError) as context:
126      message_module.TestAllTypes.FromString(bad_utf8_data)
127    self.assertIn('TestAllTypes.optional_string', str(context.exception))
128
129  def testGoldenMessage(self, message_module):
130    # Proto3 doesn't have the "default_foo" members or foreign enums,
131    # and doesn't preserve unknown fields, so for proto3 we use a golden
132    # message that doesn't have these fields set.
133    if message_module is unittest_pb2:
134      golden_data = test_util.GoldenFileData(
135          'golden_message_oneof_implemented')
136    else:
137      golden_data = test_util.GoldenFileData('golden_message_proto3')
138
139    golden_message = message_module.TestAllTypes()
140    golden_message.ParseFromString(golden_data)
141    if message_module is unittest_pb2:
142      test_util.ExpectAllFieldsSet(self, golden_message)
143    self.assertEqual(golden_data, golden_message.SerializeToString())
144    golden_copy = copy.deepcopy(golden_message)
145    self.assertEqual(golden_data, golden_copy.SerializeToString())
146
147  def testGoldenPackedMessage(self, message_module):
148    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
149    golden_message = message_module.TestPackedTypes()
150    parsed_bytes = golden_message.ParseFromString(golden_data)
151    all_set = message_module.TestPackedTypes()
152    test_util.SetAllPackedFields(all_set)
153    self.assertEqual(parsed_bytes, len(golden_data))
154    self.assertEqual(all_set, golden_message)
155    self.assertEqual(golden_data, all_set.SerializeToString())
156    golden_copy = copy.deepcopy(golden_message)
157    self.assertEqual(golden_data, golden_copy.SerializeToString())
158
159  def testParseErrors(self, message_module):
160    msg = message_module.TestAllTypes()
161    self.assertRaises(TypeError, msg.FromString, 0)
162    self.assertRaises(Exception, msg.FromString, '0')
163    # TODO(jieluo): Fix cpp extension to raise error instead of warning.
164    # b/27494216
165    end_tag = encoder.TagBytes(1, 4)
166    if api_implementation.Type() == 'python':
167      with self.assertRaises(message.DecodeError) as context:
168        msg.FromString(end_tag)
169      self.assertEqual('Unexpected end-group tag.', str(context.exception))
170
171    # Field number 0 is illegal.
172    self.assertRaises(message.DecodeError, msg.FromString, b'\3\4')
173
174  def testDeterminismParameters(self, message_module):
175    # This message is always deterministically serialized, even if determinism
176    # is disabled, so we can use it to verify that all the determinism
177    # parameters work correctly.
178    golden_data = (b'\xe2\x02\nOne string'
179                   b'\xe2\x02\nTwo string'
180                   b'\xe2\x02\nRed string'
181                   b'\xe2\x02\x0bBlue string')
182    golden_message = message_module.TestAllTypes()
183    golden_message.repeated_string.extend([
184        'One string',
185        'Two string',
186        'Red string',
187        'Blue string',
188    ])
189    self.assertEqual(golden_data,
190                     golden_message.SerializeToString(deterministic=None))
191    self.assertEqual(golden_data,
192                     golden_message.SerializeToString(deterministic=False))
193    self.assertEqual(golden_data,
194                     golden_message.SerializeToString(deterministic=True))
195
196    class BadArgError(Exception):
197      pass
198
199    class BadArg(object):
200
201      def __nonzero__(self):
202        raise BadArgError()
203
204      def __bool__(self):
205        raise BadArgError()
206
207    with self.assertRaises(BadArgError):
208      golden_message.SerializeToString(deterministic=BadArg())
209
210  def testPickleSupport(self, message_module):
211    golden_data = test_util.GoldenFileData('golden_message')
212    golden_message = message_module.TestAllTypes()
213    golden_message.ParseFromString(golden_data)
214    pickled_message = pickle.dumps(golden_message)
215
216    unpickled_message = pickle.loads(pickled_message)
217    self.assertEqual(unpickled_message, golden_message)
218
219  def testPickleNestedMessage(self, message_module):
220    golden_message = message_module.TestPickleNestedMessage.NestedMessage(bb=1)
221    pickled_message = pickle.dumps(golden_message)
222    unpickled_message = pickle.loads(pickled_message)
223    self.assertEqual(unpickled_message, golden_message)
224
225  def testPickleNestedNestedMessage(self, message_module):
226    cls = message_module.TestPickleNestedMessage.NestedMessage
227    golden_message = cls.NestedNestedMessage(cc=1)
228    pickled_message = pickle.dumps(golden_message)
229    unpickled_message = pickle.loads(pickled_message)
230    self.assertEqual(unpickled_message, golden_message)
231
232  def testPositiveInfinity(self, message_module):
233    if message_module is unittest_pb2:
234      golden_data = (b'\x5D\x00\x00\x80\x7F'
235                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
236                     b'\xCD\x02\x00\x00\x80\x7F'
237                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
238    else:
239      golden_data = (b'\x5D\x00\x00\x80\x7F'
240                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
241                     b'\xCA\x02\x04\x00\x00\x80\x7F'
242                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
243
244    golden_message = message_module.TestAllTypes()
245    golden_message.ParseFromString(golden_data)
246    self.assertTrue(IsPosInf(golden_message.optional_float))
247    self.assertTrue(IsPosInf(golden_message.optional_double))
248    self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
249    self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
250    self.assertEqual(golden_data, golden_message.SerializeToString())
251
252  def testNegativeInfinity(self, message_module):
253    if message_module is unittest_pb2:
254      golden_data = (b'\x5D\x00\x00\x80\xFF'
255                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
256                     b'\xCD\x02\x00\x00\x80\xFF'
257                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
258    else:
259      golden_data = (b'\x5D\x00\x00\x80\xFF'
260                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
261                     b'\xCA\x02\x04\x00\x00\x80\xFF'
262                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
263
264    golden_message = message_module.TestAllTypes()
265    golden_message.ParseFromString(golden_data)
266    self.assertTrue(IsNegInf(golden_message.optional_float))
267    self.assertTrue(IsNegInf(golden_message.optional_double))
268    self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
269    self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
270    self.assertEqual(golden_data, golden_message.SerializeToString())
271
272  def testNotANumber(self, message_module):
273    golden_data = (b'\x5D\x00\x00\xC0\x7F'
274                   b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
275                   b'\xCD\x02\x00\x00\xC0\x7F'
276                   b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
277    golden_message = message_module.TestAllTypes()
278    golden_message.ParseFromString(golden_data)
279    self.assertTrue(isnan(golden_message.optional_float))
280    self.assertTrue(isnan(golden_message.optional_double))
281    self.assertTrue(isnan(golden_message.repeated_float[0]))
282    self.assertTrue(isnan(golden_message.repeated_double[0]))
283
284    # The protocol buffer may serialize to any one of multiple different
285    # representations of a NaN.  Rather than verify a specific representation,
286    # verify the serialized string can be converted into a correctly
287    # behaving protocol buffer.
288    serialized = golden_message.SerializeToString()
289    message = message_module.TestAllTypes()
290    message.ParseFromString(serialized)
291    self.assertTrue(isnan(message.optional_float))
292    self.assertTrue(isnan(message.optional_double))
293    self.assertTrue(isnan(message.repeated_float[0]))
294    self.assertTrue(isnan(message.repeated_double[0]))
295
296  def testPositiveInfinityPacked(self, message_module):
297    golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
298                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
299    golden_message = message_module.TestPackedTypes()
300    golden_message.ParseFromString(golden_data)
301    self.assertTrue(IsPosInf(golden_message.packed_float[0]))
302    self.assertTrue(IsPosInf(golden_message.packed_double[0]))
303    self.assertEqual(golden_data, golden_message.SerializeToString())
304
305  def testNegativeInfinityPacked(self, message_module):
306    golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
307                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
308    golden_message = message_module.TestPackedTypes()
309    golden_message.ParseFromString(golden_data)
310    self.assertTrue(IsNegInf(golden_message.packed_float[0]))
311    self.assertTrue(IsNegInf(golden_message.packed_double[0]))
312    self.assertEqual(golden_data, golden_message.SerializeToString())
313
314  def testNotANumberPacked(self, message_module):
315    golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
316                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
317    golden_message = message_module.TestPackedTypes()
318    golden_message.ParseFromString(golden_data)
319    self.assertTrue(isnan(golden_message.packed_float[0]))
320    self.assertTrue(isnan(golden_message.packed_double[0]))
321
322    serialized = golden_message.SerializeToString()
323    message = message_module.TestPackedTypes()
324    message.ParseFromString(serialized)
325    self.assertTrue(isnan(message.packed_float[0]))
326    self.assertTrue(isnan(message.packed_double[0]))
327
328  def testExtremeFloatValues(self, message_module):
329    message = message_module.TestAllTypes()
330
331    # Most positive exponent, no significand bits set.
332    kMostPosExponentNoSigBits = math.pow(2, 127)
333    message.optional_float = kMostPosExponentNoSigBits
334    message.ParseFromString(message.SerializeToString())
335    self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
336
337    # Most positive exponent, one significand bit set.
338    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
339    message.optional_float = kMostPosExponentOneSigBit
340    message.ParseFromString(message.SerializeToString())
341    self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
342
343    # Repeat last two cases with values of same magnitude, but negative.
344    message.optional_float = -kMostPosExponentNoSigBits
345    message.ParseFromString(message.SerializeToString())
346    self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
347
348    message.optional_float = -kMostPosExponentOneSigBit
349    message.ParseFromString(message.SerializeToString())
350    self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
351
352    # Most negative exponent, no significand bits set.
353    kMostNegExponentNoSigBits = math.pow(2, -127)
354    message.optional_float = kMostNegExponentNoSigBits
355    message.ParseFromString(message.SerializeToString())
356    self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
357
358    # Most negative exponent, one significand bit set.
359    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
360    message.optional_float = kMostNegExponentOneSigBit
361    message.ParseFromString(message.SerializeToString())
362    self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
363
364    # Repeat last two cases with values of the same magnitude, but negative.
365    message.optional_float = -kMostNegExponentNoSigBits
366    message.ParseFromString(message.SerializeToString())
367    self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
368
369    message.optional_float = -kMostNegExponentOneSigBit
370    message.ParseFromString(message.SerializeToString())
371    self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
372
373    # Max 4 bytes float value
374    max_float = float.fromhex('0x1.fffffep+127')
375    message.optional_float = max_float
376    self.assertAlmostEqual(message.optional_float, max_float)
377    serialized_data = message.SerializeToString()
378    message.ParseFromString(serialized_data)
379    self.assertAlmostEqual(message.optional_float, max_float)
380
381    # Test set double to float field.
382    message.optional_float = 3.4028235e+39
383    self.assertEqual(message.optional_float, float('inf'))
384    serialized_data = message.SerializeToString()
385    message.ParseFromString(serialized_data)
386    self.assertEqual(message.optional_float, float('inf'))
387
388    message.optional_float = -3.4028235e+39
389    self.assertEqual(message.optional_float, float('-inf'))
390
391    message.optional_float = 1.4028235e-39
392    self.assertAlmostEqual(message.optional_float, 1.4028235e-39)
393
394  def testExtremeDoubleValues(self, message_module):
395    message = message_module.TestAllTypes()
396
397    # Most positive exponent, no significand bits set.
398    kMostPosExponentNoSigBits = math.pow(2, 1023)
399    message.optional_double = kMostPosExponentNoSigBits
400    message.ParseFromString(message.SerializeToString())
401    self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
402
403    # Most positive exponent, one significand bit set.
404    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
405    message.optional_double = kMostPosExponentOneSigBit
406    message.ParseFromString(message.SerializeToString())
407    self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
408
409    # Repeat last two cases with values of same magnitude, but negative.
410    message.optional_double = -kMostPosExponentNoSigBits
411    message.ParseFromString(message.SerializeToString())
412    self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
413
414    message.optional_double = -kMostPosExponentOneSigBit
415    message.ParseFromString(message.SerializeToString())
416    self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
417
418    # Most negative exponent, no significand bits set.
419    kMostNegExponentNoSigBits = math.pow(2, -1023)
420    message.optional_double = kMostNegExponentNoSigBits
421    message.ParseFromString(message.SerializeToString())
422    self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
423
424    # Most negative exponent, one significand bit set.
425    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
426    message.optional_double = kMostNegExponentOneSigBit
427    message.ParseFromString(message.SerializeToString())
428    self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
429
430    # Repeat last two cases with values of the same magnitude, but negative.
431    message.optional_double = -kMostNegExponentNoSigBits
432    message.ParseFromString(message.SerializeToString())
433    self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
434
435    message.optional_double = -kMostNegExponentOneSigBit
436    message.ParseFromString(message.SerializeToString())
437    self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
438
439  def testFloatPrinting(self, message_module):
440    message = message_module.TestAllTypes()
441    message.optional_float = 2.0
442    self.assertEqual(str(message), 'optional_float: 2.0\n')
443
444  def testHighPrecisionFloatPrinting(self, message_module):
445    msg = message_module.TestAllTypes()
446    msg.optional_float = 0.12345678912345678
447    old_float = msg.optional_float
448    msg.ParseFromString(msg.SerializeToString())
449    self.assertEqual(old_float, msg.optional_float)
450
451  def testHighPrecisionDoublePrinting(self, message_module):
452    msg = message_module.TestAllTypes()
453    msg.optional_double = 0.12345678912345678
454    if sys.version_info >= (3,):
455      self.assertEqual(str(msg), 'optional_double: 0.12345678912345678\n')
456    else:
457      self.assertEqual(str(msg), 'optional_double: 0.123456789123\n')
458
459  def testUnknownFieldPrinting(self, message_module):
460    populated = message_module.TestAllTypes()
461    test_util.SetAllNonLazyFields(populated)
462    empty = message_module.TestEmptyMessage()
463    empty.ParseFromString(populated.SerializeToString())
464    self.assertEqual(str(empty), '')
465
466  def testAppendRepeatedCompositeField(self, message_module):
467    msg = message_module.TestAllTypes()
468    msg.repeated_nested_message.append(
469        message_module.TestAllTypes.NestedMessage(bb=1))
470    nested = message_module.TestAllTypes.NestedMessage(bb=2)
471    msg.repeated_nested_message.append(nested)
472    try:
473      msg.repeated_nested_message.append(1)
474    except TypeError:
475      pass
476    self.assertEqual(2, len(msg.repeated_nested_message))
477    self.assertEqual([1, 2],
478                     [m.bb for m in msg.repeated_nested_message])
479
480  def testInsertRepeatedCompositeField(self, message_module):
481    msg = message_module.TestAllTypes()
482    msg.repeated_nested_message.insert(
483        -1, message_module.TestAllTypes.NestedMessage(bb=1))
484    sub_msg = msg.repeated_nested_message[0]
485    msg.repeated_nested_message.insert(
486        0, message_module.TestAllTypes.NestedMessage(bb=2))
487    msg.repeated_nested_message.insert(
488        99, message_module.TestAllTypes.NestedMessage(bb=3))
489    msg.repeated_nested_message.insert(
490        -2, message_module.TestAllTypes.NestedMessage(bb=-1))
491    msg.repeated_nested_message.insert(
492        -1000, message_module.TestAllTypes.NestedMessage(bb=-1000))
493    try:
494      msg.repeated_nested_message.insert(1, 999)
495    except TypeError:
496      pass
497    self.assertEqual(5, len(msg.repeated_nested_message))
498    self.assertEqual([-1000, 2, -1, 1, 3],
499                     [m.bb for m in msg.repeated_nested_message])
500    self.assertEqual(str(msg),
501                     'repeated_nested_message {\n'
502                     '  bb: -1000\n'
503                     '}\n'
504                     'repeated_nested_message {\n'
505                     '  bb: 2\n'
506                     '}\n'
507                     'repeated_nested_message {\n'
508                     '  bb: -1\n'
509                     '}\n'
510                     'repeated_nested_message {\n'
511                     '  bb: 1\n'
512                     '}\n'
513                     'repeated_nested_message {\n'
514                     '  bb: 3\n'
515                     '}\n')
516    self.assertEqual(sub_msg.bb, 1)
517
518  def testMergeFromRepeatedField(self, message_module):
519    msg = message_module.TestAllTypes()
520    msg.repeated_int32.append(1)
521    msg.repeated_int32.append(3)
522    msg.repeated_nested_message.add(bb=1)
523    msg.repeated_nested_message.add(bb=2)
524    other_msg = message_module.TestAllTypes()
525    other_msg.repeated_nested_message.add(bb=3)
526    other_msg.repeated_nested_message.add(bb=4)
527    other_msg.repeated_int32.append(5)
528    other_msg.repeated_int32.append(7)
529
530    msg.repeated_int32.MergeFrom(other_msg.repeated_int32)
531    self.assertEqual(4, len(msg.repeated_int32))
532
533    msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
534    self.assertEqual([1, 2, 3, 4],
535                     [m.bb for m in msg.repeated_nested_message])
536
537  def testAddWrongRepeatedNestedField(self, message_module):
538    msg = message_module.TestAllTypes()
539    try:
540      msg.repeated_nested_message.add('wrong')
541    except TypeError:
542      pass
543    try:
544      msg.repeated_nested_message.add(value_field='wrong')
545    except ValueError:
546      pass
547    self.assertEqual(len(msg.repeated_nested_message), 0)
548
549  def testRepeatedContains(self, message_module):
550    msg = message_module.TestAllTypes()
551    msg.repeated_int32.extend([1, 2, 3])
552    self.assertIn(2, msg.repeated_int32)
553    self.assertNotIn(0, msg.repeated_int32)
554
555    msg.repeated_nested_message.add(bb=1)
556    sub_msg1 = msg.repeated_nested_message[0]
557    sub_msg2 = message_module.TestAllTypes.NestedMessage(bb=2)
558    sub_msg3 = message_module.TestAllTypes.NestedMessage(bb=3)
559    msg.repeated_nested_message.append(sub_msg2)
560    msg.repeated_nested_message.insert(0, sub_msg3)
561    self.assertIn(sub_msg1, msg.repeated_nested_message)
562    self.assertIn(sub_msg2, msg.repeated_nested_message)
563    self.assertIn(sub_msg3, msg.repeated_nested_message)
564
565  def testRepeatedScalarIterable(self, message_module):
566    msg = message_module.TestAllTypes()
567    msg.repeated_int32.extend([1, 2, 3])
568    add = 0
569    for item in msg.repeated_int32:
570      add += item
571    self.assertEqual(add, 6)
572
573  def testRepeatedNestedFieldIteration(self, message_module):
574    msg = message_module.TestAllTypes()
575    msg.repeated_nested_message.add(bb=1)
576    msg.repeated_nested_message.add(bb=2)
577    msg.repeated_nested_message.add(bb=3)
578    msg.repeated_nested_message.add(bb=4)
579
580    self.assertEqual([1, 2, 3, 4],
581                     [m.bb for m in msg.repeated_nested_message])
582    self.assertEqual([4, 3, 2, 1],
583                     [m.bb for m in reversed(msg.repeated_nested_message)])
584    self.assertEqual([4, 3, 2, 1],
585                     [m.bb for m in msg.repeated_nested_message[::-1]])
586
587  def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
588    """Check some different types with the default comparator."""
589    message = message_module.TestAllTypes()
590
591    # TODO(mattp): would testing more scalar types strengthen test?
592    message.repeated_int32.append(1)
593    message.repeated_int32.append(3)
594    message.repeated_int32.append(2)
595    message.repeated_int32.sort()
596    self.assertEqual(message.repeated_int32[0], 1)
597    self.assertEqual(message.repeated_int32[1], 2)
598    self.assertEqual(message.repeated_int32[2], 3)
599    self.assertEqual(str(message.repeated_int32), str([1, 2, 3]))
600
601    message.repeated_float.append(1.1)
602    message.repeated_float.append(1.3)
603    message.repeated_float.append(1.2)
604    message.repeated_float.sort()
605    self.assertAlmostEqual(message.repeated_float[0], 1.1)
606    self.assertAlmostEqual(message.repeated_float[1], 1.2)
607    self.assertAlmostEqual(message.repeated_float[2], 1.3)
608
609    message.repeated_string.append('a')
610    message.repeated_string.append('c')
611    message.repeated_string.append('b')
612    message.repeated_string.sort()
613    self.assertEqual(message.repeated_string[0], 'a')
614    self.assertEqual(message.repeated_string[1], 'b')
615    self.assertEqual(message.repeated_string[2], 'c')
616    self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c']))
617
618    message.repeated_bytes.append(b'a')
619    message.repeated_bytes.append(b'c')
620    message.repeated_bytes.append(b'b')
621    message.repeated_bytes.sort()
622    self.assertEqual(message.repeated_bytes[0], b'a')
623    self.assertEqual(message.repeated_bytes[1], b'b')
624    self.assertEqual(message.repeated_bytes[2], b'c')
625    self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c']))
626
627  def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
628    """Check some different types with custom comparator."""
629    message = message_module.TestAllTypes()
630
631    message.repeated_int32.append(-3)
632    message.repeated_int32.append(-2)
633    message.repeated_int32.append(-1)
634    message.repeated_int32.sort(key=abs)
635    self.assertEqual(message.repeated_int32[0], -1)
636    self.assertEqual(message.repeated_int32[1], -2)
637    self.assertEqual(message.repeated_int32[2], -3)
638
639    message.repeated_string.append('aaa')
640    message.repeated_string.append('bb')
641    message.repeated_string.append('c')
642    message.repeated_string.sort(key=len)
643    self.assertEqual(message.repeated_string[0], 'c')
644    self.assertEqual(message.repeated_string[1], 'bb')
645    self.assertEqual(message.repeated_string[2], 'aaa')
646
647  def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
648    """Check passing a custom comparator to sort a repeated composite field."""
649    message = message_module.TestAllTypes()
650
651    message.repeated_nested_message.add().bb = 1
652    message.repeated_nested_message.add().bb = 3
653    message.repeated_nested_message.add().bb = 2
654    message.repeated_nested_message.add().bb = 6
655    message.repeated_nested_message.add().bb = 5
656    message.repeated_nested_message.add().bb = 4
657    message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
658    self.assertEqual(message.repeated_nested_message[0].bb, 1)
659    self.assertEqual(message.repeated_nested_message[1].bb, 2)
660    self.assertEqual(message.repeated_nested_message[2].bb, 3)
661    self.assertEqual(message.repeated_nested_message[3].bb, 4)
662    self.assertEqual(message.repeated_nested_message[4].bb, 5)
663    self.assertEqual(message.repeated_nested_message[5].bb, 6)
664    self.assertEqual(str(message.repeated_nested_message),
665                     '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
666
667  def testSortingRepeatedCompositeFieldsStable(self, message_module):
668    """Check passing a custom comparator to sort a repeated composite field."""
669    message = message_module.TestAllTypes()
670
671    message.repeated_nested_message.add().bb = 21
672    message.repeated_nested_message.add().bb = 20
673    message.repeated_nested_message.add().bb = 13
674    message.repeated_nested_message.add().bb = 33
675    message.repeated_nested_message.add().bb = 11
676    message.repeated_nested_message.add().bb = 24
677    message.repeated_nested_message.add().bb = 10
678    message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
679    self.assertEqual(
680        [13, 11, 10, 21, 20, 24, 33],
681        [n.bb for n in message.repeated_nested_message])
682
683    # Make sure that for the C++ implementation, the underlying fields
684    # are actually reordered.
685    pb = message.SerializeToString()
686    message.Clear()
687    message.MergeFromString(pb)
688    self.assertEqual(
689        [13, 11, 10, 21, 20, 24, 33],
690        [n.bb for n in message.repeated_nested_message])
691
692  def testRepeatedCompositeFieldSortArguments(self, message_module):
693    """Check sorting a repeated composite field using list.sort() arguments."""
694    message = message_module.TestAllTypes()
695
696    get_bb = operator.attrgetter('bb')
697    cmp_bb = lambda a, b: cmp(a.bb, b.bb)
698    message.repeated_nested_message.add().bb = 1
699    message.repeated_nested_message.add().bb = 3
700    message.repeated_nested_message.add().bb = 2
701    message.repeated_nested_message.add().bb = 6
702    message.repeated_nested_message.add().bb = 5
703    message.repeated_nested_message.add().bb = 4
704    message.repeated_nested_message.sort(key=get_bb)
705    self.assertEqual([k.bb for k in message.repeated_nested_message],
706                     [1, 2, 3, 4, 5, 6])
707    message.repeated_nested_message.sort(key=get_bb, reverse=True)
708    self.assertEqual([k.bb for k in message.repeated_nested_message],
709                     [6, 5, 4, 3, 2, 1])
710    if sys.version_info >= (3,): return  # No cmp sorting in PY3.
711    message.repeated_nested_message.sort(sort_function=cmp_bb)
712    self.assertEqual([k.bb for k in message.repeated_nested_message],
713                     [1, 2, 3, 4, 5, 6])
714    message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
715    self.assertEqual([k.bb for k in message.repeated_nested_message],
716                     [6, 5, 4, 3, 2, 1])
717
718  def testRepeatedScalarFieldSortArguments(self, message_module):
719    """Check sorting a scalar field using list.sort() arguments."""
720    message = message_module.TestAllTypes()
721
722    message.repeated_int32.append(-3)
723    message.repeated_int32.append(-2)
724    message.repeated_int32.append(-1)
725    message.repeated_int32.sort(key=abs)
726    self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
727    message.repeated_int32.sort(key=abs, reverse=True)
728    self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
729    if sys.version_info < (3,):  # No cmp sorting in PY3.
730      abs_cmp = lambda a, b: cmp(abs(a), abs(b))
731      message.repeated_int32.sort(sort_function=abs_cmp)
732      self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
733      message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
734      self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
735
736    message.repeated_string.append('aaa')
737    message.repeated_string.append('bb')
738    message.repeated_string.append('c')
739    message.repeated_string.sort(key=len)
740    self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
741    message.repeated_string.sort(key=len, reverse=True)
742    self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
743    if sys.version_info < (3,):  # No cmp sorting in PY3.
744      len_cmp = lambda a, b: cmp(len(a), len(b))
745      message.repeated_string.sort(sort_function=len_cmp)
746      self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
747      message.repeated_string.sort(cmp=len_cmp, reverse=True)
748      self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
749
750  def testRepeatedFieldsComparable(self, message_module):
751    m1 = message_module.TestAllTypes()
752    m2 = message_module.TestAllTypes()
753    m1.repeated_int32.append(0)
754    m1.repeated_int32.append(1)
755    m1.repeated_int32.append(2)
756    m2.repeated_int32.append(0)
757    m2.repeated_int32.append(1)
758    m2.repeated_int32.append(2)
759    m1.repeated_nested_message.add().bb = 1
760    m1.repeated_nested_message.add().bb = 2
761    m1.repeated_nested_message.add().bb = 3
762    m2.repeated_nested_message.add().bb = 1
763    m2.repeated_nested_message.add().bb = 2
764    m2.repeated_nested_message.add().bb = 3
765
766    if sys.version_info >= (3,): return  # No cmp() in PY3.
767
768    # These comparisons should not raise errors.
769    _ = m1 < m2
770    _ = m1.repeated_nested_message < m2.repeated_nested_message
771
772    # Make sure cmp always works. If it wasn't defined, these would be
773    # id() comparisons and would all fail.
774    self.assertEqual(cmp(m1, m2), 0)
775    self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
776    self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
777    self.assertEqual(cmp(m1.repeated_nested_message,
778                         m2.repeated_nested_message), 0)
779    with self.assertRaises(TypeError):
780      # Can't compare repeated composite containers to lists.
781      cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
782
783    # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
784
785  def testRepeatedFieldsAreSequences(self, message_module):
786    m = message_module.TestAllTypes()
787    self.assertIsInstance(m.repeated_int32, collections_abc.MutableSequence)
788    self.assertIsInstance(m.repeated_nested_message,
789                          collections_abc.MutableSequence)
790
791  def testRepeatedFieldsNotHashable(self, message_module):
792    m = message_module.TestAllTypes()
793    with self.assertRaises(TypeError):
794      hash(m.repeated_int32)
795    with self.assertRaises(TypeError):
796      hash(m.repeated_nested_message)
797
798  def testRepeatedFieldInsideNestedMessage(self, message_module):
799    m = message_module.NestedTestAllTypes()
800    m.payload.repeated_int32.extend([])
801    self.assertTrue(m.HasField('payload'))
802
803  def testMergeFrom(self, message_module):
804    m1 = message_module.TestAllTypes()
805    m2 = message_module.TestAllTypes()
806    # Cpp extension will lazily create a sub message which is immutable.
807    nested = m1.optional_nested_message
808    self.assertEqual(0, nested.bb)
809    m2.optional_nested_message.bb = 1
810    # Make sure cmessage pointing to a mutable message after merge instead of
811    # the lazily created message.
812    m1.MergeFrom(m2)
813    self.assertEqual(1, nested.bb)
814
815    # Test more nested sub message.
816    msg1 = message_module.NestedTestAllTypes()
817    msg2 = message_module.NestedTestAllTypes()
818    nested = msg1.child.payload.optional_nested_message
819    self.assertEqual(0, nested.bb)
820    msg2.child.payload.optional_nested_message.bb = 1
821    msg1.MergeFrom(msg2)
822    self.assertEqual(1, nested.bb)
823
824    # Test repeated field.
825    self.assertEqual(msg1.payload.repeated_nested_message,
826                     msg1.payload.repeated_nested_message)
827    nested = msg2.payload.repeated_nested_message.add()
828    nested.bb = 1
829    msg1.MergeFrom(msg2)
830    self.assertEqual(1, len(msg1.payload.repeated_nested_message))
831    self.assertEqual(1, nested.bb)
832
833  def testMergeFromString(self, message_module):
834    m1 = message_module.TestAllTypes()
835    m2 = message_module.TestAllTypes()
836    # Cpp extension will lazily create a sub message which is immutable.
837    self.assertEqual(0, m1.optional_nested_message.bb)
838    m2.optional_nested_message.bb = 1
839    # Make sure cmessage pointing to a mutable message after merge instead of
840    # the lazily created message.
841    m1.MergeFromString(m2.SerializeToString())
842    self.assertEqual(1, m1.optional_nested_message.bb)
843
844  @unittest.skipIf(six.PY2, 'memoryview objects are not supported on py2')
845  def testMergeFromStringUsingMemoryViewWorksInPy3(self, message_module):
846    m2 = message_module.TestAllTypes()
847    m2.optional_string = 'scalar string'
848    m2.repeated_string.append('repeated string')
849    m2.optional_bytes = b'scalar bytes'
850    m2.repeated_bytes.append(b'repeated bytes')
851
852    serialized = m2.SerializeToString()
853    memview = memoryview(serialized)
854    m1 = message_module.TestAllTypes.FromString(memview)
855
856    self.assertEqual(m1.optional_bytes, b'scalar bytes')
857    self.assertEqual(m1.repeated_bytes, [b'repeated bytes'])
858    self.assertEqual(m1.optional_string, 'scalar string')
859    self.assertEqual(m1.repeated_string, ['repeated string'])
860    # Make sure that the memoryview was correctly converted to bytes, and
861    # that a sub-sliced memoryview is not being used.
862    self.assertIsInstance(m1.optional_bytes, bytes)
863    self.assertIsInstance(m1.repeated_bytes[0], bytes)
864    self.assertIsInstance(m1.optional_string, six.text_type)
865    self.assertIsInstance(m1.repeated_string[0], six.text_type)
866
867  @unittest.skipIf(six.PY3, 'memoryview is supported by py3')
868  def testMergeFromStringUsingMemoryViewIsPy2Error(self, message_module):
869    memview = memoryview(b'')
870    with self.assertRaises(TypeError):
871      message_module.TestAllTypes.FromString(memview)
872
873  def testMergeFromEmpty(self, message_module):
874    m1 = message_module.TestAllTypes()
875    # Cpp extension will lazily create a sub message which is immutable.
876    self.assertEqual(0, m1.optional_nested_message.bb)
877    self.assertFalse(m1.HasField('optional_nested_message'))
878    # Make sure the sub message is still immutable after merge from empty.
879    m1.MergeFromString(b'')  # field state should not change
880    self.assertFalse(m1.HasField('optional_nested_message'))
881
882  def ensureNestedMessageExists(self, msg, attribute):
883    """Make sure that a nested message object exists.
884
885    As soon as a nested message attribute is accessed, it will be present in the
886    _fields dict, without being marked as actually being set.
887    """
888    getattr(msg, attribute)
889    self.assertFalse(msg.HasField(attribute))
890
891  def testOneofGetCaseNonexistingField(self, message_module):
892    m = message_module.TestAllTypes()
893    self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
894    self.assertRaises(Exception, m.WhichOneof, 0)
895
896  def testOneofDefaultValues(self, message_module):
897    m = message_module.TestAllTypes()
898    self.assertIs(None, m.WhichOneof('oneof_field'))
899    self.assertFalse(m.HasField('oneof_field'))
900    self.assertFalse(m.HasField('oneof_uint32'))
901
902    # Oneof is set even when setting it to a default value.
903    m.oneof_uint32 = 0
904    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
905    self.assertTrue(m.HasField('oneof_field'))
906    self.assertTrue(m.HasField('oneof_uint32'))
907    self.assertFalse(m.HasField('oneof_string'))
908
909    m.oneof_string = ""
910    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
911    self.assertTrue(m.HasField('oneof_string'))
912    self.assertFalse(m.HasField('oneof_uint32'))
913
914  def testOneofSemantics(self, message_module):
915    m = message_module.TestAllTypes()
916    self.assertIs(None, m.WhichOneof('oneof_field'))
917
918    m.oneof_uint32 = 11
919    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
920    self.assertTrue(m.HasField('oneof_uint32'))
921
922    m.oneof_string = u'foo'
923    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
924    self.assertFalse(m.HasField('oneof_uint32'))
925    self.assertTrue(m.HasField('oneof_string'))
926
927    # Read nested message accessor without accessing submessage.
928    m.oneof_nested_message
929    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
930    self.assertTrue(m.HasField('oneof_string'))
931    self.assertFalse(m.HasField('oneof_nested_message'))
932
933    # Read accessor of nested message without accessing submessage.
934    m.oneof_nested_message.bb
935    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
936    self.assertTrue(m.HasField('oneof_string'))
937    self.assertFalse(m.HasField('oneof_nested_message'))
938
939    m.oneof_nested_message.bb = 11
940    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
941    self.assertFalse(m.HasField('oneof_string'))
942    self.assertTrue(m.HasField('oneof_nested_message'))
943
944    m.oneof_bytes = b'bb'
945    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
946    self.assertFalse(m.HasField('oneof_nested_message'))
947    self.assertTrue(m.HasField('oneof_bytes'))
948
949  def testOneofCompositeFieldReadAccess(self, message_module):
950    m = message_module.TestAllTypes()
951    m.oneof_uint32 = 11
952
953    self.ensureNestedMessageExists(m, 'oneof_nested_message')
954    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
955    self.assertEqual(11, m.oneof_uint32)
956
957  def testOneofWhichOneof(self, message_module):
958    m = message_module.TestAllTypes()
959    self.assertIs(None, m.WhichOneof('oneof_field'))
960    if message_module is unittest_pb2:
961      self.assertFalse(m.HasField('oneof_field'))
962
963    m.oneof_uint32 = 11
964    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
965    if message_module is unittest_pb2:
966      self.assertTrue(m.HasField('oneof_field'))
967
968    m.oneof_bytes = b'bb'
969    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
970
971    m.ClearField('oneof_bytes')
972    self.assertIs(None, m.WhichOneof('oneof_field'))
973    if message_module is unittest_pb2:
974      self.assertFalse(m.HasField('oneof_field'))
975
976  def testOneofClearField(self, message_module):
977    m = message_module.TestAllTypes()
978    m.oneof_uint32 = 11
979    m.ClearField('oneof_field')
980    if message_module is unittest_pb2:
981      self.assertFalse(m.HasField('oneof_field'))
982    self.assertFalse(m.HasField('oneof_uint32'))
983    self.assertIs(None, m.WhichOneof('oneof_field'))
984
985  def testOneofClearSetField(self, message_module):
986    m = message_module.TestAllTypes()
987    m.oneof_uint32 = 11
988    m.ClearField('oneof_uint32')
989    if message_module is unittest_pb2:
990      self.assertFalse(m.HasField('oneof_field'))
991    self.assertFalse(m.HasField('oneof_uint32'))
992    self.assertIs(None, m.WhichOneof('oneof_field'))
993
994  def testOneofClearUnsetField(self, message_module):
995    m = message_module.TestAllTypes()
996    m.oneof_uint32 = 11
997    self.ensureNestedMessageExists(m, 'oneof_nested_message')
998    m.ClearField('oneof_nested_message')
999    self.assertEqual(11, m.oneof_uint32)
1000    if message_module is unittest_pb2:
1001      self.assertTrue(m.HasField('oneof_field'))
1002    self.assertTrue(m.HasField('oneof_uint32'))
1003    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
1004
1005  def testOneofDeserialize(self, message_module):
1006    m = message_module.TestAllTypes()
1007    m.oneof_uint32 = 11
1008    m2 = message_module.TestAllTypes()
1009    m2.ParseFromString(m.SerializeToString())
1010    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
1011
1012  def testOneofCopyFrom(self, message_module):
1013    m = message_module.TestAllTypes()
1014    m.oneof_uint32 = 11
1015    m2 = message_module.TestAllTypes()
1016    m2.CopyFrom(m)
1017    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
1018
1019  def testOneofNestedMergeFrom(self, message_module):
1020    m = message_module.NestedTestAllTypes()
1021    m.payload.oneof_uint32 = 11
1022    m2 = message_module.NestedTestAllTypes()
1023    m2.payload.oneof_bytes = b'bb'
1024    m2.child.payload.oneof_bytes = b'bb'
1025    m2.MergeFrom(m)
1026    self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
1027    self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
1028
1029  def testOneofMessageMergeFrom(self, message_module):
1030    m = message_module.NestedTestAllTypes()
1031    m.payload.oneof_nested_message.bb = 11
1032    m.child.payload.oneof_nested_message.bb = 12
1033    m2 = message_module.NestedTestAllTypes()
1034    m2.payload.oneof_uint32 = 13
1035    m2.MergeFrom(m)
1036    self.assertEqual('oneof_nested_message',
1037                     m2.payload.WhichOneof('oneof_field'))
1038    self.assertEqual('oneof_nested_message',
1039                     m2.child.payload.WhichOneof('oneof_field'))
1040
1041  def testOneofNestedMessageInit(self, message_module):
1042    m = message_module.TestAllTypes(
1043        oneof_nested_message=message_module.TestAllTypes.NestedMessage())
1044    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
1045
1046  def testOneofClear(self, message_module):
1047    m = message_module.TestAllTypes()
1048    m.oneof_uint32 = 11
1049    m.Clear()
1050    self.assertIsNone(m.WhichOneof('oneof_field'))
1051    m.oneof_bytes = b'bb'
1052    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
1053
1054  def testAssignByteStringToUnicodeField(self, message_module):
1055    """Assigning a byte string to a string field should result
1056    in the value being converted to a Unicode string."""
1057    m = message_module.TestAllTypes()
1058    m.optional_string = str('')
1059    self.assertIsInstance(m.optional_string, six.text_type)
1060
1061  def testLongValuedSlice(self, message_module):
1062    """It should be possible to use long-valued indicies in slices
1063
1064    This didn't used to work in the v2 C++ implementation.
1065    """
1066    m = message_module.TestAllTypes()
1067
1068    # Repeated scalar
1069    m.repeated_int32.append(1)
1070    sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
1071    self.assertEqual(len(m.repeated_int32), len(sl))
1072
1073    # Repeated composite
1074    m.repeated_nested_message.add().bb = 3
1075    sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
1076    self.assertEqual(len(m.repeated_nested_message), len(sl))
1077
1078  def testExtendShouldNotSwallowExceptions(self, message_module):
1079    """This didn't use to work in the v2 C++ implementation."""
1080    m = message_module.TestAllTypes()
1081    with self.assertRaises(NameError) as _:
1082      m.repeated_int32.extend(a for i in range(10))  # pylint: disable=undefined-variable
1083    with self.assertRaises(NameError) as _:
1084      m.repeated_nested_enum.extend(
1085          a for i in range(10))  # pylint: disable=undefined-variable
1086
1087  FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
1088
1089  def testExtendInt32WithNothing(self, message_module):
1090    """Test no-ops extending repeated int32 fields."""
1091    m = message_module.TestAllTypes()
1092    self.assertSequenceEqual([], m.repeated_int32)
1093
1094    # TODO(ptucker): Deprecate this behavior. b/18413862
1095    for falsy_value in MessageTest.FALSY_VALUES:
1096      m.repeated_int32.extend(falsy_value)
1097      self.assertSequenceEqual([], m.repeated_int32)
1098
1099    m.repeated_int32.extend([])
1100    self.assertSequenceEqual([], m.repeated_int32)
1101
1102  def testExtendFloatWithNothing(self, message_module):
1103    """Test no-ops extending repeated float fields."""
1104    m = message_module.TestAllTypes()
1105    self.assertSequenceEqual([], m.repeated_float)
1106
1107    # TODO(ptucker): Deprecate this behavior. b/18413862
1108    for falsy_value in MessageTest.FALSY_VALUES:
1109      m.repeated_float.extend(falsy_value)
1110      self.assertSequenceEqual([], m.repeated_float)
1111
1112    m.repeated_float.extend([])
1113    self.assertSequenceEqual([], m.repeated_float)
1114
1115  def testExtendStringWithNothing(self, message_module):
1116    """Test no-ops extending repeated string fields."""
1117    m = message_module.TestAllTypes()
1118    self.assertSequenceEqual([], m.repeated_string)
1119
1120    # TODO(ptucker): Deprecate this behavior. b/18413862
1121    for falsy_value in MessageTest.FALSY_VALUES:
1122      m.repeated_string.extend(falsy_value)
1123      self.assertSequenceEqual([], m.repeated_string)
1124
1125    m.repeated_string.extend([])
1126    self.assertSequenceEqual([], m.repeated_string)
1127
1128  def testExtendInt32WithPythonList(self, message_module):
1129    """Test extending repeated int32 fields with python lists."""
1130    m = message_module.TestAllTypes()
1131    self.assertSequenceEqual([], m.repeated_int32)
1132    m.repeated_int32.extend([0])
1133    self.assertSequenceEqual([0], m.repeated_int32)
1134    m.repeated_int32.extend([1, 2])
1135    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
1136    m.repeated_int32.extend([3, 4])
1137    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
1138
1139  def testExtendFloatWithPythonList(self, message_module):
1140    """Test extending repeated float fields with python lists."""
1141    m = message_module.TestAllTypes()
1142    self.assertSequenceEqual([], m.repeated_float)
1143    m.repeated_float.extend([0.0])
1144    self.assertSequenceEqual([0.0], m.repeated_float)
1145    m.repeated_float.extend([1.0, 2.0])
1146    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
1147    m.repeated_float.extend([3.0, 4.0])
1148    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
1149
1150  def testExtendStringWithPythonList(self, message_module):
1151    """Test extending repeated string fields with python lists."""
1152    m = message_module.TestAllTypes()
1153    self.assertSequenceEqual([], m.repeated_string)
1154    m.repeated_string.extend([''])
1155    self.assertSequenceEqual([''], m.repeated_string)
1156    m.repeated_string.extend(['11', '22'])
1157    self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
1158    m.repeated_string.extend(['33', '44'])
1159    self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
1160
1161  def testExtendStringWithString(self, message_module):
1162    """Test extending repeated string fields with characters from a string."""
1163    m = message_module.TestAllTypes()
1164    self.assertSequenceEqual([], m.repeated_string)
1165    m.repeated_string.extend('abc')
1166    self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
1167
1168  class TestIterable(object):
1169    """This iterable object mimics the behavior of numpy.array.
1170
1171    __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
1172
1173    """
1174
1175    def __init__(self, values=None):
1176      self._list = values or []
1177
1178    def __nonzero__(self):
1179      size = len(self._list)
1180      if size == 0:
1181        return False
1182      if size == 1:
1183        return bool(self._list[0])
1184      raise ValueError('Truth value is ambiguous.')
1185
1186    def __len__(self):
1187      return len(self._list)
1188
1189    def __iter__(self):
1190      return self._list.__iter__()
1191
1192  def testExtendInt32WithIterable(self, message_module):
1193    """Test extending repeated int32 fields with iterable."""
1194    m = message_module.TestAllTypes()
1195    self.assertSequenceEqual([], m.repeated_int32)
1196    m.repeated_int32.extend(MessageTest.TestIterable([]))
1197    self.assertSequenceEqual([], m.repeated_int32)
1198    m.repeated_int32.extend(MessageTest.TestIterable([0]))
1199    self.assertSequenceEqual([0], m.repeated_int32)
1200    m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
1201    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
1202    m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
1203    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
1204
1205  def testExtendFloatWithIterable(self, message_module):
1206    """Test extending repeated float fields with iterable."""
1207    m = message_module.TestAllTypes()
1208    self.assertSequenceEqual([], m.repeated_float)
1209    m.repeated_float.extend(MessageTest.TestIterable([]))
1210    self.assertSequenceEqual([], m.repeated_float)
1211    m.repeated_float.extend(MessageTest.TestIterable([0.0]))
1212    self.assertSequenceEqual([0.0], m.repeated_float)
1213    m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
1214    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
1215    m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
1216    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
1217
1218  def testExtendStringWithIterable(self, message_module):
1219    """Test extending repeated string fields with iterable."""
1220    m = message_module.TestAllTypes()
1221    self.assertSequenceEqual([], m.repeated_string)
1222    m.repeated_string.extend(MessageTest.TestIterable([]))
1223    self.assertSequenceEqual([], m.repeated_string)
1224    m.repeated_string.extend(MessageTest.TestIterable(['']))
1225    self.assertSequenceEqual([''], m.repeated_string)
1226    m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
1227    self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
1228    m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
1229    self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
1230
1231  def testPickleRepeatedScalarContainer(self, message_module):
1232    # TODO(tibell): The pure-Python implementation support pickling of
1233    #   scalar containers in *some* cases. For now the cpp2 version
1234    #   throws an exception to avoid a segfault. Investigate if we
1235    #   want to support pickling of these fields.
1236    #
1237    # For more information see: https://b2.corp.google.com/u/0/issues/18677897
1238    if (api_implementation.Type() != 'cpp' or
1239        api_implementation.Version() == 2):
1240      return
1241    m = message_module.TestAllTypes()
1242    with self.assertRaises(pickle.PickleError) as _:
1243      pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
1244
1245  def testSortEmptyRepeatedCompositeContainer(self, message_module):
1246    """Exercise a scenario that has led to segfaults in the past.
1247    """
1248    m = message_module.TestAllTypes()
1249    m.repeated_nested_message.sort()
1250
1251  def testHasFieldOnRepeatedField(self, message_module):
1252    """Using HasField on a repeated field should raise an exception.
1253    """
1254    m = message_module.TestAllTypes()
1255    with self.assertRaises(ValueError) as _:
1256      m.HasField('repeated_int32')
1257
1258  def testRepeatedScalarFieldPop(self, message_module):
1259    m = message_module.TestAllTypes()
1260    with self.assertRaises(IndexError) as _:
1261      m.repeated_int32.pop()
1262    m.repeated_int32.extend(range(5))
1263    self.assertEqual(4, m.repeated_int32.pop())
1264    self.assertEqual(0, m.repeated_int32.pop(0))
1265    self.assertEqual(2, m.repeated_int32.pop(1))
1266    self.assertEqual([1, 3], m.repeated_int32)
1267
1268  def testRepeatedCompositeFieldPop(self, message_module):
1269    m = message_module.TestAllTypes()
1270    with self.assertRaises(IndexError) as _:
1271      m.repeated_nested_message.pop()
1272    with self.assertRaises(TypeError) as _:
1273      m.repeated_nested_message.pop('0')
1274    for i in range(5):
1275      n = m.repeated_nested_message.add()
1276      n.bb = i
1277    self.assertEqual(4, m.repeated_nested_message.pop().bb)
1278    self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
1279    self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
1280    self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
1281
1282  def testRepeatedCompareWithSelf(self, message_module):
1283    m = message_module.TestAllTypes()
1284    for i in range(5):
1285      m.repeated_int32.insert(i, i)
1286      n = m.repeated_nested_message.add()
1287      n.bb = i
1288    self.assertSequenceEqual(m.repeated_int32, m.repeated_int32)
1289    self.assertEqual(m.repeated_nested_message, m.repeated_nested_message)
1290
1291  def testReleasedNestedMessages(self, message_module):
1292    """A case that lead to a segfault when a message detached from its parent
1293    container has itself a child container.
1294    """
1295    m = message_module.NestedTestAllTypes()
1296    m = m.repeated_child.add()
1297    m = m.child
1298    m = m.repeated_child.add()
1299    self.assertEqual(m.payload.optional_int32, 0)
1300
1301  def testSetRepeatedComposite(self, message_module):
1302    m = message_module.TestAllTypes()
1303    with self.assertRaises(AttributeError):
1304      m.repeated_int32 = []
1305    m.repeated_int32.append(1)
1306    with self.assertRaises(AttributeError):
1307      m.repeated_int32 = []
1308
1309  def testReturningType(self, message_module):
1310    m = message_module.TestAllTypes()
1311    self.assertEqual(float, type(m.optional_float))
1312    self.assertEqual(float, type(m.optional_double))
1313    self.assertEqual(bool, type(m.optional_bool))
1314    m.optional_float = 1
1315    m.optional_double = 1
1316    m.optional_bool = 1
1317    m.repeated_float.append(1)
1318    m.repeated_double.append(1)
1319    m.repeated_bool.append(1)
1320    m.ParseFromString(m.SerializeToString())
1321    self.assertEqual(float, type(m.optional_float))
1322    self.assertEqual(float, type(m.optional_double))
1323    self.assertEqual('1.0', str(m.optional_double))
1324    self.assertEqual(bool, type(m.optional_bool))
1325    self.assertEqual(float, type(m.repeated_float[0]))
1326    self.assertEqual(float, type(m.repeated_double[0]))
1327    self.assertEqual(bool, type(m.repeated_bool[0]))
1328    self.assertEqual(True, m.repeated_bool[0])
1329
1330
1331# Class to test proto2-only features (required, extensions, etc.)
1332@testing_refleaks.TestCase
1333class Proto2Test(unittest.TestCase):
1334
1335  def testFieldPresence(self):
1336    message = unittest_pb2.TestAllTypes()
1337
1338    self.assertFalse(message.HasField("optional_int32"))
1339    self.assertFalse(message.HasField("optional_bool"))
1340    self.assertFalse(message.HasField("optional_nested_message"))
1341
1342    with self.assertRaises(ValueError):
1343      message.HasField("field_doesnt_exist")
1344
1345    with self.assertRaises(ValueError):
1346      message.HasField("repeated_int32")
1347    with self.assertRaises(ValueError):
1348      message.HasField("repeated_nested_message")
1349
1350    self.assertEqual(0, message.optional_int32)
1351    self.assertEqual(False, message.optional_bool)
1352    self.assertEqual(0, message.optional_nested_message.bb)
1353
1354    # Fields are set even when setting the values to default values.
1355    message.optional_int32 = 0
1356    message.optional_bool = False
1357    message.optional_nested_message.bb = 0
1358    self.assertTrue(message.HasField("optional_int32"))
1359    self.assertTrue(message.HasField("optional_bool"))
1360    self.assertTrue(message.HasField("optional_nested_message"))
1361
1362    # Set the fields to non-default values.
1363    message.optional_int32 = 5
1364    message.optional_bool = True
1365    message.optional_nested_message.bb = 15
1366
1367    self.assertTrue(message.HasField(u"optional_int32"))
1368    self.assertTrue(message.HasField("optional_bool"))
1369    self.assertTrue(message.HasField("optional_nested_message"))
1370
1371    # Clearing the fields unsets them and resets their value to default.
1372    message.ClearField("optional_int32")
1373    message.ClearField(u"optional_bool")
1374    message.ClearField("optional_nested_message")
1375
1376    self.assertFalse(message.HasField("optional_int32"))
1377    self.assertFalse(message.HasField("optional_bool"))
1378    self.assertFalse(message.HasField("optional_nested_message"))
1379    self.assertEqual(0, message.optional_int32)
1380    self.assertEqual(False, message.optional_bool)
1381    self.assertEqual(0, message.optional_nested_message.bb)
1382
1383  def testAssignInvalidEnum(self):
1384    """Assigning an invalid enum number is not allowed in proto2."""
1385    m = unittest_pb2.TestAllTypes()
1386
1387    # Proto2 can not assign unknown enum.
1388    with self.assertRaises(ValueError) as _:
1389      m.optional_nested_enum = 1234567
1390    self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
1391    # Assignment is a different code path than append for the C++ impl.
1392    m.repeated_nested_enum.append(2)
1393    m.repeated_nested_enum[0] = 2
1394    with self.assertRaises(ValueError):
1395      m.repeated_nested_enum[0] = 123456
1396
1397    # Unknown enum value can be parsed but is ignored.
1398    m2 = unittest_proto3_arena_pb2.TestAllTypes()
1399    m2.optional_nested_enum = 1234567
1400    m2.repeated_nested_enum.append(7654321)
1401    serialized = m2.SerializeToString()
1402
1403    m3 = unittest_pb2.TestAllTypes()
1404    m3.ParseFromString(serialized)
1405    self.assertFalse(m3.HasField('optional_nested_enum'))
1406    # 1 is the default value for optional_nested_enum.
1407    self.assertEqual(1, m3.optional_nested_enum)
1408    self.assertEqual(0, len(m3.repeated_nested_enum))
1409    m2.Clear()
1410    m2.ParseFromString(m3.SerializeToString())
1411    self.assertEqual(1234567, m2.optional_nested_enum)
1412    self.assertEqual(7654321, m2.repeated_nested_enum[0])
1413
1414  def testUnknownEnumMap(self):
1415    m = map_proto2_unittest_pb2.TestEnumMap()
1416    m.known_map_field[123] = 0
1417    with self.assertRaises(ValueError):
1418      m.unknown_map_field[1] = 123
1419
1420  def testExtensionsErrors(self):
1421    msg = unittest_pb2.TestAllTypes()
1422    self.assertRaises(AttributeError, getattr, msg, 'Extensions')
1423
1424  def testMergeFromExtensions(self):
1425    msg1 = more_extensions_pb2.TopLevelMessage()
1426    msg2 = more_extensions_pb2.TopLevelMessage()
1427    # Cpp extension will lazily create a sub message which is immutable.
1428    self.assertEqual(0, msg1.submessage.Extensions[
1429        more_extensions_pb2.optional_int_extension])
1430    self.assertFalse(msg1.HasField('submessage'))
1431    msg2.submessage.Extensions[
1432        more_extensions_pb2.optional_int_extension] = 123
1433    # Make sure cmessage and extensions pointing to a mutable message
1434    # after merge instead of the lazily created message.
1435    msg1.MergeFrom(msg2)
1436    self.assertEqual(123, msg1.submessage.Extensions[
1437        more_extensions_pb2.optional_int_extension])
1438
1439  def testGoldenExtensions(self):
1440    golden_data = test_util.GoldenFileData('golden_message')
1441    golden_message = unittest_pb2.TestAllExtensions()
1442    golden_message.ParseFromString(golden_data)
1443    all_set = unittest_pb2.TestAllExtensions()
1444    test_util.SetAllExtensions(all_set)
1445    self.assertEqual(all_set, golden_message)
1446    self.assertEqual(golden_data, golden_message.SerializeToString())
1447    golden_copy = copy.deepcopy(golden_message)
1448    self.assertEqual(golden_data, golden_copy.SerializeToString())
1449
1450  def testGoldenPackedExtensions(self):
1451    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
1452    golden_message = unittest_pb2.TestPackedExtensions()
1453    golden_message.ParseFromString(golden_data)
1454    all_set = unittest_pb2.TestPackedExtensions()
1455    test_util.SetAllPackedExtensions(all_set)
1456    self.assertEqual(all_set, golden_message)
1457    self.assertEqual(golden_data, all_set.SerializeToString())
1458    golden_copy = copy.deepcopy(golden_message)
1459    self.assertEqual(golden_data, golden_copy.SerializeToString())
1460
1461  def testPickleIncompleteProto(self):
1462    golden_message = unittest_pb2.TestRequired(a=1)
1463    pickled_message = pickle.dumps(golden_message)
1464
1465    unpickled_message = pickle.loads(pickled_message)
1466    self.assertEqual(unpickled_message, golden_message)
1467    self.assertEqual(unpickled_message.a, 1)
1468    # This is still an incomplete proto - so serializing should fail
1469    self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
1470
1471
1472  # TODO(haberman): this isn't really a proto2-specific test except that this
1473  # message has a required field in it.  Should probably be factored out so
1474  # that we can test the other parts with proto3.
1475  def testParsingMerge(self):
1476    """Check the merge behavior when a required or optional field appears
1477    multiple times in the input."""
1478    messages = [
1479        unittest_pb2.TestAllTypes(),
1480        unittest_pb2.TestAllTypes(),
1481        unittest_pb2.TestAllTypes() ]
1482    messages[0].optional_int32 = 1
1483    messages[1].optional_int64 = 2
1484    messages[2].optional_int32 = 3
1485    messages[2].optional_string = 'hello'
1486
1487    merged_message = unittest_pb2.TestAllTypes()
1488    merged_message.optional_int32 = 3
1489    merged_message.optional_int64 = 2
1490    merged_message.optional_string = 'hello'
1491
1492    generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
1493    generator.field1.extend(messages)
1494    generator.field2.extend(messages)
1495    generator.field3.extend(messages)
1496    generator.ext1.extend(messages)
1497    generator.ext2.extend(messages)
1498    generator.group1.add().field1.MergeFrom(messages[0])
1499    generator.group1.add().field1.MergeFrom(messages[1])
1500    generator.group1.add().field1.MergeFrom(messages[2])
1501    generator.group2.add().field1.MergeFrom(messages[0])
1502    generator.group2.add().field1.MergeFrom(messages[1])
1503    generator.group2.add().field1.MergeFrom(messages[2])
1504
1505    data = generator.SerializeToString()
1506    parsing_merge = unittest_pb2.TestParsingMerge()
1507    parsing_merge.ParseFromString(data)
1508
1509    # Required and optional fields should be merged.
1510    self.assertEqual(parsing_merge.required_all_types, merged_message)
1511    self.assertEqual(parsing_merge.optional_all_types, merged_message)
1512    self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
1513                     merged_message)
1514    self.assertEqual(parsing_merge.Extensions[
1515                     unittest_pb2.TestParsingMerge.optional_ext],
1516                     merged_message)
1517
1518    # Repeated fields should not be merged.
1519    self.assertEqual(len(parsing_merge.repeated_all_types), 3)
1520    self.assertEqual(len(parsing_merge.repeatedgroup), 3)
1521    self.assertEqual(len(parsing_merge.Extensions[
1522        unittest_pb2.TestParsingMerge.repeated_ext]), 3)
1523
1524  def testPythonicInit(self):
1525    message = unittest_pb2.TestAllTypes(
1526        optional_int32=100,
1527        optional_fixed32=200,
1528        optional_float=300.5,
1529        optional_bytes=b'x',
1530        optionalgroup={'a': 400},
1531        optional_nested_message={'bb': 500},
1532        optional_foreign_message={},
1533        optional_nested_enum='BAZ',
1534        repeatedgroup=[{'a': 600},
1535                       {'a': 700}],
1536        repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
1537        default_int32=800,
1538        oneof_string='y')
1539    self.assertIsInstance(message, unittest_pb2.TestAllTypes)
1540    self.assertEqual(100, message.optional_int32)
1541    self.assertEqual(200, message.optional_fixed32)
1542    self.assertEqual(300.5, message.optional_float)
1543    self.assertEqual(b'x', message.optional_bytes)
1544    self.assertEqual(400, message.optionalgroup.a)
1545    self.assertIsInstance(message.optional_nested_message,
1546                          unittest_pb2.TestAllTypes.NestedMessage)
1547    self.assertEqual(500, message.optional_nested_message.bb)
1548    self.assertTrue(message.HasField('optional_foreign_message'))
1549    self.assertEqual(message.optional_foreign_message,
1550                     unittest_pb2.ForeignMessage())
1551    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1552                     message.optional_nested_enum)
1553    self.assertEqual(2, len(message.repeatedgroup))
1554    self.assertEqual(600, message.repeatedgroup[0].a)
1555    self.assertEqual(700, message.repeatedgroup[1].a)
1556    self.assertEqual(2, len(message.repeated_nested_enum))
1557    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
1558                     message.repeated_nested_enum[0])
1559    self.assertEqual(unittest_pb2.TestAllTypes.BAR,
1560                     message.repeated_nested_enum[1])
1561    self.assertEqual(800, message.default_int32)
1562    self.assertEqual('y', message.oneof_string)
1563    self.assertFalse(message.HasField('optional_int64'))
1564    self.assertEqual(0, len(message.repeated_float))
1565    self.assertEqual(42, message.default_int64)
1566
1567    message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
1568    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1569                     message.optional_nested_enum)
1570
1571    with self.assertRaises(ValueError):
1572      unittest_pb2.TestAllTypes(
1573          optional_nested_message={'INVALID_NESTED_FIELD': 17})
1574
1575    with self.assertRaises(TypeError):
1576      unittest_pb2.TestAllTypes(
1577          optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
1578
1579    with self.assertRaises(ValueError):
1580      unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
1581
1582    with self.assertRaises(ValueError):
1583      unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
1584
1585  def testPythonicInitWithDict(self):
1586    # Both string/unicode field name keys should work.
1587    kwargs = {
1588        'optional_int32': 100,
1589        u'optional_fixed32': 200,
1590    }
1591    msg = unittest_pb2.TestAllTypes(**kwargs)
1592    self.assertEqual(100, msg.optional_int32)
1593    self.assertEqual(200, msg.optional_fixed32)
1594
1595
1596  def test_documentation(self):
1597    # Also used by the interactive help() function.
1598    doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message')
1599    self.assertIn('class TestAllTypes', doc)
1600    self.assertIn('SerializePartialToString', doc)
1601    self.assertIn('repeated_float', doc)
1602    base = unittest_pb2.TestAllTypes.__bases__[0]
1603    self.assertRaises(AttributeError, getattr, base, '_extensions_by_name')
1604
1605
1606# Class to test proto3-only features/behavior (updated field presence & enums)
1607@testing_refleaks.TestCase
1608class Proto3Test(unittest.TestCase):
1609
1610  # Utility method for comparing equality with a map.
1611  def assertMapIterEquals(self, map_iter, dict_value):
1612    # Avoid mutating caller's copy.
1613    dict_value = dict(dict_value)
1614
1615    for k, v in map_iter:
1616      self.assertEqual(v, dict_value[k])
1617      del dict_value[k]
1618
1619    self.assertEqual({}, dict_value)
1620
1621  def testFieldPresence(self):
1622    message = unittest_proto3_arena_pb2.TestAllTypes()
1623
1624    # We can't test presence of non-repeated, non-submessage fields.
1625    with self.assertRaises(ValueError):
1626      message.HasField('optional_int32')
1627    with self.assertRaises(ValueError):
1628      message.HasField('optional_float')
1629    with self.assertRaises(ValueError):
1630      message.HasField('optional_string')
1631    with self.assertRaises(ValueError):
1632      message.HasField('optional_bool')
1633
1634    # But we can still test presence of submessage fields.
1635    self.assertFalse(message.HasField('optional_nested_message'))
1636
1637    # As with proto2, we can't test presence of fields that don't exist, or
1638    # repeated fields.
1639    with self.assertRaises(ValueError):
1640      message.HasField('field_doesnt_exist')
1641
1642    with self.assertRaises(ValueError):
1643      message.HasField('repeated_int32')
1644    with self.assertRaises(ValueError):
1645      message.HasField('repeated_nested_message')
1646
1647    # Fields should default to their type-specific default.
1648    self.assertEqual(0, message.optional_int32)
1649    self.assertEqual(0, message.optional_float)
1650    self.assertEqual('', message.optional_string)
1651    self.assertEqual(False, message.optional_bool)
1652    self.assertEqual(0, message.optional_nested_message.bb)
1653
1654    # Setting a submessage should still return proper presence information.
1655    message.optional_nested_message.bb = 0
1656    self.assertTrue(message.HasField('optional_nested_message'))
1657
1658    # Set the fields to non-default values.
1659    message.optional_int32 = 5
1660    message.optional_float = 1.1
1661    message.optional_string = 'abc'
1662    message.optional_bool = True
1663    message.optional_nested_message.bb = 15
1664
1665    # Clearing the fields unsets them and resets their value to default.
1666    message.ClearField('optional_int32')
1667    message.ClearField('optional_float')
1668    message.ClearField('optional_string')
1669    message.ClearField('optional_bool')
1670    message.ClearField('optional_nested_message')
1671
1672    self.assertEqual(0, message.optional_int32)
1673    self.assertEqual(0, message.optional_float)
1674    self.assertEqual('', message.optional_string)
1675    self.assertEqual(False, message.optional_bool)
1676    self.assertEqual(0, message.optional_nested_message.bb)
1677
1678  def testProto3ParserDropDefaultScalar(self):
1679    message_proto2 = unittest_pb2.TestAllTypes()
1680    message_proto2.optional_int32 = 0
1681    message_proto2.optional_string = ''
1682    message_proto2.optional_bytes = b''
1683    self.assertEqual(len(message_proto2.ListFields()), 3)
1684
1685    message_proto3 = unittest_proto3_arena_pb2.TestAllTypes()
1686    message_proto3.ParseFromString(message_proto2.SerializeToString())
1687    self.assertEqual(len(message_proto3.ListFields()), 0)
1688
1689  def testProto3Optional(self):
1690    msg = test_proto3_optional_pb2.TestProto3Optional()
1691    self.assertFalse(msg.HasField('optional_int32'))
1692    self.assertFalse(msg.HasField('optional_float'))
1693    self.assertFalse(msg.HasField('optional_string'))
1694    self.assertFalse(msg.HasField('optional_nested_message'))
1695    self.assertFalse(msg.optional_nested_message.HasField('bb'))
1696
1697    # Set fields.
1698    msg.optional_int32 = 1
1699    msg.optional_float = 1.0
1700    msg.optional_string = '123'
1701    msg.optional_nested_message.bb = 1
1702    self.assertTrue(msg.HasField('optional_int32'))
1703    self.assertTrue(msg.HasField('optional_float'))
1704    self.assertTrue(msg.HasField('optional_string'))
1705    self.assertTrue(msg.HasField('optional_nested_message'))
1706    self.assertTrue(msg.optional_nested_message.HasField('bb'))
1707    # Set to default value does not clear the fields
1708    msg.optional_int32 = 0
1709    msg.optional_float = 0.0
1710    msg.optional_string = ''
1711    msg.optional_nested_message.bb = 0
1712    self.assertTrue(msg.HasField('optional_int32'))
1713    self.assertTrue(msg.HasField('optional_float'))
1714    self.assertTrue(msg.HasField('optional_string'))
1715    self.assertTrue(msg.HasField('optional_nested_message'))
1716    self.assertTrue(msg.optional_nested_message.HasField('bb'))
1717
1718    # Test serialize
1719    msg2 = test_proto3_optional_pb2.TestProto3Optional()
1720    msg2.ParseFromString(msg.SerializeToString())
1721    self.assertTrue(msg2.HasField('optional_int32'))
1722    self.assertTrue(msg2.HasField('optional_float'))
1723    self.assertTrue(msg2.HasField('optional_string'))
1724    self.assertTrue(msg2.HasField('optional_nested_message'))
1725    self.assertTrue(msg2.optional_nested_message.HasField('bb'))
1726
1727    self.assertEqual(msg.WhichOneof('_optional_int32'), 'optional_int32')
1728
1729    # Clear these fields.
1730    msg.ClearField('optional_int32')
1731    msg.ClearField('optional_float')
1732    msg.ClearField('optional_string')
1733    msg.ClearField('optional_nested_message')
1734    self.assertFalse(msg.HasField('optional_int32'))
1735    self.assertFalse(msg.HasField('optional_float'))
1736    self.assertFalse(msg.HasField('optional_string'))
1737    self.assertFalse(msg.HasField('optional_nested_message'))
1738    self.assertFalse(msg.optional_nested_message.HasField('bb'))
1739
1740    self.assertEqual(msg.WhichOneof('_optional_int32'), None)
1741
1742  def testAssignUnknownEnum(self):
1743    """Assigning an unknown enum value is allowed and preserves the value."""
1744    m = unittest_proto3_arena_pb2.TestAllTypes()
1745
1746    # Proto3 can assign unknown enums.
1747    m.optional_nested_enum = 1234567
1748    self.assertEqual(1234567, m.optional_nested_enum)
1749    m.repeated_nested_enum.append(22334455)
1750    self.assertEqual(22334455, m.repeated_nested_enum[0])
1751    # Assignment is a different code path than append for the C++ impl.
1752    m.repeated_nested_enum[0] = 7654321
1753    self.assertEqual(7654321, m.repeated_nested_enum[0])
1754    serialized = m.SerializeToString()
1755
1756    m2 = unittest_proto3_arena_pb2.TestAllTypes()
1757    m2.ParseFromString(serialized)
1758    self.assertEqual(1234567, m2.optional_nested_enum)
1759    self.assertEqual(7654321, m2.repeated_nested_enum[0])
1760
1761  # Map isn't really a proto3-only feature. But there is no proto2 equivalent
1762  # of google/protobuf/map_unittest.proto right now, so it's not easy to
1763  # test both with the same test like we do for the other proto2/proto3 tests.
1764  # (google/protobuf/map_proto2_unittest.proto is very different in the set
1765  # of messages and fields it contains).
1766  def testScalarMapDefaults(self):
1767    msg = map_unittest_pb2.TestMap()
1768
1769    # Scalars start out unset.
1770    self.assertFalse(-123 in msg.map_int32_int32)
1771    self.assertFalse(-2**33 in msg.map_int64_int64)
1772    self.assertFalse(123 in msg.map_uint32_uint32)
1773    self.assertFalse(2**33 in msg.map_uint64_uint64)
1774    self.assertFalse(123 in msg.map_int32_double)
1775    self.assertFalse(False in msg.map_bool_bool)
1776    self.assertFalse('abc' in msg.map_string_string)
1777    self.assertFalse(111 in msg.map_int32_bytes)
1778    self.assertFalse(888 in msg.map_int32_enum)
1779
1780    # Accessing an unset key returns the default.
1781    self.assertEqual(0, msg.map_int32_int32[-123])
1782    self.assertEqual(0, msg.map_int64_int64[-2**33])
1783    self.assertEqual(0, msg.map_uint32_uint32[123])
1784    self.assertEqual(0, msg.map_uint64_uint64[2**33])
1785    self.assertEqual(0.0, msg.map_int32_double[123])
1786    self.assertTrue(isinstance(msg.map_int32_double[123], float))
1787    self.assertEqual(False, msg.map_bool_bool[False])
1788    self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
1789    self.assertEqual('', msg.map_string_string['abc'])
1790    self.assertEqual(b'', msg.map_int32_bytes[111])
1791    self.assertEqual(0, msg.map_int32_enum[888])
1792
1793    # It also sets the value in the map
1794    self.assertTrue(-123 in msg.map_int32_int32)
1795    self.assertTrue(-2**33 in msg.map_int64_int64)
1796    self.assertTrue(123 in msg.map_uint32_uint32)
1797    self.assertTrue(2**33 in msg.map_uint64_uint64)
1798    self.assertTrue(123 in msg.map_int32_double)
1799    self.assertTrue(False in msg.map_bool_bool)
1800    self.assertTrue('abc' in msg.map_string_string)
1801    self.assertTrue(111 in msg.map_int32_bytes)
1802    self.assertTrue(888 in msg.map_int32_enum)
1803
1804    self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
1805
1806    # Accessing an unset key still throws TypeError if the type of the key
1807    # is incorrect.
1808    with self.assertRaises(TypeError):
1809      msg.map_string_string[123]
1810
1811    with self.assertRaises(TypeError):
1812      123 in msg.map_string_string
1813
1814  def testMapGet(self):
1815    # Need to test that get() properly returns the default, even though the dict
1816    # has defaultdict-like semantics.
1817    msg = map_unittest_pb2.TestMap()
1818
1819    self.assertIsNone(msg.map_int32_int32.get(5))
1820    self.assertEqual(10, msg.map_int32_int32.get(5, 10))
1821    self.assertEqual(10, msg.map_int32_int32.get(key=5, default=10))
1822    self.assertIsNone(msg.map_int32_int32.get(5))
1823
1824    msg.map_int32_int32[5] = 15
1825    self.assertEqual(15, msg.map_int32_int32.get(5))
1826    self.assertEqual(15, msg.map_int32_int32.get(5))
1827    with self.assertRaises(TypeError):
1828      msg.map_int32_int32.get('')
1829
1830    self.assertIsNone(msg.map_int32_foreign_message.get(5))
1831    self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
1832    self.assertEqual(10, msg.map_int32_foreign_message.get(key=5, default=10))
1833
1834    submsg = msg.map_int32_foreign_message[5]
1835    self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
1836    with self.assertRaises(TypeError):
1837      msg.map_int32_foreign_message.get('')
1838
1839  def testScalarMap(self):
1840    msg = map_unittest_pb2.TestMap()
1841
1842    self.assertEqual(0, len(msg.map_int32_int32))
1843    self.assertFalse(5 in msg.map_int32_int32)
1844
1845    msg.map_int32_int32[-123] = -456
1846    msg.map_int64_int64[-2**33] = -2**34
1847    msg.map_uint32_uint32[123] = 456
1848    msg.map_uint64_uint64[2**33] = 2**34
1849    msg.map_int32_float[2] = 1.2
1850    msg.map_int32_double[1] = 3.3
1851    msg.map_string_string['abc'] = '123'
1852    msg.map_bool_bool[True] = True
1853    msg.map_int32_enum[888] = 2
1854    # Unknown numeric enum is supported in proto3.
1855    msg.map_int32_enum[123] = 456
1856
1857    self.assertEqual([], msg.FindInitializationErrors())
1858
1859    self.assertEqual(1, len(msg.map_string_string))
1860
1861    # Bad key.
1862    with self.assertRaises(TypeError):
1863      msg.map_string_string[123] = '123'
1864
1865    # Verify that trying to assign a bad key doesn't actually add a member to
1866    # the map.
1867    self.assertEqual(1, len(msg.map_string_string))
1868
1869    # Bad value.
1870    with self.assertRaises(TypeError):
1871      msg.map_string_string['123'] = 123
1872
1873    serialized = msg.SerializeToString()
1874    msg2 = map_unittest_pb2.TestMap()
1875    msg2.ParseFromString(serialized)
1876
1877    # Bad key.
1878    with self.assertRaises(TypeError):
1879      msg2.map_string_string[123] = '123'
1880
1881    # Bad value.
1882    with self.assertRaises(TypeError):
1883      msg2.map_string_string['123'] = 123
1884
1885    self.assertEqual(-456, msg2.map_int32_int32[-123])
1886    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
1887    self.assertEqual(456, msg2.map_uint32_uint32[123])
1888    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
1889    self.assertAlmostEqual(1.2, msg.map_int32_float[2])
1890    self.assertEqual(3.3, msg.map_int32_double[1])
1891    self.assertEqual('123', msg2.map_string_string['abc'])
1892    self.assertEqual(True, msg2.map_bool_bool[True])
1893    self.assertEqual(2, msg2.map_int32_enum[888])
1894    self.assertEqual(456, msg2.map_int32_enum[123])
1895    self.assertEqual('{-123: -456}',
1896                     str(msg2.map_int32_int32))
1897
1898  def testMapEntryAlwaysSerialized(self):
1899    msg = map_unittest_pb2.TestMap()
1900    msg.map_int32_int32[0] = 0
1901    msg.map_string_string[''] = ''
1902    self.assertEqual(msg.ByteSize(), 12)
1903    self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00',
1904                     msg.SerializeToString())
1905
1906  def testStringUnicodeConversionInMap(self):
1907    msg = map_unittest_pb2.TestMap()
1908
1909    unicode_obj = u'\u1234'
1910    bytes_obj = unicode_obj.encode('utf8')
1911
1912    msg.map_string_string[bytes_obj] = bytes_obj
1913
1914    (key, value) = list(msg.map_string_string.items())[0]
1915
1916    self.assertEqual(key, unicode_obj)
1917    self.assertEqual(value, unicode_obj)
1918
1919    self.assertIsInstance(key, six.text_type)
1920    self.assertIsInstance(value, six.text_type)
1921
1922  def testMessageMap(self):
1923    msg = map_unittest_pb2.TestMap()
1924
1925    self.assertEqual(0, len(msg.map_int32_foreign_message))
1926    self.assertFalse(5 in msg.map_int32_foreign_message)
1927
1928    msg.map_int32_foreign_message[123]
1929    # get_or_create() is an alias for getitem.
1930    msg.map_int32_foreign_message.get_or_create(-456)
1931
1932    self.assertEqual(2, len(msg.map_int32_foreign_message))
1933    self.assertIn(123, msg.map_int32_foreign_message)
1934    self.assertIn(-456, msg.map_int32_foreign_message)
1935    self.assertEqual(2, len(msg.map_int32_foreign_message))
1936
1937    # Bad key.
1938    with self.assertRaises(TypeError):
1939      msg.map_int32_foreign_message['123']
1940
1941    # Can't assign directly to submessage.
1942    with self.assertRaises(ValueError):
1943      msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
1944
1945    # Verify that trying to assign a bad key doesn't actually add a member to
1946    # the map.
1947    self.assertEqual(2, len(msg.map_int32_foreign_message))
1948
1949    serialized = msg.SerializeToString()
1950    msg2 = map_unittest_pb2.TestMap()
1951    msg2.ParseFromString(serialized)
1952
1953    self.assertEqual(2, len(msg2.map_int32_foreign_message))
1954    self.assertIn(123, msg2.map_int32_foreign_message)
1955    self.assertIn(-456, msg2.map_int32_foreign_message)
1956    self.assertEqual(2, len(msg2.map_int32_foreign_message))
1957    msg2.map_int32_foreign_message[123].c = 1
1958    # TODO(jieluo): Fix text format for message map.
1959    self.assertIn(str(msg2.map_int32_foreign_message),
1960                  ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
1961
1962  def testNestedMessageMapItemDelete(self):
1963    msg = map_unittest_pb2.TestMap()
1964    msg.map_int32_all_types[1].optional_nested_message.bb = 1
1965    del msg.map_int32_all_types[1]
1966    msg.map_int32_all_types[2].optional_nested_message.bb = 2
1967    self.assertEqual(1, len(msg.map_int32_all_types))
1968    msg.map_int32_all_types[1].optional_nested_message.bb = 1
1969    self.assertEqual(2, len(msg.map_int32_all_types))
1970
1971    serialized = msg.SerializeToString()
1972    msg2 = map_unittest_pb2.TestMap()
1973    msg2.ParseFromString(serialized)
1974    keys = [1, 2]
1975    # The loop triggers PyErr_Occurred() in c extension.
1976    for key in keys:
1977      del msg2.map_int32_all_types[key]
1978
1979  def testMapByteSize(self):
1980    msg = map_unittest_pb2.TestMap()
1981    msg.map_int32_int32[1] = 1
1982    size = msg.ByteSize()
1983    msg.map_int32_int32[1] = 128
1984    self.assertEqual(msg.ByteSize(), size + 1)
1985
1986    msg.map_int32_foreign_message[19].c = 1
1987    size = msg.ByteSize()
1988    msg.map_int32_foreign_message[19].c = 128
1989    self.assertEqual(msg.ByteSize(), size + 1)
1990
1991  def testMergeFrom(self):
1992    msg = map_unittest_pb2.TestMap()
1993    msg.map_int32_int32[12] = 34
1994    msg.map_int32_int32[56] = 78
1995    msg.map_int64_int64[22] = 33
1996    msg.map_int32_foreign_message[111].c = 5
1997    msg.map_int32_foreign_message[222].c = 10
1998
1999    msg2 = map_unittest_pb2.TestMap()
2000    msg2.map_int32_int32[12] = 55
2001    msg2.map_int64_int64[88] = 99
2002    msg2.map_int32_foreign_message[222].c = 15
2003    msg2.map_int32_foreign_message[222].d = 20
2004    old_map_value = msg2.map_int32_foreign_message[222]
2005
2006    msg2.MergeFrom(msg)
2007    # Compare with expected message instead of call
2008    # msg2.map_int32_foreign_message[222] to make sure MergeFrom does not
2009    # sync with repeated field and there is no duplicated keys.
2010    expected_msg = map_unittest_pb2.TestMap()
2011    expected_msg.CopyFrom(msg)
2012    expected_msg.map_int64_int64[88] = 99
2013    self.assertEqual(msg2, expected_msg)
2014
2015    self.assertEqual(34, msg2.map_int32_int32[12])
2016    self.assertEqual(78, msg2.map_int32_int32[56])
2017    self.assertEqual(33, msg2.map_int64_int64[22])
2018    self.assertEqual(99, msg2.map_int64_int64[88])
2019    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
2020    self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
2021    self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
2022    if api_implementation.Type() != 'cpp':
2023      # During the call to MergeFrom(), the C++ implementation will have
2024      # deallocated the underlying message, but this is very difficult to detect
2025      # properly. The line below is likely to cause a segmentation fault.
2026      # With the Python implementation, old_map_value is just 'detached' from
2027      # the main message. Using it will not crash of course, but since it still
2028      # have a reference to the parent message I'm sure we can find interesting
2029      # ways to cause inconsistencies.
2030      self.assertEqual(15, old_map_value.c)
2031
2032    # Verify that there is only one entry per key, even though the MergeFrom
2033    # may have internally created multiple entries for a single key in the
2034    # list representation.
2035    as_dict = {}
2036    for key in msg2.map_int32_foreign_message:
2037      self.assertFalse(key in as_dict)
2038      as_dict[key] = msg2.map_int32_foreign_message[key].c
2039
2040    self.assertEqual({111: 5, 222: 10}, as_dict)
2041
2042    # Special case: test that delete of item really removes the item, even if
2043    # there might have physically been duplicate keys due to the previous merge.
2044    # This is only a special case for the C++ implementation which stores the
2045    # map as an array.
2046    del msg2.map_int32_int32[12]
2047    self.assertFalse(12 in msg2.map_int32_int32)
2048
2049    del msg2.map_int32_foreign_message[222]
2050    self.assertFalse(222 in msg2.map_int32_foreign_message)
2051    with self.assertRaises(TypeError):
2052      del msg2.map_int32_foreign_message['']
2053
2054  def testMapMergeFrom(self):
2055    msg = map_unittest_pb2.TestMap()
2056    msg.map_int32_int32[12] = 34
2057    msg.map_int32_int32[56] = 78
2058    msg.map_int64_int64[22] = 33
2059    msg.map_int32_foreign_message[111].c = 5
2060    msg.map_int32_foreign_message[222].c = 10
2061
2062    msg2 = map_unittest_pb2.TestMap()
2063    msg2.map_int32_int32[12] = 55
2064    msg2.map_int64_int64[88] = 99
2065    msg2.map_int32_foreign_message[222].c = 15
2066    msg2.map_int32_foreign_message[222].d = 20
2067
2068    msg2.map_int32_int32.MergeFrom(msg.map_int32_int32)
2069    self.assertEqual(34, msg2.map_int32_int32[12])
2070    self.assertEqual(78, msg2.map_int32_int32[56])
2071
2072    msg2.map_int64_int64.MergeFrom(msg.map_int64_int64)
2073    self.assertEqual(33, msg2.map_int64_int64[22])
2074    self.assertEqual(99, msg2.map_int64_int64[88])
2075
2076    msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message)
2077    # Compare with expected message instead of call
2078    # msg.map_int32_foreign_message[222] to make sure MergeFrom does not
2079    # sync with repeated field and no duplicated keys.
2080    expected_msg = map_unittest_pb2.TestMap()
2081    expected_msg.CopyFrom(msg)
2082    expected_msg.map_int64_int64[88] = 99
2083    self.assertEqual(msg2, expected_msg)
2084
2085    # Test when cpp extension cache a map.
2086    m1 = map_unittest_pb2.TestMap()
2087    m2 = map_unittest_pb2.TestMap()
2088    self.assertEqual(m1.map_int32_foreign_message,
2089                     m1.map_int32_foreign_message)
2090    m2.map_int32_foreign_message[123].c = 10
2091    m1.MergeFrom(m2)
2092    self.assertEqual(10, m2.map_int32_foreign_message[123].c)
2093
2094    # Test merge maps within different message types.
2095    m1 = map_unittest_pb2.TestMap()
2096    m2 = map_unittest_pb2.TestMessageMap()
2097    m2.map_int32_message[123].optional_int32 = 10
2098    m1.map_int32_all_types.MergeFrom(m2.map_int32_message)
2099    self.assertEqual(10, m1.map_int32_all_types[123].optional_int32)
2100
2101    # Test overwrite message value map
2102    msg = map_unittest_pb2.TestMap()
2103    msg.map_int32_foreign_message[222].c = 123
2104    msg2 = map_unittest_pb2.TestMap()
2105    msg2.map_int32_foreign_message[222].d = 20
2106    msg.MergeFromString(msg2.SerializeToString())
2107    self.assertEqual(msg.map_int32_foreign_message[222].d, 20)
2108    self.assertNotEqual(msg.map_int32_foreign_message[222].c, 123)
2109
2110  def testMergeFromBadType(self):
2111    msg = map_unittest_pb2.TestMap()
2112    with self.assertRaisesRegexp(
2113        TypeError,
2114        r'Parameter to MergeFrom\(\) must be instance of same class: expected '
2115        r'.*TestMap got int\.'):
2116      msg.MergeFrom(1)
2117
2118  def testCopyFromBadType(self):
2119    msg = map_unittest_pb2.TestMap()
2120    with self.assertRaisesRegexp(
2121        TypeError,
2122        r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
2123        r'expected .*TestMap got int\.'):
2124      msg.CopyFrom(1)
2125
2126  def testIntegerMapWithLongs(self):
2127    msg = map_unittest_pb2.TestMap()
2128    msg.map_int32_int32[long(-123)] = long(-456)
2129    msg.map_int64_int64[long(-2**33)] = long(-2**34)
2130    msg.map_uint32_uint32[long(123)] = long(456)
2131    msg.map_uint64_uint64[long(2**33)] = long(2**34)
2132
2133    serialized = msg.SerializeToString()
2134    msg2 = map_unittest_pb2.TestMap()
2135    msg2.ParseFromString(serialized)
2136
2137    self.assertEqual(-456, msg2.map_int32_int32[-123])
2138    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
2139    self.assertEqual(456, msg2.map_uint32_uint32[123])
2140    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
2141
2142  def testMapAssignmentCausesPresence(self):
2143    msg = map_unittest_pb2.TestMapSubmessage()
2144    msg.test_map.map_int32_int32[123] = 456
2145
2146    serialized = msg.SerializeToString()
2147    msg2 = map_unittest_pb2.TestMapSubmessage()
2148    msg2.ParseFromString(serialized)
2149
2150    self.assertEqual(msg, msg2)
2151
2152    # Now test that various mutations of the map properly invalidate the
2153    # cached size of the submessage.
2154    msg.test_map.map_int32_int32[888] = 999
2155    serialized = msg.SerializeToString()
2156    msg2.ParseFromString(serialized)
2157    self.assertEqual(msg, msg2)
2158
2159    msg.test_map.map_int32_int32.clear()
2160    serialized = msg.SerializeToString()
2161    msg2.ParseFromString(serialized)
2162    self.assertEqual(msg, msg2)
2163
2164  def testMapAssignmentCausesPresenceForSubmessages(self):
2165    msg = map_unittest_pb2.TestMapSubmessage()
2166    msg.test_map.map_int32_foreign_message[123].c = 5
2167
2168    serialized = msg.SerializeToString()
2169    msg2 = map_unittest_pb2.TestMapSubmessage()
2170    msg2.ParseFromString(serialized)
2171
2172    self.assertEqual(msg, msg2)
2173
2174    # Now test that various mutations of the map properly invalidate the
2175    # cached size of the submessage.
2176    msg.test_map.map_int32_foreign_message[888].c = 7
2177    serialized = msg.SerializeToString()
2178    msg2.ParseFromString(serialized)
2179    self.assertEqual(msg, msg2)
2180
2181    msg.test_map.map_int32_foreign_message[888].MergeFrom(
2182        msg.test_map.map_int32_foreign_message[123])
2183    serialized = msg.SerializeToString()
2184    msg2.ParseFromString(serialized)
2185    self.assertEqual(msg, msg2)
2186
2187    msg.test_map.map_int32_foreign_message.clear()
2188    serialized = msg.SerializeToString()
2189    msg2.ParseFromString(serialized)
2190    self.assertEqual(msg, msg2)
2191
2192  def testModifyMapWhileIterating(self):
2193    msg = map_unittest_pb2.TestMap()
2194
2195    string_string_iter = iter(msg.map_string_string)
2196    int32_foreign_iter = iter(msg.map_int32_foreign_message)
2197
2198    msg.map_string_string['abc'] = '123'
2199    msg.map_int32_foreign_message[5].c = 5
2200
2201    with self.assertRaises(RuntimeError):
2202      for key in string_string_iter:
2203        pass
2204
2205    with self.assertRaises(RuntimeError):
2206      for key in int32_foreign_iter:
2207        pass
2208
2209  def testSubmessageMap(self):
2210    msg = map_unittest_pb2.TestMap()
2211
2212    submsg = msg.map_int32_foreign_message[111]
2213    self.assertIs(submsg, msg.map_int32_foreign_message[111])
2214    self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
2215
2216    submsg.c = 5
2217
2218    serialized = msg.SerializeToString()
2219    msg2 = map_unittest_pb2.TestMap()
2220    msg2.ParseFromString(serialized)
2221
2222    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
2223
2224    # Doesn't allow direct submessage assignment.
2225    with self.assertRaises(ValueError):
2226      msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
2227
2228  def testMapIteration(self):
2229    msg = map_unittest_pb2.TestMap()
2230
2231    for k, v in msg.map_int32_int32.items():
2232      # Should not be reached.
2233      self.assertTrue(False)
2234
2235    msg.map_int32_int32[2] = 4
2236    msg.map_int32_int32[3] = 6
2237    msg.map_int32_int32[4] = 8
2238    self.assertEqual(3, len(msg.map_int32_int32))
2239
2240    matching_dict = {2: 4, 3: 6, 4: 8}
2241    self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
2242
2243  def testPython2Map(self):
2244    if sys.version_info < (3,):
2245      msg = map_unittest_pb2.TestMap()
2246      msg.map_int32_int32[2] = 4
2247      msg.map_int32_int32[3] = 6
2248      msg.map_int32_int32[4] = 8
2249      msg.map_int32_int32[5] = 10
2250      map_int32 = msg.map_int32_int32
2251      self.assertEqual(4, len(map_int32))
2252      msg2 = map_unittest_pb2.TestMap()
2253      msg2.ParseFromString(msg.SerializeToString())
2254
2255      def CheckItems(seq, iterator):
2256        self.assertEqual(next(iterator), seq[0])
2257        self.assertEqual(list(iterator), seq[1:])
2258
2259      CheckItems(map_int32.items(), map_int32.iteritems())
2260      CheckItems(map_int32.keys(), map_int32.iterkeys())
2261      CheckItems(map_int32.values(), map_int32.itervalues())
2262
2263      self.assertEqual(6, map_int32.get(3))
2264      self.assertEqual(None, map_int32.get(999))
2265      self.assertEqual(6, map_int32.pop(3))
2266      self.assertEqual(0, map_int32.pop(3))
2267      self.assertEqual(3, len(map_int32))
2268      key, value = map_int32.popitem()
2269      self.assertEqual(2 * key, value)
2270      self.assertEqual(2, len(map_int32))
2271      map_int32.clear()
2272      self.assertEqual(0, len(map_int32))
2273
2274      with self.assertRaises(KeyError):
2275        map_int32.popitem()
2276
2277      self.assertEqual(0, map_int32.setdefault(2))
2278      self.assertEqual(1, len(map_int32))
2279
2280      map_int32.update(msg2.map_int32_int32)
2281      self.assertEqual(4, len(map_int32))
2282
2283      with self.assertRaises(TypeError):
2284        map_int32.update(msg2.map_int32_int32,
2285                         msg2.map_int32_int32)
2286      with self.assertRaises(TypeError):
2287        map_int32.update(0)
2288      with self.assertRaises(TypeError):
2289        map_int32.update(value=12)
2290
2291  def testMapItems(self):
2292    # Map items used to have strange behaviors when use c extension. Because
2293    # [] may reorder the map and invalidate any exsting iterators.
2294    # TODO(jieluo): Check if [] reordering the map is a bug or intended
2295    # behavior.
2296    msg = map_unittest_pb2.TestMap()
2297    msg.map_string_string['local_init_op'] = ''
2298    msg.map_string_string['trainable_variables'] = ''
2299    msg.map_string_string['variables'] = ''
2300    msg.map_string_string['init_op'] = ''
2301    msg.map_string_string['summaries'] = ''
2302    items1 = msg.map_string_string.items()
2303    items2 = msg.map_string_string.items()
2304    self.assertEqual(items1, items2)
2305
2306  def testMapDeterministicSerialization(self):
2307    golden_data = (b'r\x0c\n\x07init_op\x12\x01d'
2308                   b'r\n\n\x05item1\x12\x01e'
2309                   b'r\n\n\x05item2\x12\x01f'
2310                   b'r\n\n\x05item3\x12\x01g'
2311                   b'r\x0b\n\x05item4\x12\x02QQ'
2312                   b'r\x12\n\rlocal_init_op\x12\x01a'
2313                   b'r\x0e\n\tsummaries\x12\x01e'
2314                   b'r\x18\n\x13trainable_variables\x12\x01b'
2315                   b'r\x0e\n\tvariables\x12\x01c')
2316    msg = map_unittest_pb2.TestMap()
2317    msg.map_string_string['local_init_op'] = 'a'
2318    msg.map_string_string['trainable_variables'] = 'b'
2319    msg.map_string_string['variables'] = 'c'
2320    msg.map_string_string['init_op'] = 'd'
2321    msg.map_string_string['summaries'] = 'e'
2322    msg.map_string_string['item1'] = 'e'
2323    msg.map_string_string['item2'] = 'f'
2324    msg.map_string_string['item3'] = 'g'
2325    msg.map_string_string['item4'] = 'QQ'
2326
2327    # If deterministic serialization is not working correctly, this will be
2328    # "flaky" depending on the exact python dict hash seed.
2329    #
2330    # Fortunately, there are enough items in this map that it is extremely
2331    # unlikely to ever hit the "right" in-order combination, so the test
2332    # itself should fail reliably.
2333    self.assertEqual(golden_data, msg.SerializeToString(deterministic=True))
2334
2335  def testMapIterationClearMessage(self):
2336    # Iterator needs to work even if message and map are deleted.
2337    msg = map_unittest_pb2.TestMap()
2338
2339    msg.map_int32_int32[2] = 4
2340    msg.map_int32_int32[3] = 6
2341    msg.map_int32_int32[4] = 8
2342
2343    it = msg.map_int32_int32.items()
2344    del msg
2345
2346    matching_dict = {2: 4, 3: 6, 4: 8}
2347    self.assertMapIterEquals(it, matching_dict)
2348
2349  def testMapConstruction(self):
2350    msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
2351    self.assertEqual(2, msg.map_int32_int32[1])
2352    self.assertEqual(4, msg.map_int32_int32[3])
2353
2354    msg = map_unittest_pb2.TestMap(
2355        map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
2356    self.assertEqual(5, msg.map_int32_foreign_message[3].c)
2357
2358  def testMapScalarFieldConstruction(self):
2359    msg1 = map_unittest_pb2.TestMap()
2360    msg1.map_int32_int32[1] = 42
2361    msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32)
2362    self.assertEqual(42, msg2.map_int32_int32[1])
2363
2364  def testMapMessageFieldConstruction(self):
2365    msg1 = map_unittest_pb2.TestMap()
2366    msg1.map_string_foreign_message['test'].c = 42
2367    msg2 = map_unittest_pb2.TestMap(
2368      map_string_foreign_message=msg1.map_string_foreign_message)
2369    self.assertEqual(42, msg2.map_string_foreign_message['test'].c)
2370
2371  def testMapFieldRaisesCorrectError(self):
2372    # Should raise a TypeError when given a non-iterable.
2373    with self.assertRaises(TypeError):
2374      map_unittest_pb2.TestMap(map_string_foreign_message=1)
2375
2376  def testMapValidAfterFieldCleared(self):
2377    # Map needs to work even if field is cleared.
2378    # For the C++ implementation this tests the correctness of
2379    # MapContainer::Release()
2380    msg = map_unittest_pb2.TestMap()
2381    int32_map = msg.map_int32_int32
2382
2383    int32_map[2] = 4
2384    int32_map[3] = 6
2385    int32_map[4] = 8
2386
2387    msg.ClearField('map_int32_int32')
2388    self.assertEqual(b'', msg.SerializeToString())
2389    matching_dict = {2: 4, 3: 6, 4: 8}
2390    self.assertMapIterEquals(int32_map.items(), matching_dict)
2391
2392  def testMessageMapValidAfterFieldCleared(self):
2393    # Map needs to work even if field is cleared.
2394    # For the C++ implementation this tests the correctness of
2395    # MapContainer::Release()
2396    msg = map_unittest_pb2.TestMap()
2397    int32_foreign_message = msg.map_int32_foreign_message
2398
2399    int32_foreign_message[2].c = 5
2400
2401    msg.ClearField('map_int32_foreign_message')
2402    self.assertEqual(b'', msg.SerializeToString())
2403    self.assertTrue(2 in int32_foreign_message.keys())
2404
2405  def testMessageMapItemValidAfterTopMessageCleared(self):
2406    # Message map item needs to work even if it is cleared.
2407    # For the C++ implementation this tests the correctness of
2408    # MapContainer::Release()
2409    msg = map_unittest_pb2.TestMap()
2410    msg.map_int32_all_types[2].optional_string = 'bar'
2411
2412    if api_implementation.Type() == 'cpp':
2413      # Need to keep the map reference because of b/27942626.
2414      # TODO(jieluo): Remove it.
2415      unused_map = msg.map_int32_all_types  # pylint: disable=unused-variable
2416    msg_value = msg.map_int32_all_types[2]
2417    msg.Clear()
2418
2419    # Reset to trigger sync between repeated field and map in c++.
2420    msg.map_int32_all_types[3].optional_string = 'foo'
2421    self.assertEqual(msg_value.optional_string, 'bar')
2422
2423  def testMapIterInvalidatedByClearField(self):
2424    # Map iterator is invalidated when field is cleared.
2425    # But this case does need to not crash the interpreter.
2426    # For the C++ implementation this tests the correctness of
2427    # ScalarMapContainer::Release()
2428    msg = map_unittest_pb2.TestMap()
2429
2430    it = iter(msg.map_int32_int32)
2431
2432    msg.ClearField('map_int32_int32')
2433    with self.assertRaises(RuntimeError):
2434      for _ in it:
2435        pass
2436
2437    it = iter(msg.map_int32_foreign_message)
2438    msg.ClearField('map_int32_foreign_message')
2439    with self.assertRaises(RuntimeError):
2440      for _ in it:
2441        pass
2442
2443  def testMapDelete(self):
2444    msg = map_unittest_pb2.TestMap()
2445
2446    self.assertEqual(0, len(msg.map_int32_int32))
2447
2448    msg.map_int32_int32[4] = 6
2449    self.assertEqual(1, len(msg.map_int32_int32))
2450
2451    with self.assertRaises(KeyError):
2452      del msg.map_int32_int32[88]
2453
2454    del msg.map_int32_int32[4]
2455    self.assertEqual(0, len(msg.map_int32_int32))
2456
2457    with self.assertRaises(KeyError):
2458      del msg.map_int32_all_types[32]
2459
2460  def testMapsAreMapping(self):
2461    msg = map_unittest_pb2.TestMap()
2462    self.assertIsInstance(msg.map_int32_int32, collections_abc.Mapping)
2463    self.assertIsInstance(msg.map_int32_int32, collections_abc.MutableMapping)
2464    self.assertIsInstance(msg.map_int32_foreign_message, collections_abc.Mapping)
2465    self.assertIsInstance(msg.map_int32_foreign_message,
2466                          collections_abc.MutableMapping)
2467
2468  def testMapsCompare(self):
2469    msg = map_unittest_pb2.TestMap()
2470    msg.map_int32_int32[-123] = -456
2471    self.assertEqual(msg.map_int32_int32, msg.map_int32_int32)
2472    self.assertEqual(msg.map_int32_foreign_message,
2473                     msg.map_int32_foreign_message)
2474    self.assertNotEqual(msg.map_int32_int32, 0)
2475
2476  def testMapFindInitializationErrorsSmokeTest(self):
2477    msg = map_unittest_pb2.TestMap()
2478    msg.map_string_string['abc'] = '123'
2479    msg.map_int32_int32[35] = 64
2480    msg.map_string_foreign_message['foo'].c = 5
2481    self.assertEqual(0, len(msg.FindInitializationErrors()))
2482
2483  @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2')
2484  def testStrictUtf8Check(self):
2485    # Test u'\ud801' is rejected at parser in both python2 and python3.
2486    serialized = (b'r\x03\xed\xa0\x81')
2487    msg = unittest_proto3_arena_pb2.TestAllTypes()
2488    with self.assertRaises(Exception) as context:
2489      msg.MergeFromString(serialized)
2490    if api_implementation.Type() == 'python':
2491      self.assertIn('optional_string', str(context.exception))
2492    else:
2493      self.assertIn('Error parsing message', str(context.exception))
2494
2495    # Test optional_string=u'��' is accepted.
2496    serialized = unittest_proto3_arena_pb2.TestAllTypes(
2497        optional_string=u'��').SerializeToString()
2498    msg2 = unittest_proto3_arena_pb2.TestAllTypes()
2499    msg2.MergeFromString(serialized)
2500    self.assertEqual(msg2.optional_string, u'��')
2501
2502    msg = unittest_proto3_arena_pb2.TestAllTypes(
2503        optional_string=u'\ud001')
2504    self.assertEqual(msg.optional_string, u'\ud001')
2505
2506  @unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2')
2507  def testSurrogatesInPython3(self):
2508    # Surrogates like U+D83D is an invalid unicode character, it is
2509    # supported by Python2 only because in some builds, unicode strings
2510    # use 2-bytes code units. Since Python 3.3, we don't have this problem.
2511    #
2512    # Surrogates are utf16 code units, in a unicode string they are invalid
2513    # characters even when they appear in pairs like u'\ud801\udc01'. Protobuf
2514    # Python3 reject such cases at setters and parsers. Python2 accpect it
2515    # to keep same features with the language itself. 'Unpaired pairs'
2516    # like u'\ud801' are rejected at parsers when strict utf8 check is enabled
2517    # in proto3 to keep same behavior with c extension.
2518
2519    # Surrogates are rejected at setters in Python3.
2520    with self.assertRaises(ValueError):
2521      unittest_proto3_arena_pb2.TestAllTypes(
2522          optional_string=u'\ud801\udc01')
2523    with self.assertRaises(ValueError):
2524      unittest_proto3_arena_pb2.TestAllTypes(
2525          optional_string=b'\xed\xa0\x81')
2526    with self.assertRaises(ValueError):
2527      unittest_proto3_arena_pb2.TestAllTypes(
2528          optional_string=u'\ud801')
2529    with self.assertRaises(ValueError):
2530      unittest_proto3_arena_pb2.TestAllTypes(
2531          optional_string=u'\ud801\ud801')
2532
2533  @unittest.skipIf(six.PY3 or sys.maxunicode == UCS2_MAXUNICODE,
2534                   'Surrogates are rejected at setters in Python3')
2535  def testSurrogatesInPython2(self):
2536    # Test optional_string=u'\ud801\udc01'.
2537    # surrogate pair is acceptable in python2.
2538    msg = unittest_proto3_arena_pb2.TestAllTypes(
2539        optional_string=u'\ud801\udc01')
2540    # TODO(jieluo): Change pure python to have same behavior with c extension.
2541    # Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are
2542    # equal, some are not equal.
2543    if api_implementation.Type() == 'python':
2544      self.assertEqual(msg.optional_string, u'\ud801\udc01')
2545    else:
2546      self.assertEqual(msg.optional_string, u'\U00010401')
2547    serialized = msg.SerializeToString()
2548    msg2 = unittest_proto3_arena_pb2.TestAllTypes()
2549    msg2.MergeFromString(serialized)
2550    self.assertEqual(msg2.optional_string, u'\U00010401')
2551
2552    # Python2 does not reject surrogates at setters.
2553    msg = unittest_proto3_arena_pb2.TestAllTypes(
2554        optional_string=b'\xed\xa0\x81')
2555    unittest_proto3_arena_pb2.TestAllTypes(
2556        optional_string=u'\ud801')
2557    unittest_proto3_arena_pb2.TestAllTypes(
2558        optional_string=u'\ud801\ud801')
2559
2560
2561@testing_refleaks.TestCase
2562class ValidTypeNamesTest(unittest.TestCase):
2563
2564  def assertImportFromName(self, msg, base_name):
2565    # Parse <type 'module.class_name'> to extra 'some.name' as a string.
2566    tp_name = str(type(msg)).split("'")[1]
2567    valid_names = ('Repeated%sContainer' % base_name,
2568                   'Repeated%sFieldContainer' % base_name)
2569    self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
2570                    '%r does end with any of %r' % (tp_name, valid_names))
2571
2572    parts = tp_name.split('.')
2573    class_name = parts[-1]
2574    module_name = '.'.join(parts[:-1])
2575    __import__(module_name, fromlist=[class_name])
2576
2577  def testTypeNamesCanBeImported(self):
2578    # If import doesn't work, pickling won't work either.
2579    pb = unittest_pb2.TestAllTypes()
2580    self.assertImportFromName(pb.repeated_int32, 'Scalar')
2581    self.assertImportFromName(pb.repeated_nested_message, 'Composite')
2582
2583@testing_refleaks.TestCase
2584class PackedFieldTest(unittest.TestCase):
2585
2586  def setMessage(self, message):
2587    message.repeated_int32.append(1)
2588    message.repeated_int64.append(1)
2589    message.repeated_uint32.append(1)
2590    message.repeated_uint64.append(1)
2591    message.repeated_sint32.append(1)
2592    message.repeated_sint64.append(1)
2593    message.repeated_fixed32.append(1)
2594    message.repeated_fixed64.append(1)
2595    message.repeated_sfixed32.append(1)
2596    message.repeated_sfixed64.append(1)
2597    message.repeated_float.append(1.0)
2598    message.repeated_double.append(1.0)
2599    message.repeated_bool.append(True)
2600    message.repeated_nested_enum.append(1)
2601
2602  def testPackedFields(self):
2603    message = packed_field_test_pb2.TestPackedTypes()
2604    self.setMessage(message)
2605    golden_data = (b'\x0A\x01\x01'
2606                   b'\x12\x01\x01'
2607                   b'\x1A\x01\x01'
2608                   b'\x22\x01\x01'
2609                   b'\x2A\x01\x02'
2610                   b'\x32\x01\x02'
2611                   b'\x3A\x04\x01\x00\x00\x00'
2612                   b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2613                   b'\x4A\x04\x01\x00\x00\x00'
2614                   b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2615                   b'\x5A\x04\x00\x00\x80\x3f'
2616                   b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
2617                   b'\x6A\x01\x01'
2618                   b'\x72\x01\x01')
2619    self.assertEqual(golden_data, message.SerializeToString())
2620
2621  def testUnpackedFields(self):
2622    message = packed_field_test_pb2.TestUnpackedTypes()
2623    self.setMessage(message)
2624    golden_data = (b'\x08\x01'
2625                   b'\x10\x01'
2626                   b'\x18\x01'
2627                   b'\x20\x01'
2628                   b'\x28\x02'
2629                   b'\x30\x02'
2630                   b'\x3D\x01\x00\x00\x00'
2631                   b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
2632                   b'\x4D\x01\x00\x00\x00'
2633                   b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
2634                   b'\x5D\x00\x00\x80\x3f'
2635                   b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
2636                   b'\x68\x01'
2637                   b'\x70\x01')
2638    self.assertEqual(golden_data, message.SerializeToString())
2639
2640
2641@unittest.skipIf(api_implementation.Type() != 'cpp' or
2642                 sys.version_info < (2, 7),
2643                 'explicit tests of the C++ implementation for PY27 and above')
2644@testing_refleaks.TestCase
2645class OversizeProtosTest(unittest.TestCase):
2646
2647  @classmethod
2648  def setUpClass(cls):
2649    # At the moment, reference cycles between DescriptorPool and Message classes
2650    # are not detected and these objects are never freed.
2651    # To avoid errors with ReferenceLeakChecker, we create the class only once.
2652    file_desc = """
2653      name: "f/f.msg2"
2654      package: "f"
2655      message_type {
2656        name: "msg1"
2657        field {
2658          name: "payload"
2659          number: 1
2660          label: LABEL_OPTIONAL
2661          type: TYPE_STRING
2662        }
2663      }
2664      message_type {
2665        name: "msg2"
2666        field {
2667          name: "field"
2668          number: 1
2669          label: LABEL_OPTIONAL
2670          type: TYPE_MESSAGE
2671          type_name: "msg1"
2672        }
2673      }
2674    """
2675    pool = descriptor_pool.DescriptorPool()
2676    desc = descriptor_pb2.FileDescriptorProto()
2677    text_format.Parse(file_desc, desc)
2678    pool.Add(desc)
2679    cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
2680        pool.FindMessageTypeByName('f.msg2'))
2681
2682  def setUp(self):
2683    self.p = self.proto_cls()
2684    self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
2685    self.p_serialized = self.p.SerializeToString()
2686
2687  def testAssertOversizeProto(self):
2688    from google.protobuf.pyext._message import SetAllowOversizeProtos
2689    SetAllowOversizeProtos(False)
2690    q = self.proto_cls()
2691    try:
2692      q.ParseFromString(self.p_serialized)
2693    except message.DecodeError as e:
2694      self.assertEqual(str(e), 'Error parsing message')
2695
2696  def testSucceedOversizeProto(self):
2697    from google.protobuf.pyext._message import SetAllowOversizeProtos
2698    SetAllowOversizeProtos(True)
2699    q = self.proto_cls()
2700    q.ParseFromString(self.p_serialized)
2701    self.assertEqual(self.p.field.payload, q.field.payload)
2702
2703if __name__ == '__main__':
2704  unittest.main()
2705