• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2# Protocol Buffers - Google's data interchange format
3# Copyright 2008 Google Inc.  All rights reserved.
4#
5# Use of this source code is governed by a BSD-style
6# license that can be found in the LICENSE file or at
7# https://developers.google.com/open-source/licenses/bsd
8
9"""Tests python protocol buffers against the golden message.
10
11Note that the golden messages exercise every known field type, thus this
12test ends up exercising and verifying nearly all of the parsing and
13serialization code in the whole library.
14"""
15
16__author__ = 'gps@google.com (Gregory P. Smith)'
17
18import collections
19import copy
20import math
21import operator
22import pickle
23import pydoc
24import sys
25import types
26import unittest
27from unittest import mock
28import warnings
29
30cmp = lambda x, y: (x > y) - (x < y)
31
32from google.protobuf.internal import api_implementation # pylint: disable=g-import-not-at-top
33from google.protobuf.internal import encoder
34from google.protobuf.internal import enum_type_wrapper
35from google.protobuf.internal import more_extensions_pb2
36from google.protobuf.internal import more_messages_pb2
37from google.protobuf.internal import packed_field_test_pb2
38from google.protobuf.internal import self_recursive_pb2
39from google.protobuf.internal import test_proto3_optional_pb2
40from google.protobuf.internal import test_util
41from google.protobuf.internal import testing_refleaks
42from google.protobuf import descriptor
43from google.protobuf import message
44from google.protobuf.internal import _parameterized
45from google.protobuf import map_proto2_unittest_pb2
46from google.protobuf import map_unittest_pb2
47from google.protobuf import unittest_pb2
48from google.protobuf import unittest_proto3_arena_pb2
49
50UCS2_MAXUNICODE = 65535
51
52warnings.simplefilter('error', DeprecationWarning)
53
54@_parameterized.named_parameters(('_proto2', unittest_pb2),
55                                ('_proto3', unittest_proto3_arena_pb2))
56@testing_refleaks.TestCase
57class MessageTest(unittest.TestCase):
58
59  def testBadUtf8String(self, message_module):
60    if api_implementation.Type() != 'python':
61      self.skipTest('Skipping testBadUtf8String, currently only the python '
62                    'api implementation raises UnicodeDecodeError when a '
63                    'string field contains bad utf-8.')
64    bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
65    with self.assertRaises(UnicodeDecodeError) as context:
66      message_module.TestAllTypes.FromString(bad_utf8_data)
67    self.assertIn('TestAllTypes.optional_string', str(context.exception))
68
69  def testParseErrors(self, message_module):
70    msg = message_module.TestAllTypes()
71    self.assertRaises(TypeError, msg.FromString, 0)
72    self.assertRaises(Exception, msg.FromString, '0')
73    # TODO: Fix cpp extension to raise error instead of warning.
74    # b/27494216
75    end_tag = encoder.TagBytes(1, 4)
76    if (api_implementation.Type() == 'python' or
77        api_implementation.Type() == 'upb'):
78      with self.assertRaises(message.DecodeError) as context:
79        msg.FromString(end_tag)
80      if api_implementation.Type() == 'python':
81        # Only pure-Python has an error message this specific.
82        self.assertEqual('Unexpected end-group tag.', str(context.exception))
83
84    # Field number 0 is illegal.
85    self.assertRaises(message.DecodeError, msg.FromString, b'\3\4')
86
87  def testDeterminismParameters(self, message_module):
88    # This message is always deterministically serialized, even if determinism
89    # is disabled, so we can use it to verify that all the determinism
90    # parameters work correctly.
91    golden_data = (b'\xe2\x02\nOne string'
92                   b'\xe2\x02\nTwo string'
93                   b'\xe2\x02\nRed string'
94                   b'\xe2\x02\x0bBlue string')
95    golden_message = message_module.TestAllTypes()
96    golden_message.repeated_string.extend([
97        'One string',
98        'Two string',
99        'Red string',
100        'Blue string',
101    ])
102    self.assertEqual(golden_data,
103                     golden_message.SerializeToString(deterministic=None))
104    self.assertEqual(golden_data,
105                     golden_message.SerializeToString(deterministic=False))
106    self.assertEqual(golden_data,
107                     golden_message.SerializeToString(deterministic=True))
108
109    class BadArgError(Exception):
110      pass
111
112    class BadArg(object):
113
114      def __nonzero__(self):
115        raise BadArgError()
116
117      def __bool__(self):
118        raise BadArgError()
119
120    with self.assertRaises(BadArgError):
121      golden_message.SerializeToString(deterministic=BadArg())
122
123  def testPickleSupport(self, message_module):
124    golden_message = message_module.TestAllTypes()
125    test_util.SetAllFields(golden_message)
126    golden_data = golden_message.SerializeToString()
127    golden_message = message_module.TestAllTypes()
128    golden_message.ParseFromString(golden_data)
129    pickled_message = pickle.dumps(golden_message)
130
131    unpickled_message = pickle.loads(pickled_message)
132    self.assertEqual(unpickled_message, golden_message)
133
134  def testPickleNestedMessage(self, message_module):
135    golden_message = message_module.TestPickleNestedMessage.NestedMessage(bb=1)
136    pickled_message = pickle.dumps(golden_message)
137    unpickled_message = pickle.loads(pickled_message)
138    self.assertEqual(unpickled_message, golden_message)
139
140  def testPickleNestedNestedMessage(self, message_module):
141    cls = message_module.TestPickleNestedMessage.NestedMessage
142    golden_message = cls.NestedNestedMessage(cc=1)
143    pickled_message = pickle.dumps(golden_message)
144    unpickled_message = pickle.loads(pickled_message)
145    self.assertEqual(unpickled_message, golden_message)
146
147  def testPositiveInfinity(self, message_module):
148    if message_module is unittest_pb2:
149      golden_data = (b'\x5D\x00\x00\x80\x7F'
150                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
151                     b'\xCD\x02\x00\x00\x80\x7F'
152                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
153    else:
154      golden_data = (b'\x5D\x00\x00\x80\x7F'
155                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
156                     b'\xCA\x02\x04\x00\x00\x80\x7F'
157                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
158
159    golden_message = message_module.TestAllTypes()
160    golden_message.ParseFromString(golden_data)
161    self.assertEqual(golden_message.optional_float, math.inf)
162    self.assertEqual(golden_message.optional_double, math.inf)
163    self.assertEqual(golden_message.repeated_float[0], math.inf)
164    self.assertEqual(golden_message.repeated_double[0], math.inf)
165    self.assertEqual(golden_data, golden_message.SerializeToString())
166
167  def testNegativeInfinity(self, message_module):
168    if message_module is unittest_pb2:
169      golden_data = (b'\x5D\x00\x00\x80\xFF'
170                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
171                     b'\xCD\x02\x00\x00\x80\xFF'
172                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
173    else:
174      golden_data = (b'\x5D\x00\x00\x80\xFF'
175                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
176                     b'\xCA\x02\x04\x00\x00\x80\xFF'
177                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
178
179    golden_message = message_module.TestAllTypes()
180    golden_message.ParseFromString(golden_data)
181    self.assertEqual(golden_message.optional_float, -math.inf)
182    self.assertEqual(golden_message.optional_double, -math.inf)
183    self.assertEqual(golden_message.repeated_float[0], -math.inf)
184    self.assertEqual(golden_message.repeated_double[0], -math.inf)
185    self.assertEqual(golden_data, golden_message.SerializeToString())
186
187  def testNotANumber(self, message_module):
188    golden_data = (b'\x5D\x00\x00\xC0\x7F'
189                   b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
190                   b'\xCD\x02\x00\x00\xC0\x7F'
191                   b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
192    golden_message = message_module.TestAllTypes()
193    golden_message.ParseFromString(golden_data)
194    self.assertTrue(math.isnan(golden_message.optional_float))
195    self.assertTrue(math.isnan(golden_message.optional_double))
196    self.assertTrue(math.isnan(golden_message.repeated_float[0]))
197    self.assertTrue(math.isnan(golden_message.repeated_double[0]))
198
199    # The protocol buffer may serialize to any one of multiple different
200    # representations of a NaN.  Rather than verify a specific representation,
201    # verify the serialized string can be converted into a correctly
202    # behaving protocol buffer.
203    serialized = golden_message.SerializeToString()
204    message = message_module.TestAllTypes()
205    message.ParseFromString(serialized)
206    self.assertTrue(math.isnan(message.optional_float))
207    self.assertTrue(math.isnan(message.optional_double))
208    self.assertTrue(math.isnan(message.repeated_float[0]))
209    self.assertTrue(math.isnan(message.repeated_double[0]))
210
211  def testPositiveInfinityPacked(self, message_module):
212    golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
213                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
214    golden_message = message_module.TestPackedTypes()
215    golden_message.ParseFromString(golden_data)
216    self.assertEqual(golden_message.packed_float[0], math.inf)
217    self.assertEqual(golden_message.packed_double[0], math.inf)
218    self.assertEqual(golden_data, golden_message.SerializeToString())
219
220  def testNegativeInfinityPacked(self, message_module):
221    golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
222                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
223    golden_message = message_module.TestPackedTypes()
224    golden_message.ParseFromString(golden_data)
225    self.assertEqual(golden_message.packed_float[0], -math.inf)
226    self.assertEqual(golden_message.packed_double[0], -math.inf)
227    self.assertEqual(golden_data, golden_message.SerializeToString())
228
229  def testNotANumberPacked(self, message_module):
230    golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
231                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
232    golden_message = message_module.TestPackedTypes()
233    golden_message.ParseFromString(golden_data)
234    self.assertTrue(math.isnan(golden_message.packed_float[0]))
235    self.assertTrue(math.isnan(golden_message.packed_double[0]))
236
237    serialized = golden_message.SerializeToString()
238    message = message_module.TestPackedTypes()
239    message.ParseFromString(serialized)
240    self.assertTrue(math.isnan(message.packed_float[0]))
241    self.assertTrue(math.isnan(message.packed_double[0]))
242
243  def testExtremeFloatValues(self, message_module):
244    message = message_module.TestAllTypes()
245
246    # Most positive exponent, no significand bits set.
247    kMostPosExponentNoSigBits = math.pow(2, 127)
248    message.optional_float = kMostPosExponentNoSigBits
249    message.ParseFromString(message.SerializeToString())
250    self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
251
252    # Most positive exponent, one significand bit set.
253    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
254    message.optional_float = kMostPosExponentOneSigBit
255    message.ParseFromString(message.SerializeToString())
256    self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
257
258    # Repeat last two cases with values of same magnitude, but negative.
259    message.optional_float = -kMostPosExponentNoSigBits
260    message.ParseFromString(message.SerializeToString())
261    self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
262
263    message.optional_float = -kMostPosExponentOneSigBit
264    message.ParseFromString(message.SerializeToString())
265    self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
266
267    # Most negative exponent, no significand bits set.
268    kMostNegExponentNoSigBits = math.pow(2, -127)
269    message.optional_float = kMostNegExponentNoSigBits
270    message.ParseFromString(message.SerializeToString())
271    self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
272
273    # Most negative exponent, one significand bit set.
274    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
275    message.optional_float = kMostNegExponentOneSigBit
276    message.ParseFromString(message.SerializeToString())
277    self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
278
279    # Repeat last two cases with values of the same magnitude, but negative.
280    message.optional_float = -kMostNegExponentNoSigBits
281    message.ParseFromString(message.SerializeToString())
282    self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
283
284    message.optional_float = -kMostNegExponentOneSigBit
285    message.ParseFromString(message.SerializeToString())
286    self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
287
288    # Max 4 bytes float value
289    max_float = float.fromhex('0x1.fffffep+127')
290    message.optional_float = max_float
291    self.assertAlmostEqual(message.optional_float, max_float)
292    serialized_data = message.SerializeToString()
293    message.ParseFromString(serialized_data)
294    self.assertAlmostEqual(message.optional_float, max_float)
295
296    # Test set double to float field.
297    message.optional_float = 3.4028235e+39
298    self.assertEqual(message.optional_float, float('inf'))
299    serialized_data = message.SerializeToString()
300    message.ParseFromString(serialized_data)
301    self.assertEqual(message.optional_float, float('inf'))
302
303    message.optional_float = -3.4028235e+39
304    self.assertEqual(message.optional_float, float('-inf'))
305
306    message.optional_float = 1.4028235e-39
307    self.assertAlmostEqual(message.optional_float, 1.4028235e-39)
308
309  def testExtremeDoubleValues(self, message_module):
310    message = message_module.TestAllTypes()
311
312    # Most positive exponent, no significand bits set.
313    kMostPosExponentNoSigBits = math.pow(2, 1023)
314    message.optional_double = kMostPosExponentNoSigBits
315    message.ParseFromString(message.SerializeToString())
316    self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
317
318    # Most positive exponent, one significand bit set.
319    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
320    message.optional_double = kMostPosExponentOneSigBit
321    message.ParseFromString(message.SerializeToString())
322    self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
323
324    # Repeat last two cases with values of same magnitude, but negative.
325    message.optional_double = -kMostPosExponentNoSigBits
326    message.ParseFromString(message.SerializeToString())
327    self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
328
329    message.optional_double = -kMostPosExponentOneSigBit
330    message.ParseFromString(message.SerializeToString())
331    self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
332
333    # Most negative exponent, no significand bits set.
334    kMostNegExponentNoSigBits = math.pow(2, -1023)
335    message.optional_double = kMostNegExponentNoSigBits
336    message.ParseFromString(message.SerializeToString())
337    self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
338
339    # Most negative exponent, one significand bit set.
340    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
341    message.optional_double = kMostNegExponentOneSigBit
342    message.ParseFromString(message.SerializeToString())
343    self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
344
345    # Repeat last two cases with values of the same magnitude, but negative.
346    message.optional_double = -kMostNegExponentNoSigBits
347    message.ParseFromString(message.SerializeToString())
348    self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
349
350    message.optional_double = -kMostNegExponentOneSigBit
351    message.ParseFromString(message.SerializeToString())
352    self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
353
354  def testFloatPrinting(self, message_module):
355    message = message_module.TestAllTypes()
356    message.optional_float = 2.0
357    self.assertEqual(str(message), 'optional_float: 2.0\n')
358
359  def testFloatNanPrinting(self, message_module):
360    message = message_module.TestAllTypes()
361    message.optional_float = float('nan')
362    self.assertEqual(str(message), 'optional_float: nan\n')
363
364  def testHighPrecisionFloatPrinting(self, message_module):
365    msg = message_module.TestAllTypes()
366    msg.optional_float = 0.12345678912345678
367    old_float = msg.optional_float
368    msg.ParseFromString(msg.SerializeToString())
369    self.assertEqual(old_float, msg.optional_float)
370
371  def testDoubleNanPrinting(self, message_module):
372    message = message_module.TestAllTypes()
373    message.optional_double = float('nan')
374    self.assertEqual(str(message), 'optional_double: nan\n')
375
376  def testHighPrecisionDoublePrinting(self, message_module):
377    msg = message_module.TestAllTypes()
378    msg.optional_double = 0.12345678912345678
379    self.assertEqual(str(msg), 'optional_double: 0.12345678912345678\n')
380
381  def testUnknownFieldPrinting(self, message_module):
382    populated = message_module.TestAllTypes()
383    test_util.SetAllNonLazyFields(populated)
384    empty = message_module.TestEmptyMessage()
385    empty.ParseFromString(populated.SerializeToString())
386    self.assertEqual(str(empty), '')
387
388  def testCopyFromEmpty(self, message_module):
389    msg = message_module.NestedTestAllTypes()
390    test_msg = message_module.NestedTestAllTypes()
391    test_util.SetAllFields(test_msg.payload)
392    self.assertTrue(test_msg.HasField('payload'))
393    # Copy from empty message
394    test_msg.CopyFrom(msg)
395    self.assertEqual(0, len(test_msg.ListFields()))
396
397    test_util.SetAllFields(test_msg.payload)
398    self.assertTrue(test_msg.HasField('payload'))
399    # Copy from a non exist message
400    test_msg.CopyFrom(msg.child)
401    self.assertFalse(test_msg.HasField('payload'))
402    self.assertEqual(0, len(test_msg.ListFields()))
403
404  def testAppendRepeatedCompositeField(self, message_module):
405    msg = message_module.TestAllTypes()
406    msg.repeated_nested_message.append(
407        message_module.TestAllTypes.NestedMessage(bb=1))
408    nested = message_module.TestAllTypes.NestedMessage(bb=2)
409    msg.repeated_nested_message.append(nested)
410    try:
411      msg.repeated_nested_message.append(1)
412    except TypeError:
413      pass
414    self.assertEqual(2, len(msg.repeated_nested_message))
415    self.assertEqual([1, 2], [m.bb for m in msg.repeated_nested_message])
416
417  def testInsertRepeatedCompositeField(self, message_module):
418    msg = message_module.TestAllTypes()
419    msg.repeated_nested_message.insert(
420        -1, message_module.TestAllTypes.NestedMessage(bb=1))
421    sub_msg = msg.repeated_nested_message[0]
422    msg.repeated_nested_message.insert(
423        0, message_module.TestAllTypes.NestedMessage(bb=2))
424    msg.repeated_nested_message.insert(
425        99, message_module.TestAllTypes.NestedMessage(bb=3))
426    msg.repeated_nested_message.insert(
427        -2, message_module.TestAllTypes.NestedMessage(bb=-1))
428    msg.repeated_nested_message.insert(
429        -1000, message_module.TestAllTypes.NestedMessage(bb=-1000))
430    try:
431      msg.repeated_nested_message.insert(1, 999)
432    except TypeError:
433      pass
434    self.assertEqual(5, len(msg.repeated_nested_message))
435    self.assertEqual([-1000, 2, -1, 1, 3],
436                     [m.bb for m in msg.repeated_nested_message])
437    self.assertEqual(
438        str(msg), 'repeated_nested_message {\n'
439        '  bb: -1000\n'
440        '}\n'
441        'repeated_nested_message {\n'
442        '  bb: 2\n'
443        '}\n'
444        'repeated_nested_message {\n'
445        '  bb: -1\n'
446        '}\n'
447        'repeated_nested_message {\n'
448        '  bb: 1\n'
449        '}\n'
450        'repeated_nested_message {\n'
451        '  bb: 3\n'
452        '}\n')
453    self.assertEqual(sub_msg.bb, 1)
454
455  def testAssignRepeatedField(self, message_module):
456    msg = message_module.NestedTestAllTypes()
457    msg.payload.repeated_int32[:] = [1, 2, 3, 4]
458    self.assertEqual(4, len(msg.payload.repeated_int32))
459    self.assertEqual([1, 2, 3, 4], msg.payload.repeated_int32)
460
461  def testMergeFromRepeatedField(self, message_module):
462    msg = message_module.TestAllTypes()
463    msg.repeated_int32.append(1)
464    msg.repeated_int32.append(3)
465    msg.repeated_nested_message.add(bb=1)
466    msg.repeated_nested_message.add(bb=2)
467    other_msg = message_module.TestAllTypes()
468    other_msg.repeated_nested_message.add(bb=3)
469    other_msg.repeated_nested_message.add(bb=4)
470    other_msg.repeated_int32.append(5)
471    other_msg.repeated_int32.append(7)
472
473    msg.repeated_int32.MergeFrom(other_msg.repeated_int32)
474    self.assertEqual(4, len(msg.repeated_int32))
475
476    msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
477    self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message])
478
479  def testInternalMergeWithMissingRequiredField(self, message_module):
480    req = more_messages_pb2.RequiredField()
481    more_messages_pb2.RequiredWrapper(request=req)
482
483  def testMergeFromMissingRequiredField(self, message_module):
484    msg = more_messages_pb2.RequiredField()
485    message = more_messages_pb2.RequiredField()
486    message.MergeFrom(msg)
487    self.assertEqual(msg, message)
488
489  def testAddWrongRepeatedNestedField(self, message_module):
490    msg = message_module.TestAllTypes()
491    try:
492      msg.repeated_nested_message.add('wrong')
493    except TypeError:
494      pass
495    try:
496      msg.repeated_nested_message.add(value_field='wrong')
497    except ValueError:
498      pass
499    self.assertEqual(len(msg.repeated_nested_message), 0)
500
501  def testRepeatedContains(self, message_module):
502    msg = message_module.TestAllTypes()
503    msg.repeated_int32.extend([1, 2, 3])
504    self.assertIn(2, msg.repeated_int32)
505    self.assertNotIn(0, msg.repeated_int32)
506
507    msg.repeated_nested_message.add(bb=1)
508    sub_msg1 = msg.repeated_nested_message[0]
509    sub_msg2 = message_module.TestAllTypes.NestedMessage(bb=2)
510    sub_msg3 = message_module.TestAllTypes.NestedMessage(bb=3)
511    msg.repeated_nested_message.append(sub_msg2)
512    msg.repeated_nested_message.insert(0, sub_msg3)
513    self.assertIn(sub_msg1, msg.repeated_nested_message)
514    self.assertIn(sub_msg2, msg.repeated_nested_message)
515    self.assertIn(sub_msg3, msg.repeated_nested_message)
516
517  def testRepeatedScalarIterable(self, message_module):
518    msg = message_module.TestAllTypes()
519    msg.repeated_int32.extend([1, 2, 3])
520    add = 0
521    for item in msg.repeated_int32:
522      add += item
523    self.assertEqual(add, 6)
524
525  def testRepeatedNestedFieldIteration(self, message_module):
526    msg = message_module.TestAllTypes()
527    msg.repeated_nested_message.add(bb=1)
528    msg.repeated_nested_message.add(bb=2)
529    msg.repeated_nested_message.add(bb=3)
530    msg.repeated_nested_message.add(bb=4)
531
532    self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message])
533    self.assertEqual([4, 3, 2, 1],
534                     [m.bb for m in reversed(msg.repeated_nested_message)])
535    self.assertEqual([4, 3, 2, 1],
536                     [m.bb for m in msg.repeated_nested_message[::-1]])
537
538  def testSortEmptyRepeated(self, message_module):
539    message = message_module.NestedTestAllTypes()
540    self.assertFalse(message.HasField('child'))
541    self.assertFalse(message.HasField('payload'))
542    message.child.repeated_child.sort()
543    message.payload.repeated_int32.sort()
544    self.assertFalse(message.HasField('child'))
545    self.assertFalse(message.HasField('payload'))
546
547  def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
548    """Check some different types with the default comparator."""
549    message = message_module.TestAllTypes()
550
551    # TODO: would testing more scalar types strengthen test?
552    message.repeated_int32.append(1)
553    message.repeated_int32.append(3)
554    message.repeated_int32.append(2)
555    message.repeated_int32.sort()
556    self.assertEqual(message.repeated_int32[0], 1)
557    self.assertEqual(message.repeated_int32[1], 2)
558    self.assertEqual(message.repeated_int32[2], 3)
559    self.assertEqual(str(message.repeated_int32), str([1, 2, 3]))
560
561    message.repeated_float.append(1.1)
562    message.repeated_float.append(1.3)
563    message.repeated_float.append(1.2)
564    message.repeated_float.sort()
565    self.assertAlmostEqual(message.repeated_float[0], 1.1)
566    self.assertAlmostEqual(message.repeated_float[1], 1.2)
567    self.assertAlmostEqual(message.repeated_float[2], 1.3)
568
569    message.repeated_string.append('a')
570    message.repeated_string.append('c')
571    message.repeated_string.append('b')
572    message.repeated_string.sort()
573    self.assertEqual(message.repeated_string[0], 'a')
574    self.assertEqual(message.repeated_string[1], 'b')
575    self.assertEqual(message.repeated_string[2], 'c')
576    self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c']))
577
578    message.repeated_bytes.append(b'a')
579    message.repeated_bytes.append(b'c')
580    message.repeated_bytes.append(b'b')
581    message.repeated_bytes.sort()
582    self.assertEqual(message.repeated_bytes[0], b'a')
583    self.assertEqual(message.repeated_bytes[1], b'b')
584    self.assertEqual(message.repeated_bytes[2], b'c')
585    self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c']))
586
587  def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
588    """Check some different types with custom comparator."""
589    message = message_module.TestAllTypes()
590
591    message.repeated_int32.append(-3)
592    message.repeated_int32.append(-2)
593    message.repeated_int32.append(-1)
594    message.repeated_int32.sort(key=abs)
595    self.assertEqual(message.repeated_int32[0], -1)
596    self.assertEqual(message.repeated_int32[1], -2)
597    self.assertEqual(message.repeated_int32[2], -3)
598
599    message.repeated_string.append('aaa')
600    message.repeated_string.append('bb')
601    message.repeated_string.append('c')
602    message.repeated_string.sort(key=len)
603    self.assertEqual(message.repeated_string[0], 'c')
604    self.assertEqual(message.repeated_string[1], 'bb')
605    self.assertEqual(message.repeated_string[2], 'aaa')
606
607  def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
608    """Check passing a custom comparator to sort a repeated composite field."""
609    message = message_module.TestAllTypes()
610
611    message.repeated_nested_message.add().bb = 1
612    message.repeated_nested_message.add().bb = 3
613    message.repeated_nested_message.add().bb = 2
614    message.repeated_nested_message.add().bb = 6
615    message.repeated_nested_message.add().bb = 5
616    message.repeated_nested_message.add().bb = 4
617    message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
618    self.assertEqual(message.repeated_nested_message[0].bb, 1)
619    self.assertEqual(message.repeated_nested_message[1].bb, 2)
620    self.assertEqual(message.repeated_nested_message[2].bb, 3)
621    self.assertEqual(message.repeated_nested_message[3].bb, 4)
622    self.assertEqual(message.repeated_nested_message[4].bb, 5)
623    self.assertEqual(message.repeated_nested_message[5].bb, 6)
624    self.assertEqual(
625        str(message.repeated_nested_message),
626        '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
627
628  def testSortingRepeatedCompositeFieldsStable(self, message_module):
629    """Check passing a custom comparator to sort a repeated composite field."""
630    message = message_module.TestAllTypes()
631
632    message.repeated_nested_message.add().bb = 21
633    message.repeated_nested_message.add().bb = 20
634    message.repeated_nested_message.add().bb = 13
635    message.repeated_nested_message.add().bb = 33
636    message.repeated_nested_message.add().bb = 11
637    message.repeated_nested_message.add().bb = 24
638    message.repeated_nested_message.add().bb = 10
639    message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
640    self.assertEqual([13, 11, 10, 21, 20, 24, 33],
641                     [n.bb for n in message.repeated_nested_message])
642
643    # Make sure that for the C++ implementation, the underlying fields
644    # are actually reordered.
645    pb = message.SerializeToString()
646    message.Clear()
647    message.MergeFromString(pb)
648    self.assertEqual([13, 11, 10, 21, 20, 24, 33],
649                     [n.bb for n in message.repeated_nested_message])
650
651  def testRepeatedCompositeFieldSortArguments(self, message_module):
652    """Check sorting a repeated composite field using list.sort() arguments."""
653    message = message_module.TestAllTypes()
654
655    get_bb = operator.attrgetter('bb')
656    message.repeated_nested_message.add().bb = 1
657    message.repeated_nested_message.add().bb = 3
658    message.repeated_nested_message.add().bb = 2
659    message.repeated_nested_message.add().bb = 6
660    message.repeated_nested_message.add().bb = 5
661    message.repeated_nested_message.add().bb = 4
662    message.repeated_nested_message.sort(key=get_bb)
663    self.assertEqual([k.bb for k in message.repeated_nested_message],
664                     [1, 2, 3, 4, 5, 6])
665    message.repeated_nested_message.sort(key=get_bb, reverse=True)
666    self.assertEqual([k.bb for k in message.repeated_nested_message],
667                     [6, 5, 4, 3, 2, 1])
668
669  def testRepeatedScalarFieldSortArguments(self, message_module):
670    """Check sorting a scalar field using list.sort() arguments."""
671    message = message_module.TestAllTypes()
672
673    message.repeated_int32.append(-3)
674    message.repeated_int32.append(-2)
675    message.repeated_int32.append(-1)
676    message.repeated_int32.sort(key=abs)
677    self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
678    message.repeated_int32.sort(key=abs, reverse=True)
679    self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
680
681    message.repeated_string.append('aaa')
682    message.repeated_string.append('bb')
683    message.repeated_string.append('c')
684    message.repeated_string.sort(key=len)
685    self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
686    message.repeated_string.sort(key=len, reverse=True)
687    self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
688
689  def testRepeatedFieldsComparable(self, message_module):
690    m1 = message_module.TestAllTypes()
691    m2 = message_module.TestAllTypes()
692    m1.repeated_int32.append(0)
693    m1.repeated_int32.append(1)
694    m1.repeated_int32.append(2)
695    m2.repeated_int32.append(0)
696    m2.repeated_int32.append(1)
697    m2.repeated_int32.append(2)
698    m1.repeated_nested_message.add().bb = 1
699    m1.repeated_nested_message.add().bb = 2
700    m1.repeated_nested_message.add().bb = 3
701    m2.repeated_nested_message.add().bb = 1
702    m2.repeated_nested_message.add().bb = 2
703    m2.repeated_nested_message.add().bb = 3
704
705  def testRepeatedFieldsAreSequences(self, message_module):
706    m = message_module.TestAllTypes()
707    self.assertIsInstance(m.repeated_int32, collections.abc.MutableSequence)
708    self.assertIsInstance(m.repeated_nested_message,
709                          collections.abc.MutableSequence)
710
711  def testRepeatedFieldsNotHashable(self, message_module):
712    m = message_module.TestAllTypes()
713    with self.assertRaises(TypeError):
714      hash(m.repeated_int32)
715    with self.assertRaises(TypeError):
716      hash(m.repeated_nested_message)
717
718  def testRepeatedFieldInsideNestedMessage(self, message_module):
719    m = message_module.NestedTestAllTypes()
720    m.payload.repeated_int32.extend([])
721    self.assertTrue(m.HasField('payload'))
722
723  def testMergeFrom(self, message_module):
724    m1 = message_module.TestAllTypes()
725    m2 = message_module.TestAllTypes()
726    # Cpp extension will lazily create a sub message which is immutable.
727    nested = m1.optional_nested_message
728    self.assertEqual(0, nested.bb)
729    m2.optional_nested_message.bb = 1
730    # Make sure cmessage pointing to a mutable message after merge instead of
731    # the lazily created message.
732    m1.MergeFrom(m2)
733    self.assertEqual(1, nested.bb)
734
735    # Test more nested sub message.
736    msg1 = message_module.NestedTestAllTypes()
737    msg2 = message_module.NestedTestAllTypes()
738    nested = msg1.child.payload.optional_nested_message
739    self.assertEqual(0, nested.bb)
740    msg2.child.payload.optional_nested_message.bb = 1
741    msg1.MergeFrom(msg2)
742    self.assertEqual(1, nested.bb)
743
744    # Test repeated field.
745    self.assertEqual(msg1.payload.repeated_nested_message,
746                     msg1.payload.repeated_nested_message)
747    nested = msg2.payload.repeated_nested_message.add()
748    nested.bb = 1
749    msg1.MergeFrom(msg2)
750    self.assertEqual(1, len(msg1.payload.repeated_nested_message))
751    self.assertEqual(1, nested.bb)
752
753  def testMergeFromString(self, message_module):
754    m1 = message_module.TestAllTypes()
755    m2 = message_module.TestAllTypes()
756    # Cpp extension will lazily create a sub message which is immutable.
757    self.assertEqual(0, m1.optional_nested_message.bb)
758    m2.optional_nested_message.bb = 1
759    # Make sure cmessage pointing to a mutable message after merge instead of
760    # the lazily created message.
761    m1.MergeFromString(m2.SerializeToString())
762    self.assertEqual(1, m1.optional_nested_message.bb)
763
764  def testMergeFromStringUsingMemoryView(self, message_module):
765    m2 = message_module.TestAllTypes()
766    m2.optional_string = 'scalar string'
767    m2.repeated_string.append('repeated string')
768    m2.optional_bytes = b'scalar bytes'
769    m2.repeated_bytes.append(b'repeated bytes')
770
771    serialized = m2.SerializeToString()
772    memview = memoryview(serialized)
773    m1 = message_module.TestAllTypes.FromString(memview)
774
775    self.assertEqual(m1.optional_bytes, b'scalar bytes')
776    self.assertEqual(m1.repeated_bytes, [b'repeated bytes'])
777    self.assertEqual(m1.optional_string, 'scalar string')
778    self.assertEqual(m1.repeated_string, ['repeated string'])
779    # Make sure that the memoryview was correctly converted to bytes, and
780    # that a sub-sliced memoryview is not being used.
781    self.assertIsInstance(m1.optional_bytes, bytes)
782    self.assertIsInstance(m1.repeated_bytes[0], bytes)
783    self.assertIsInstance(m1.optional_string, str)
784    self.assertIsInstance(m1.repeated_string[0], str)
785
786  def testMergeFromEmpty(self, message_module):
787    m1 = message_module.TestAllTypes()
788    # Cpp extension will lazily create a sub message which is immutable.
789    self.assertEqual(0, m1.optional_nested_message.bb)
790    self.assertFalse(m1.HasField('optional_nested_message'))
791    # Make sure the sub message is still immutable after merge from empty.
792    m1.MergeFromString(b'')  # field state should not change
793    self.assertFalse(m1.HasField('optional_nested_message'))
794
795  def ensureNestedMessageExists(self, msg, attribute):
796    """Make sure that a nested message object exists.
797
798    As soon as a nested message attribute is accessed, it will be present in the
799    _fields dict, without being marked as actually being set.
800    """
801    getattr(msg, attribute)
802    self.assertFalse(msg.HasField(attribute))
803
804  def testOneofGetCaseNonexistingField(self, message_module):
805    m = message_module.TestAllTypes()
806    self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
807    self.assertRaises(Exception, m.WhichOneof, 0)
808
809  def testOneofDefaultValues(self, message_module):
810    m = message_module.TestAllTypes()
811    self.assertIs(None, m.WhichOneof('oneof_field'))
812    self.assertFalse(m.HasField('oneof_field'))
813    self.assertFalse(m.HasField('oneof_uint32'))
814
815    # Oneof is set even when setting it to a default value.
816    m.oneof_uint32 = 0
817    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
818    self.assertTrue(m.HasField('oneof_field'))
819    self.assertTrue(m.HasField('oneof_uint32'))
820    self.assertFalse(m.HasField('oneof_string'))
821
822    m.oneof_string = ''
823    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
824    self.assertTrue(m.HasField('oneof_string'))
825    self.assertFalse(m.HasField('oneof_uint32'))
826
827  def testOneofSemantics(self, message_module):
828    m = message_module.TestAllTypes()
829    self.assertIs(None, m.WhichOneof('oneof_field'))
830
831    m.oneof_uint32 = 11
832    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
833    self.assertTrue(m.HasField('oneof_uint32'))
834
835    m.oneof_string = u'foo'
836    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
837    self.assertFalse(m.HasField('oneof_uint32'))
838    self.assertTrue(m.HasField('oneof_string'))
839
840    # Read nested message accessor without accessing submessage.
841    m.oneof_nested_message
842    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
843    self.assertTrue(m.HasField('oneof_string'))
844    self.assertFalse(m.HasField('oneof_nested_message'))
845
846    # Read accessor of nested message without accessing submessage.
847    m.oneof_nested_message.bb
848    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
849    self.assertTrue(m.HasField('oneof_string'))
850    self.assertFalse(m.HasField('oneof_nested_message'))
851
852    m.oneof_nested_message.bb = 11
853    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
854    self.assertFalse(m.HasField('oneof_string'))
855    self.assertTrue(m.HasField('oneof_nested_message'))
856
857    m.oneof_bytes = b'bb'
858    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
859    self.assertFalse(m.HasField('oneof_nested_message'))
860    self.assertTrue(m.HasField('oneof_bytes'))
861
862  def testOneofCompositeFieldReadAccess(self, message_module):
863    m = message_module.TestAllTypes()
864    m.oneof_uint32 = 11
865
866    self.ensureNestedMessageExists(m, 'oneof_nested_message')
867    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
868    self.assertEqual(11, m.oneof_uint32)
869
870  def testOneofWhichOneof(self, message_module):
871    m = message_module.TestAllTypes()
872    self.assertIs(None, m.WhichOneof('oneof_field'))
873    if message_module is unittest_pb2:
874      self.assertFalse(m.HasField('oneof_field'))
875
876    m.oneof_uint32 = 11
877    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
878    if message_module is unittest_pb2:
879      self.assertTrue(m.HasField('oneof_field'))
880
881    m.oneof_bytes = b'bb'
882    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
883
884    m.ClearField('oneof_bytes')
885    self.assertIs(None, m.WhichOneof('oneof_field'))
886    if message_module is unittest_pb2:
887      self.assertFalse(m.HasField('oneof_field'))
888
889  def testOneofClearField(self, message_module):
890    m = message_module.TestAllTypes()
891    m.ClearField('oneof_field')
892    m.oneof_uint32 = 11
893    m.ClearField('oneof_field')
894    if message_module is unittest_pb2:
895      self.assertFalse(m.HasField('oneof_field'))
896    self.assertFalse(m.HasField('oneof_uint32'))
897    self.assertIs(None, m.WhichOneof('oneof_field'))
898
899  def testOneofClearSetField(self, message_module):
900    m = message_module.TestAllTypes()
901    m.oneof_uint32 = 11
902    m.ClearField('oneof_uint32')
903    if message_module is unittest_pb2:
904      self.assertFalse(m.HasField('oneof_field'))
905    self.assertFalse(m.HasField('oneof_uint32'))
906    self.assertIs(None, m.WhichOneof('oneof_field'))
907
908  def testOneofClearUnsetField(self, message_module):
909    m = message_module.TestAllTypes()
910    m.oneof_uint32 = 11
911    self.ensureNestedMessageExists(m, 'oneof_nested_message')
912    m.ClearField('oneof_nested_message')
913    self.assertEqual(11, m.oneof_uint32)
914    if message_module is unittest_pb2:
915      self.assertTrue(m.HasField('oneof_field'))
916    self.assertTrue(m.HasField('oneof_uint32'))
917    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
918
919  def testOneofDeserialize(self, message_module):
920    m = message_module.TestAllTypes()
921    m.oneof_uint32 = 11
922    m2 = message_module.TestAllTypes()
923    m2.ParseFromString(m.SerializeToString())
924    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
925
926  def testOneofCopyFrom(self, message_module):
927    m = message_module.TestAllTypes()
928    m.oneof_uint32 = 11
929    m2 = message_module.TestAllTypes()
930    m2.CopyFrom(m)
931    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
932
933  def testOneofNestedMergeFrom(self, message_module):
934    m = message_module.NestedTestAllTypes()
935    m.payload.oneof_uint32 = 11
936    m2 = message_module.NestedTestAllTypes()
937    m2.payload.oneof_bytes = b'bb'
938    m2.child.payload.oneof_bytes = b'bb'
939    m2.MergeFrom(m)
940    self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
941    self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
942
943  def testOneofMessageMergeFrom(self, message_module):
944    m = message_module.NestedTestAllTypes()
945    m.payload.oneof_nested_message.bb = 11
946    m.child.payload.oneof_nested_message.bb = 12
947    m2 = message_module.NestedTestAllTypes()
948    m2.payload.oneof_uint32 = 13
949    m2.MergeFrom(m)
950    self.assertEqual('oneof_nested_message',
951                     m2.payload.WhichOneof('oneof_field'))
952    self.assertEqual('oneof_nested_message',
953                     m2.child.payload.WhichOneof('oneof_field'))
954
955  def testOneofNestedMessageInit(self, message_module):
956    m = message_module.TestAllTypes(
957        oneof_nested_message=message_module.TestAllTypes.NestedMessage())
958    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
959
960  def testOneofClear(self, message_module):
961    m = message_module.TestAllTypes()
962    m.oneof_uint32 = 11
963    m.Clear()
964    self.assertIsNone(m.WhichOneof('oneof_field'))
965    m.oneof_bytes = b'bb'
966    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
967
968  def testAssignByteStringToUnicodeField(self, message_module):
969    """Assigning a byte string to a string field should result
970
971    in the value being converted to a Unicode string.
972    """
973    m = message_module.TestAllTypes()
974    m.optional_string = str('')
975    self.assertIsInstance(m.optional_string, str)
976
977  def testLongValuedSlice(self, message_module):
978    """It should be possible to use int-valued indices in slices.
979
980    This didn't used to work in the v2 C++ implementation.
981    """
982    m = message_module.TestAllTypes()
983
984    # Repeated scalar
985    m.repeated_int32.append(1)
986    sl = m.repeated_int32[int(0):int(len(m.repeated_int32))]
987    self.assertEqual(len(m.repeated_int32), len(sl))
988
989    # Repeated composite
990    m.repeated_nested_message.add().bb = 3
991    sl = m.repeated_nested_message[int(0):int(len(m.repeated_nested_message))]
992    self.assertEqual(len(m.repeated_nested_message), len(sl))
993
994  def testExtendShouldNotSwallowExceptions(self, message_module):
995    """This didn't use to work in the v2 C++ implementation."""
996    m = message_module.TestAllTypes()
997    with self.assertRaises(NameError) as _:
998      m.repeated_int32.extend(a for i in range(10))  # pylint: disable=undefined-variable
999    with self.assertRaises(NameError) as _:
1000      m.repeated_nested_enum.extend(a for i in range(10))  # pylint: disable=undefined-variable
1001
1002  FALSY_VALUES = [None, False, 0, 0.0]
1003  EMPTY_VALUES = [b'', u'', bytearray(), [], {}, set()]
1004
1005  def testExtendInt32WithNothing(self, message_module):
1006    """Test no-ops extending repeated int32 fields."""
1007    m = message_module.TestAllTypes()
1008    self.assertSequenceEqual([], m.repeated_int32)
1009
1010    for falsy_value in MessageTest.FALSY_VALUES:
1011      with self.assertRaises(TypeError) as context:
1012        m.repeated_int32.extend(falsy_value)
1013      self.assertIn('iterable', str(context.exception))
1014      self.assertSequenceEqual([], m.repeated_int32)
1015
1016    for empty_value in MessageTest.EMPTY_VALUES:
1017      m.repeated_int32.extend(empty_value)
1018      self.assertSequenceEqual([], m.repeated_int32)
1019
1020  def testExtendFloatWithNothing(self, message_module):
1021    """Test no-ops extending repeated float fields."""
1022    m = message_module.TestAllTypes()
1023    self.assertSequenceEqual([], m.repeated_float)
1024
1025    for falsy_value in MessageTest.FALSY_VALUES:
1026      with self.assertRaises(TypeError) as context:
1027        m.repeated_float.extend(falsy_value)
1028      self.assertIn('iterable', str(context.exception))
1029      self.assertSequenceEqual([], m.repeated_float)
1030
1031    for empty_value in MessageTest.EMPTY_VALUES:
1032      m.repeated_float.extend(empty_value)
1033      self.assertSequenceEqual([], m.repeated_float)
1034
1035  def testExtendStringWithNothing(self, message_module):
1036    """Test no-ops extending repeated string fields."""
1037    m = message_module.TestAllTypes()
1038    self.assertSequenceEqual([], m.repeated_string)
1039
1040    for falsy_value in MessageTest.FALSY_VALUES:
1041      with self.assertRaises(TypeError) as context:
1042        m.repeated_string.extend(falsy_value)
1043      self.assertIn('iterable', str(context.exception))
1044      self.assertSequenceEqual([], m.repeated_string)
1045
1046    for empty_value in MessageTest.EMPTY_VALUES:
1047      m.repeated_string.extend(empty_value)
1048      self.assertSequenceEqual([], m.repeated_string)
1049
1050  def testExtendInt32WithPythonList(self, message_module):
1051    """Test extending repeated int32 fields with python lists."""
1052    m = message_module.TestAllTypes()
1053    self.assertSequenceEqual([], m.repeated_int32)
1054    m.repeated_int32.extend([0])
1055    self.assertSequenceEqual([0], m.repeated_int32)
1056    m.repeated_int32.extend([1, 2])
1057    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
1058    m.repeated_int32.extend([3, 4])
1059    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
1060
1061  def testExtendFloatWithPythonList(self, message_module):
1062    """Test extending repeated float fields with python lists."""
1063    m = message_module.TestAllTypes()
1064    self.assertSequenceEqual([], m.repeated_float)
1065    m.repeated_float.extend([0.0])
1066    self.assertSequenceEqual([0.0], m.repeated_float)
1067    m.repeated_float.extend([1.0, 2.0])
1068    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
1069    m.repeated_float.extend([3.0, 4.0])
1070    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
1071
1072  def testExtendStringWithPythonList(self, message_module):
1073    """Test extending repeated string fields with python lists."""
1074    m = message_module.TestAllTypes()
1075    self.assertSequenceEqual([], m.repeated_string)
1076    m.repeated_string.extend([''])
1077    self.assertSequenceEqual([''], m.repeated_string)
1078    m.repeated_string.extend(['11', '22'])
1079    self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
1080    m.repeated_string.extend(['33', '44'])
1081    self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
1082
1083  def testExtendStringWithString(self, message_module):
1084    """Test extending repeated string fields with characters from a string."""
1085    m = message_module.TestAllTypes()
1086    self.assertSequenceEqual([], m.repeated_string)
1087    m.repeated_string.extend('abc')
1088    self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
1089
1090  class TestIterable(object):
1091    """This iterable object mimics the behavior of numpy.array.
1092
1093    __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
1094
1095    """
1096
1097    def __init__(self, values=None):
1098      self._list = values or []
1099
1100    def __nonzero__(self):
1101      size = len(self._list)
1102      if size == 0:
1103        return False
1104      if size == 1:
1105        return bool(self._list[0])
1106      raise ValueError('Truth value is ambiguous.')
1107
1108    def __len__(self):
1109      return len(self._list)
1110
1111    def __iter__(self):
1112      return self._list.__iter__()
1113
1114  def testExtendInt32WithIterable(self, message_module):
1115    """Test extending repeated int32 fields with iterable."""
1116    m = message_module.TestAllTypes()
1117    self.assertSequenceEqual([], m.repeated_int32)
1118    m.repeated_int32.extend(MessageTest.TestIterable([]))
1119    self.assertSequenceEqual([], m.repeated_int32)
1120    m.repeated_int32.extend(MessageTest.TestIterable([0]))
1121    self.assertSequenceEqual([0], m.repeated_int32)
1122    m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
1123    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
1124    m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
1125    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
1126
1127  def testExtendFloatWithIterable(self, message_module):
1128    """Test extending repeated float fields with iterable."""
1129    m = message_module.TestAllTypes()
1130    self.assertSequenceEqual([], m.repeated_float)
1131    m.repeated_float.extend(MessageTest.TestIterable([]))
1132    self.assertSequenceEqual([], m.repeated_float)
1133    m.repeated_float.extend(MessageTest.TestIterable([0.0]))
1134    self.assertSequenceEqual([0.0], m.repeated_float)
1135    m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
1136    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
1137    m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
1138    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
1139
1140  def testExtendStringWithIterable(self, message_module):
1141    """Test extending repeated string fields with iterable."""
1142    m = message_module.TestAllTypes()
1143    self.assertSequenceEqual([], m.repeated_string)
1144    m.repeated_string.extend(MessageTest.TestIterable([]))
1145    self.assertSequenceEqual([], m.repeated_string)
1146    m.repeated_string.extend(MessageTest.TestIterable(['']))
1147    self.assertSequenceEqual([''], m.repeated_string)
1148    m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
1149    self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
1150    m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
1151    self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
1152
1153  class TestIndex(object):
1154    """This index object mimics the behavior of numpy.int64 and other types."""
1155
1156    def __init__(self, value=None):
1157      self.value = value
1158
1159    def __index__(self):
1160      return self.value
1161
1162  def testRepeatedIndexingWithIntIndex(self, message_module):
1163    msg = message_module.TestAllTypes()
1164    msg.repeated_int32.extend([1, 2, 3])
1165    self.assertEqual(1, msg.repeated_int32[MessageTest.TestIndex(0)])
1166
1167  def testRepeatedIndexingWithNegative1IntIndex(self, message_module):
1168    msg = message_module.TestAllTypes()
1169    msg.repeated_int32.extend([1, 2, 3])
1170    self.assertEqual(3, msg.repeated_int32[MessageTest.TestIndex(-1)])
1171
1172  def testRepeatedIndexingWithNegative1Int(self, message_module):
1173    msg = message_module.TestAllTypes()
1174    msg.repeated_int32.extend([1, 2, 3])
1175    self.assertEqual(3, msg.repeated_int32[-1])
1176
1177  def testPickleRepeatedScalarContainer(self, message_module):
1178    # Pickle repeated scalar container is not supported.
1179    m = message_module.TestAllTypes()
1180    with self.assertRaises(pickle.PickleError) as _:
1181      pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
1182
1183  def testSortEmptyRepeatedCompositeContainer(self, message_module):
1184    """Exercise a scenario that has led to segfaults in the past."""
1185    m = message_module.TestAllTypes()
1186    m.repeated_nested_message.sort()
1187
1188  def testHasFieldOnRepeatedField(self, message_module):
1189    """Using HasField on a repeated field should raise an exception."""
1190    m = message_module.TestAllTypes()
1191    with self.assertRaises(ValueError) as _:
1192      m.HasField('repeated_int32')
1193
1194  def testRepeatedScalarFieldPop(self, message_module):
1195    m = message_module.TestAllTypes()
1196    with self.assertRaises(IndexError) as _:
1197      m.repeated_int32.pop()
1198    m.repeated_int32.extend(range(5))
1199    self.assertEqual(4, m.repeated_int32.pop())
1200    self.assertEqual(0, m.repeated_int32.pop(0))
1201    self.assertEqual(2, m.repeated_int32.pop(1))
1202    self.assertEqual([1, 3], m.repeated_int32)
1203
1204  def testRepeatedCompositeFieldPop(self, message_module):
1205    m = message_module.TestAllTypes()
1206    with self.assertRaises(IndexError) as _:
1207      m.repeated_nested_message.pop()
1208    with self.assertRaises(TypeError) as _:
1209      m.repeated_nested_message.pop('0')
1210    for i in range(5):
1211      n = m.repeated_nested_message.add()
1212      n.bb = i
1213    self.assertEqual(4, m.repeated_nested_message.pop().bb)
1214    self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
1215    self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
1216    self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
1217
1218  def testRepeatedCompareWithSelf(self, message_module):
1219    m = message_module.TestAllTypes()
1220    for i in range(5):
1221      m.repeated_int32.insert(i, i)
1222      n = m.repeated_nested_message.add()
1223      n.bb = i
1224    self.assertSequenceEqual(m.repeated_int32, m.repeated_int32)
1225    self.assertEqual(m.repeated_nested_message, m.repeated_nested_message)
1226
1227  def testReleasedNestedMessages(self, message_module):
1228    """A case that lead to a segfault when a message detached from its parent
1229
1230    container has itself a child container.
1231    """
1232    m = message_module.NestedTestAllTypes()
1233    m = m.repeated_child.add()
1234    m = m.child
1235    m = m.repeated_child.add()
1236    self.assertEqual(m.payload.optional_int32, 0)
1237
1238  def testSetRepeatedComposite(self, message_module):
1239    m = message_module.TestAllTypes()
1240    with self.assertRaises(AttributeError):
1241      m.repeated_int32 = []
1242    m.repeated_int32.append(1)
1243    with self.assertRaises(AttributeError):
1244      m.repeated_int32 = []
1245
1246  def testReturningType(self, message_module):
1247    m = message_module.TestAllTypes()
1248    self.assertEqual(float, type(m.optional_float))
1249    self.assertEqual(float, type(m.optional_double))
1250    self.assertEqual(bool, type(m.optional_bool))
1251    m.optional_float = 1
1252    m.optional_double = 1
1253    m.optional_bool = 1
1254    m.repeated_float.append(1)
1255    m.repeated_double.append(1)
1256    m.repeated_bool.append(1)
1257    m.ParseFromString(m.SerializeToString())
1258    self.assertEqual(float, type(m.optional_float))
1259    self.assertEqual(float, type(m.optional_double))
1260    self.assertEqual('1.0', str(m.optional_double))
1261    self.assertEqual(bool, type(m.optional_bool))
1262    self.assertEqual(float, type(m.repeated_float[0]))
1263    self.assertEqual(float, type(m.repeated_double[0]))
1264    self.assertEqual(bool, type(m.repeated_bool[0]))
1265    self.assertEqual(True, m.repeated_bool[0])
1266
1267  def testEquality(self, message_module):
1268    m = message_module.TestAllTypes()
1269    m2 = message_module.TestAllTypes()
1270    self.assertEqual(m, m)
1271    self.assertEqual(m, m2)
1272    self.assertEqual(m2, m)
1273
1274    different_m = message_module.TestAllTypes()
1275    different_m.repeated_float.append(1)
1276    self.assertNotEqual(m, different_m)
1277    self.assertNotEqual(different_m, m)
1278
1279    self.assertIsNotNone(m)
1280    self.assertIsNotNone(m)
1281    self.assertNotEqual(42, m)
1282    self.assertNotEqual(m, 42)
1283    self.assertNotEqual('foo', m)
1284    self.assertNotEqual(m, 'foo')
1285
1286    self.assertEqual(mock.ANY, m)
1287    self.assertEqual(m, mock.ANY)
1288
1289    class ComparesWithFoo(object):
1290
1291      def __eq__(self, other):
1292        if getattr(other, 'optional_string', 'not_foo') == 'foo':
1293          return True
1294        return NotImplemented
1295
1296    m.optional_string = 'foo'
1297    self.assertEqual(m, ComparesWithFoo())
1298    self.assertEqual(ComparesWithFoo(), m)
1299    m.optional_string = 'bar'
1300    self.assertNotEqual(m, ComparesWithFoo())
1301    self.assertNotEqual(ComparesWithFoo(), m)
1302
1303  def testTypeUnion(self, message_module):
1304    # Below python 3.10 you cannot create union types with the | operator, so we
1305    # skip testing for unions with old versions.
1306    if sys.version_info < (3, 10):
1307      return
1308    enum_type = enum_type_wrapper.EnumTypeWrapper(
1309        message_module.TestAllTypes.NestedEnum.DESCRIPTOR
1310    )
1311    union_type = enum_type | int
1312    self.assertIsInstance(union_type, types.UnionType)
1313
1314    def get_union() -> union_type:
1315      return enum_type
1316
1317    union = get_union()
1318    self.assertIsInstance(union, enum_type_wrapper.EnumTypeWrapper)
1319    self.assertEqual(
1320        union.DESCRIPTOR, message_module.TestAllTypes.NestedEnum.DESCRIPTOR
1321    )
1322
1323  def testIn(self, message_module):
1324    m = message_module.TestAllTypes()
1325    self.assertNotIn('optional_nested_message', m)
1326    self.assertNotIn('oneof_bytes', m)
1327    self.assertNotIn('oneof_string', m)
1328    with self.assertRaises(ValueError) as e:
1329      'repeated_int32' in m
1330    with self.assertRaises(ValueError) as e:
1331      'repeated_nested_message' in m
1332    with self.assertRaises(ValueError) as e:
1333      1 in m
1334    with self.assertRaises(ValueError) as e:
1335      'not_a_field' in m
1336    test_util.SetAllFields(m)
1337    self.assertIn('optional_nested_message', m)
1338    self.assertIn('oneof_bytes', m)
1339    self.assertNotIn('oneof_string', m)
1340
1341
1342@testing_refleaks.TestCase
1343class TestRecursiveGroup(unittest.TestCase):
1344
1345  def _MakeRecursiveGroupMessage(self, n):
1346    msg = self_recursive_pb2.SelfRecursive()
1347    sub = msg
1348    for _ in range(n):
1349      sub = sub.sub_group
1350    sub.i = 1
1351    return msg.SerializeToString()
1352
1353  def testRecursiveGroups(self):
1354    recurse_msg = self_recursive_pb2.SelfRecursive()
1355    data = self._MakeRecursiveGroupMessage(100)
1356    recurse_msg.ParseFromString(data)
1357    self.assertTrue(recurse_msg.HasField('sub_group'))
1358
1359  def testRecursiveGroupsException(self):
1360    if api_implementation.Type() != 'python':
1361      api_implementation._c_module.SetAllowOversizeProtos(False)
1362    recurse_msg = self_recursive_pb2.SelfRecursive()
1363    data = self._MakeRecursiveGroupMessage(300)
1364    with self.assertRaises(message.DecodeError) as context:
1365      recurse_msg.ParseFromString(data)
1366    self.assertIn('Error parsing message', str(context.exception))
1367    if api_implementation.Type() == 'python':
1368      self.assertIn('too many levels of nesting', str(context.exception))
1369
1370  def testRecursiveGroupsUnknownFields(self):
1371    if api_implementation.Type() != 'python':
1372      api_implementation._c_module.SetAllowOversizeProtos(False)
1373    test_msg = unittest_pb2.TestAllTypes()
1374    data = self._MakeRecursiveGroupMessage(300)  # unknown to test_msg
1375    with self.assertRaises(message.DecodeError) as context:
1376      test_msg.ParseFromString(data)
1377    self.assertIn(
1378        'Error parsing message',
1379        str(context.exception),
1380    )
1381    if api_implementation.Type() == 'python':
1382      self.assertIn('too many levels of nesting', str(context.exception))
1383      decoder.SetRecursionLimit(310)
1384      test_msg.ParseFromString(data)
1385      decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
1386
1387
1388# Class to test proto2-only features (required, extensions, etc.)
1389@testing_refleaks.TestCase
1390class Proto2Test(unittest.TestCase):
1391
1392  def testFieldPresence(self):
1393    message = unittest_pb2.TestAllTypes()
1394
1395    self.assertFalse(message.HasField('optional_int32'))
1396    self.assertFalse(message.HasField('optional_bool'))
1397    self.assertFalse(message.HasField('optional_nested_message'))
1398
1399    with self.assertRaises(ValueError):
1400      message.HasField('field_doesnt_exist')
1401
1402    with self.assertRaises(ValueError):
1403      message.HasField('repeated_int32')
1404    with self.assertRaises(ValueError):
1405      message.HasField('repeated_nested_message')
1406
1407    self.assertEqual(0, message.optional_int32)
1408    self.assertEqual(False, message.optional_bool)
1409    self.assertEqual(0, message.optional_nested_message.bb)
1410
1411    # Fields are set even when setting the values to default values.
1412    message.optional_int32 = 0
1413    message.optional_bool = False
1414    message.optional_nested_message.bb = 0
1415    self.assertTrue(message.HasField('optional_int32'))
1416    self.assertTrue(message.HasField('optional_bool'))
1417    self.assertTrue(message.HasField('optional_nested_message'))
1418    self.assertIn('optional_int32', message)
1419    self.assertIn('optional_bool', message)
1420    self.assertIn('optional_nested_message', message)
1421
1422    # Set the fields to non-default values.
1423    message.optional_int32 = 5
1424    message.optional_bool = True
1425    message.optional_nested_message.bb = 15
1426
1427    self.assertTrue(message.HasField(u'optional_int32'))
1428    self.assertTrue(message.HasField('optional_bool'))
1429    self.assertTrue(message.HasField('optional_nested_message'))
1430
1431    # Clearing the fields unsets them and resets their value to default.
1432    message.ClearField('optional_int32')
1433    message.ClearField(u'optional_bool')
1434    message.ClearField('optional_nested_message')
1435
1436    self.assertFalse(message.HasField('optional_int32'))
1437    self.assertFalse(message.HasField('optional_bool'))
1438    self.assertFalse(message.HasField('optional_nested_message'))
1439    self.assertNotIn('optional_int32', message)
1440    self.assertNotIn('optional_bool', message)
1441    self.assertNotIn('optional_nested_message', message)
1442    self.assertEqual(0, message.optional_int32)
1443    self.assertEqual(False, message.optional_bool)
1444    self.assertEqual(0, message.optional_nested_message.bb)
1445
1446  def testDel(self):
1447    msg = unittest_pb2.TestAllTypes()
1448
1449    # Fields cannot be deleted.
1450    with self.assertRaises(AttributeError):
1451      del msg.optional_int32
1452    with self.assertRaises(AttributeError):
1453      del msg.optional_bool
1454    with self.assertRaises(AttributeError):
1455      del msg.repeated_nested_message
1456
1457  def testAssignInvalidEnum(self):
1458    """Assigning an invalid enum number is not allowed for closed enums."""
1459    m = unittest_pb2.TestAllTypes()
1460
1461    # TODO Enable these once upb's behavior is made conformant.
1462    if api_implementation.Type() != 'upb':
1463      # Can not assign unknown enum to closed enums.
1464      with self.assertRaises(ValueError) as _:
1465        m.optional_nested_enum = 1234567
1466      self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
1467      # Assignment is a different code path than append for the C++ impl.
1468      m.repeated_nested_enum.append(2)
1469      m.repeated_nested_enum[0] = 2
1470      with self.assertRaises(ValueError):
1471        m.repeated_nested_enum[0] = 123456
1472    else:
1473      m.optional_nested_enum = 1234567
1474      m.repeated_nested_enum.append(1234567)
1475      m.repeated_nested_enum.append(2)
1476      m.repeated_nested_enum[0] = 2
1477      m.repeated_nested_enum[0] = 123456
1478
1479    # Unknown enum value can be parsed but is ignored.
1480    m2 = unittest_proto3_arena_pb2.TestAllTypes()
1481    m2.optional_nested_enum = 1234567
1482    m2.repeated_nested_enum.append(7654321)
1483    serialized = m2.SerializeToString()
1484
1485    m3 = unittest_pb2.TestAllTypes()
1486    m3.ParseFromString(serialized)
1487    self.assertFalse(m3.HasField('optional_nested_enum'))
1488    # 1 is the default value for optional_nested_enum.
1489    self.assertEqual(1, m3.optional_nested_enum)
1490    self.assertEqual(0, len(m3.repeated_nested_enum))
1491    m2.Clear()
1492    m2.ParseFromString(m3.SerializeToString())
1493    self.assertEqual(1234567, m2.optional_nested_enum)
1494    self.assertEqual(7654321, m2.repeated_nested_enum[0])
1495
1496  def testUnknownEnumMap(self):
1497    m = map_proto2_unittest_pb2.TestEnumMap()
1498    m.known_map_field[123] = 0
1499    with self.assertRaises(ValueError):
1500      m.unknown_map_field[1] = 123
1501
1502  def testDeepCopyClosedEnum(self):
1503    m = map_proto2_unittest_pb2.TestEnumMap()
1504    m.known_map_field[123] = 0
1505    m2 = copy.deepcopy(m)
1506    self.assertEqual(m, m2)
1507
1508  def testExtensionsErrors(self):
1509    msg = unittest_pb2.TestAllTypes()
1510    self.assertRaises(AttributeError, getattr, msg, 'Extensions')
1511
1512  def testMergeFromExtensions(self):
1513    msg1 = more_extensions_pb2.TopLevelMessage()
1514    msg2 = more_extensions_pb2.TopLevelMessage()
1515    # Cpp extension will lazily create a sub message which is immutable.
1516    self.assertEqual(
1517        0,
1518        msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension])
1519    self.assertFalse(msg1.HasField('submessage'))
1520    msg2.submessage.Extensions[more_extensions_pb2.optional_int_extension] = 123
1521    # Make sure cmessage and extensions pointing to a mutable message
1522    # after merge instead of the lazily created message.
1523    msg1.MergeFrom(msg2)
1524    self.assertEqual(
1525        123,
1526        msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension])
1527
1528  def testCopyFromAll(self):
1529    message = unittest_pb2.TestAllTypes()
1530    test_util.SetAllFields(message)
1531    copy = unittest_pb2.TestAllTypes()
1532    copy.CopyFrom(message)
1533    self.assertEqual(message, copy)
1534    message.repeated_nested_message.add().bb = 123
1535    self.assertNotEqual(message, copy)
1536
1537  def testCopyFromAllExtensions(self):
1538    all_set = unittest_pb2.TestAllExtensions()
1539    test_util.SetAllExtensions(all_set)
1540    copy =  unittest_pb2.TestAllExtensions()
1541    copy.CopyFrom(all_set)
1542    self.assertEqual(all_set, copy)
1543    all_set.Extensions[unittest_pb2.repeatedgroup_extension].add().a = 321
1544    self.assertNotEqual(all_set, copy)
1545
1546  def testCopyFromAllPackedExtensions(self):
1547    all_set = unittest_pb2.TestPackedExtensions()
1548    test_util.SetAllPackedExtensions(all_set)
1549    copy =  unittest_pb2.TestPackedExtensions()
1550    copy.CopyFrom(all_set)
1551    self.assertEqual(all_set, copy)
1552    all_set.Extensions[unittest_pb2.packed_float_extension].extend([61.0, 71.0])
1553    self.assertNotEqual(all_set, copy)
1554
1555  def testPickleIncompleteProto(self):
1556    golden_message = unittest_pb2.TestRequired(a=1)
1557    pickled_message = pickle.dumps(golden_message)
1558
1559    unpickled_message = pickle.loads(pickled_message)
1560    self.assertEqual(unpickled_message, golden_message)
1561    self.assertEqual(unpickled_message.a, 1)
1562    # This is still an incomplete proto - so serializing should fail
1563    self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
1564
1565  # TODO: this isn't really a proto2-specific test except that this
1566  # message has a required field in it.  Should probably be factored out so
1567  # that we can test the other parts with proto3.
1568  def testParsingMerge(self):
1569    """Check the merge behavior when a required or optional field appears
1570
1571    multiple times in the input.
1572    """
1573    messages = [
1574        unittest_pb2.TestAllTypes(),
1575        unittest_pb2.TestAllTypes(),
1576        unittest_pb2.TestAllTypes()
1577    ]
1578    messages[0].optional_int32 = 1
1579    messages[1].optional_int64 = 2
1580    messages[2].optional_int32 = 3
1581    messages[2].optional_string = 'hello'
1582
1583    merged_message = unittest_pb2.TestAllTypes()
1584    merged_message.optional_int32 = 3
1585    merged_message.optional_int64 = 2
1586    merged_message.optional_string = 'hello'
1587
1588    generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
1589    generator.field1.extend(messages)
1590    generator.field2.extend(messages)
1591    generator.field3.extend(messages)
1592    generator.ext1.extend(messages)
1593    generator.ext2.extend(messages)
1594    generator.group1.add().field1.MergeFrom(messages[0])
1595    generator.group1.add().field1.MergeFrom(messages[1])
1596    generator.group1.add().field1.MergeFrom(messages[2])
1597    generator.group2.add().field1.MergeFrom(messages[0])
1598    generator.group2.add().field1.MergeFrom(messages[1])
1599    generator.group2.add().field1.MergeFrom(messages[2])
1600
1601    data = generator.SerializeToString()
1602    parsing_merge = unittest_pb2.TestParsingMerge()
1603    parsing_merge.ParseFromString(data)
1604
1605    # Required and optional fields should be merged.
1606    self.assertEqual(parsing_merge.required_all_types, merged_message)
1607    self.assertEqual(parsing_merge.optional_all_types, merged_message)
1608    self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
1609                     merged_message)
1610    self.assertEqual(
1611        parsing_merge.Extensions[unittest_pb2.TestParsingMerge.optional_ext],
1612        merged_message)
1613
1614    # Repeated fields should not be merged.
1615    self.assertEqual(len(parsing_merge.repeated_all_types), 3)
1616    self.assertEqual(len(parsing_merge.repeatedgroup), 3)
1617    self.assertEqual(
1618        len(parsing_merge.Extensions[
1619            unittest_pb2.TestParsingMerge.repeated_ext]), 3)
1620
1621  def testPythonicInit(self):
1622    message = unittest_pb2.TestAllTypes(
1623        optional_int32=100,
1624        optional_fixed32=200,
1625        optional_float=300.5,
1626        optional_bytes=b'x',
1627        optionalgroup={'a': 400},
1628        optional_nested_message={'bb': 500},
1629        optional_foreign_message={},
1630        optional_nested_enum='BAZ',
1631        repeatedgroup=[{
1632            'a': 600
1633        }, {
1634            'a': 700
1635        }],
1636        repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
1637        default_int32=800,
1638        oneof_string='y')
1639    self.assertIsInstance(message, unittest_pb2.TestAllTypes)
1640    self.assertEqual(100, message.optional_int32)
1641    self.assertEqual(200, message.optional_fixed32)
1642    self.assertEqual(300.5, message.optional_float)
1643    self.assertEqual(b'x', message.optional_bytes)
1644    self.assertEqual(400, message.optionalgroup.a)
1645    self.assertIsInstance(message.optional_nested_message,
1646                          unittest_pb2.TestAllTypes.NestedMessage)
1647    self.assertEqual(500, message.optional_nested_message.bb)
1648    self.assertTrue(message.HasField('optional_foreign_message'))
1649    self.assertEqual(message.optional_foreign_message,
1650                     unittest_pb2.ForeignMessage())
1651    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1652                     message.optional_nested_enum)
1653    self.assertEqual(2, len(message.repeatedgroup))
1654    self.assertEqual(600, message.repeatedgroup[0].a)
1655    self.assertEqual(700, message.repeatedgroup[1].a)
1656    self.assertEqual(2, len(message.repeated_nested_enum))
1657    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
1658                     message.repeated_nested_enum[0])
1659    self.assertEqual(unittest_pb2.TestAllTypes.BAR,
1660                     message.repeated_nested_enum[1])
1661    self.assertEqual(800, message.default_int32)
1662    self.assertEqual('y', message.oneof_string)
1663    self.assertFalse(message.HasField('optional_int64'))
1664    self.assertEqual(0, len(message.repeated_float))
1665    self.assertEqual(42, message.default_int64)
1666
1667    message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
1668    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1669                     message.optional_nested_enum)
1670
1671    with self.assertRaises(ValueError):
1672      unittest_pb2.TestAllTypes(
1673          optional_nested_message={'INVALID_NESTED_FIELD': 17})
1674
1675    with self.assertRaises(TypeError):
1676      unittest_pb2.TestAllTypes(
1677          optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
1678
1679    with self.assertRaises(ValueError):
1680      unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
1681
1682    with self.assertRaises(ValueError):
1683      unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
1684
1685  def testPythonicInitWithDict(self):
1686    # Both string/unicode field name keys should work.
1687    kwargs = {
1688        'optional_int32': 100,
1689        u'optional_fixed32': 200,
1690    }
1691    msg = unittest_pb2.TestAllTypes(**kwargs)
1692    self.assertEqual(100, msg.optional_int32)
1693    self.assertEqual(200, msg.optional_fixed32)
1694
1695  def test_documentation(self):
1696    # Also used by the interactive help() function.
1697    doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message')
1698    self.assertIn('class TestAllTypes', doc)
1699    self.assertIn('SerializePartialToString', doc)
1700    self.assertIn('repeated_float', doc)
1701    base = unittest_pb2.TestAllTypes.__bases__[0]
1702    self.assertRaises(AttributeError, getattr, base, '_extensions_by_name')
1703
1704
1705# Class to test proto3-only features/behavior (updated field presence & enums)
1706@testing_refleaks.TestCase
1707class Proto3Test(unittest.TestCase):
1708
1709  # Utility method for comparing equality with a map.
1710  def assertMapIterEquals(self, map_iter, dict_value):
1711    # Avoid mutating caller's copy.
1712    dict_value = dict(dict_value)
1713
1714    for k, v in map_iter:
1715      self.assertEqual(v, dict_value[k])
1716      del dict_value[k]
1717
1718    self.assertEqual({}, dict_value)
1719
1720  def testFieldPresence(self):
1721    message = unittest_proto3_arena_pb2.TestAllTypes()
1722
1723    # We can't test presence of non-repeated, non-submessage fields.
1724    with self.assertRaises(ValueError):
1725      message.HasField('optional_int32')
1726    with self.assertRaises(ValueError):
1727      message.HasField('optional_float')
1728    with self.assertRaises(ValueError):
1729      message.HasField('optional_string')
1730    with self.assertRaises(ValueError):
1731      message.HasField('optional_bool')
1732
1733    # But we can still test presence of submessage fields.
1734    self.assertFalse(message.HasField('optional_nested_message'))
1735
1736    # As with proto2, we can't test presence of fields that don't exist, or
1737    # repeated fields.
1738    with self.assertRaises(ValueError):
1739      message.HasField('field_doesnt_exist')
1740
1741    with self.assertRaises(ValueError):
1742      message.HasField('repeated_int32')
1743    with self.assertRaises(ValueError):
1744      message.HasField('repeated_nested_message')
1745
1746    # Can not test "in" operator.
1747    with self.assertRaises(ValueError):
1748      'repeated_int32' in message
1749    with self.assertRaises(ValueError):
1750      'repeated_nested_message' in message
1751
1752    # Fields should default to their type-specific default.
1753    self.assertEqual(0, message.optional_int32)
1754    self.assertEqual(0, message.optional_float)
1755    self.assertEqual('', message.optional_string)
1756    self.assertEqual(False, message.optional_bool)
1757    self.assertEqual(0, message.optional_nested_message.bb)
1758
1759    # Setting a submessage should still return proper presence information.
1760    message.optional_nested_message.bb = 0
1761    self.assertTrue(message.HasField('optional_nested_message'))
1762    self.assertIn('optional_nested_message', message)
1763
1764    # Set the fields to non-default values.
1765    message.optional_int32 = 5
1766    message.optional_float = 1.1
1767    message.optional_string = 'abc'
1768    message.optional_bool = True
1769    message.optional_nested_message.bb = 15
1770
1771    # Clearing the fields unsets them and resets their value to default.
1772    message.ClearField('optional_int32')
1773    message.ClearField('optional_float')
1774    message.ClearField('optional_string')
1775    message.ClearField('optional_bool')
1776    message.ClearField('optional_nested_message')
1777
1778    self.assertEqual(0, message.optional_int32)
1779    self.assertEqual(0, message.optional_float)
1780    self.assertEqual('', message.optional_string)
1781    self.assertEqual(False, message.optional_bool)
1782    self.assertEqual(0, message.optional_nested_message.bb)
1783
1784  def testProto3ParserDropDefaultScalar(self):
1785    message_proto2 = unittest_pb2.TestAllTypes()
1786    message_proto2.optional_int32 = 0
1787    message_proto2.optional_string = ''
1788    message_proto2.optional_bytes = b''
1789    self.assertEqual(len(message_proto2.ListFields()), 3)
1790
1791    message_proto3 = unittest_proto3_arena_pb2.TestAllTypes()
1792    message_proto3.ParseFromString(message_proto2.SerializeToString())
1793    self.assertEqual(len(message_proto3.ListFields()), 0)
1794
1795  def testProto3Optional(self):
1796    msg = test_proto3_optional_pb2.TestProto3Optional()
1797    self.assertFalse(msg.HasField('optional_int32'))
1798    self.assertFalse(msg.HasField('optional_float'))
1799    self.assertFalse(msg.HasField('optional_string'))
1800    self.assertFalse(msg.HasField('optional_nested_message'))
1801    self.assertFalse(msg.optional_nested_message.HasField('bb'))
1802
1803    # Set fields.
1804    msg.optional_int32 = 1
1805    msg.optional_float = 1.0
1806    msg.optional_string = '123'
1807    msg.optional_nested_message.bb = 1
1808    self.assertTrue(msg.HasField('optional_int32'))
1809    self.assertTrue(msg.HasField('optional_float'))
1810    self.assertTrue(msg.HasField('optional_string'))
1811    self.assertTrue(msg.HasField('optional_nested_message'))
1812    self.assertTrue(msg.optional_nested_message.HasField('bb'))
1813    # Set to default value does not clear the fields
1814    msg.optional_int32 = 0
1815    msg.optional_float = 0.0
1816    msg.optional_string = ''
1817    msg.optional_nested_message.bb = 0
1818    self.assertTrue(msg.HasField('optional_int32'))
1819    self.assertTrue(msg.HasField('optional_float'))
1820    self.assertTrue(msg.HasField('optional_string'))
1821    self.assertTrue(msg.HasField('optional_nested_message'))
1822    self.assertTrue(msg.optional_nested_message.HasField('bb'))
1823
1824    # Test serialize
1825    msg2 = test_proto3_optional_pb2.TestProto3Optional()
1826    msg2.ParseFromString(msg.SerializeToString())
1827    self.assertTrue(msg2.HasField('optional_int32'))
1828    self.assertTrue(msg2.HasField('optional_float'))
1829    self.assertTrue(msg2.HasField('optional_string'))
1830    self.assertTrue(msg2.HasField('optional_nested_message'))
1831    self.assertTrue(msg2.optional_nested_message.HasField('bb'))
1832
1833    self.assertEqual(msg.WhichOneof('_optional_int32'), 'optional_int32')
1834
1835    # Clear these fields.
1836    msg.ClearField('optional_int32')
1837    msg.ClearField('optional_float')
1838    msg.ClearField('optional_string')
1839    msg.ClearField('optional_nested_message')
1840    self.assertFalse(msg.HasField('optional_int32'))
1841    self.assertFalse(msg.HasField('optional_float'))
1842    self.assertFalse(msg.HasField('optional_string'))
1843    self.assertFalse(msg.HasField('optional_nested_message'))
1844    self.assertFalse(msg.optional_nested_message.HasField('bb'))
1845
1846    self.assertEqual(msg.WhichOneof('_optional_int32'), None)
1847
1848    # Test has presence:
1849    for field in test_proto3_optional_pb2.TestProto3Optional.DESCRIPTOR.fields:
1850      if field.name.startswith('optional_'):
1851        self.assertTrue(field.has_presence)
1852    for field in unittest_pb2.TestAllTypes.DESCRIPTOR.fields:
1853      if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1854        self.assertFalse(field.has_presence)
1855      else:
1856        self.assertTrue(field.has_presence)
1857    proto3_descriptor = unittest_proto3_arena_pb2.TestAllTypes.DESCRIPTOR
1858    repeated_field = proto3_descriptor.fields_by_name['repeated_int32']
1859    self.assertFalse(repeated_field.has_presence)
1860    singular_field = proto3_descriptor.fields_by_name['optional_int32']
1861    self.assertFalse(singular_field.has_presence)
1862    optional_field = proto3_descriptor.fields_by_name['proto3_optional_int32']
1863    self.assertTrue(optional_field.has_presence)
1864    message_field = proto3_descriptor.fields_by_name['optional_nested_message']
1865    self.assertTrue(message_field.has_presence)
1866    oneof_field = proto3_descriptor.fields_by_name['oneof_uint32']
1867    self.assertTrue(oneof_field.has_presence)
1868
1869  def testAssignUnknownEnum(self):
1870    """Assigning an unknown enum value is allowed and preserves the value."""
1871    m = unittest_proto3_arena_pb2.TestAllTypes()
1872
1873    # Proto3 can assign unknown enums.
1874    m.optional_nested_enum = 1234567
1875    self.assertEqual(1234567, m.optional_nested_enum)
1876    m.repeated_nested_enum.append(22334455)
1877    self.assertEqual(22334455, m.repeated_nested_enum[0])
1878    # Assignment is a different code path than append for the C++ impl.
1879    m.repeated_nested_enum[0] = 7654321
1880    self.assertEqual(7654321, m.repeated_nested_enum[0])
1881    serialized = m.SerializeToString()
1882
1883    m2 = unittest_proto3_arena_pb2.TestAllTypes()
1884    m2.ParseFromString(serialized)
1885    self.assertEqual(1234567, m2.optional_nested_enum)
1886    self.assertEqual(7654321, m2.repeated_nested_enum[0])
1887
1888  # Map isn't really a proto3-only feature. But there is no proto2 equivalent
1889  # of google/protobuf/map_unittest.proto right now, so it's not easy to
1890  # test both with the same test like we do for the other proto2/proto3 tests.
1891  # (google/protobuf/map_proto2_unittest.proto is very different in the set
1892  # of messages and fields it contains).
1893  def testScalarMapDefaults(self):
1894    msg = map_unittest_pb2.TestMap()
1895
1896    # Scalars start out unset.
1897    self.assertFalse(-123 in msg.map_int32_int32)
1898    self.assertFalse(-2**33 in msg.map_int64_int64)
1899    self.assertFalse(123 in msg.map_uint32_uint32)
1900    self.assertFalse(2**33 in msg.map_uint64_uint64)
1901    self.assertFalse(123 in msg.map_int32_double)
1902    self.assertFalse(False in msg.map_bool_bool)
1903    self.assertFalse('abc' in msg.map_string_string)
1904    self.assertFalse(111 in msg.map_int32_bytes)
1905    self.assertFalse(888 in msg.map_int32_enum)
1906
1907    # Accessing an unset key returns the default.
1908    self.assertEqual(0, msg.map_int32_int32[-123])
1909    self.assertEqual(0, msg.map_int64_int64[-2**33])
1910    self.assertEqual(0, msg.map_uint32_uint32[123])
1911    self.assertEqual(0, msg.map_uint64_uint64[2**33])
1912    self.assertEqual(0.0, msg.map_int32_double[123])
1913    self.assertTrue(isinstance(msg.map_int32_double[123], float))
1914    self.assertEqual(False, msg.map_bool_bool[False])
1915    self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
1916    self.assertEqual('', msg.map_string_string['abc'])
1917    self.assertEqual(b'', msg.map_int32_bytes[111])
1918    self.assertEqual(0, msg.map_int32_enum[888])
1919
1920    # It also sets the value in the map
1921    self.assertTrue(-123 in msg.map_int32_int32)
1922    self.assertTrue(-2**33 in msg.map_int64_int64)
1923    self.assertTrue(123 in msg.map_uint32_uint32)
1924    self.assertTrue(2**33 in msg.map_uint64_uint64)
1925    self.assertTrue(123 in msg.map_int32_double)
1926    self.assertTrue(False in msg.map_bool_bool)
1927    self.assertTrue('abc' in msg.map_string_string)
1928    self.assertTrue(111 in msg.map_int32_bytes)
1929    self.assertTrue(888 in msg.map_int32_enum)
1930
1931    self.assertIsInstance(msg.map_string_string['abc'], str)
1932
1933    # Accessing an unset key still throws TypeError if the type of the key
1934    # is incorrect.
1935    with self.assertRaises(TypeError):
1936      msg.map_string_string[123]
1937
1938    with self.assertRaises(TypeError):
1939      123 in msg.map_string_string
1940
1941    with self.assertRaises(TypeError):
1942      msg.map_string_string.__contains__(123)
1943
1944  def testScalarMapComparison(self):
1945    msg1 = map_unittest_pb2.TestMap()
1946    msg2 = map_unittest_pb2.TestMap()
1947
1948    self.assertEqual(msg1.map_int32_int32, msg2.map_int32_int32)
1949
1950  def testMessageMapComparison(self):
1951    msg1 = map_unittest_pb2.TestMap()
1952    msg2 = map_unittest_pb2.TestMap()
1953
1954    self.assertEqual(msg1.map_int32_foreign_message,
1955                     msg2.map_int32_foreign_message)
1956
1957  def testMapGet(self):
1958    # Need to test that get() properly returns the default, even though the dict
1959    # has defaultdict-like semantics.
1960    msg = map_unittest_pb2.TestMap()
1961
1962    self.assertIsNone(msg.map_int32_int32.get(5))
1963    self.assertEqual(10, msg.map_int32_int32.get(5, 10))
1964    self.assertEqual(10, msg.map_int32_int32.get(key=5, default=10))
1965    self.assertIsNone(msg.map_int32_int32.get(5))
1966
1967    msg.map_int32_int32[5] = 15
1968    self.assertEqual(15, msg.map_int32_int32.get(5))
1969    self.assertEqual(15, msg.map_int32_int32.get(5))
1970    with self.assertRaises(TypeError):
1971      msg.map_int32_int32.get('')
1972
1973    self.assertIsNone(msg.map_int32_foreign_message.get(5))
1974    self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
1975    self.assertEqual(10, msg.map_int32_foreign_message.get(key=5, default=10))
1976
1977    submsg = msg.map_int32_foreign_message[5]
1978    self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
1979    with self.assertRaises(TypeError):
1980      msg.map_int32_foreign_message.get('')
1981
1982  def testScalarMap(self):
1983    msg = map_unittest_pb2.TestMap()
1984
1985    self.assertEqual(0, len(msg.map_int32_int32))
1986    self.assertFalse(5 in msg.map_int32_int32)
1987
1988    msg.map_int32_int32[-123] = -456
1989    msg.map_int64_int64[-2**33] = -2**34
1990    msg.map_uint32_uint32[123] = 456
1991    msg.map_uint64_uint64[2**33] = 2**34
1992    msg.map_int32_float[2] = 1.2
1993    msg.map_int32_double[1] = 3.3
1994    msg.map_string_string['abc'] = '123'
1995    msg.map_bool_bool[True] = True
1996    msg.map_int32_enum[888] = 2
1997    # Unknown numeric enum is supported in proto3.
1998    msg.map_int32_enum[123] = 456
1999
2000    self.assertEqual([], msg.FindInitializationErrors())
2001
2002    self.assertEqual(1, len(msg.map_string_string))
2003
2004    # Bad key.
2005    with self.assertRaises(TypeError):
2006      msg.map_string_string[123] = '123'
2007
2008    # Verify that trying to assign a bad key doesn't actually add a member to
2009    # the map.
2010    self.assertEqual(1, len(msg.map_string_string))
2011
2012    # Bad value.
2013    with self.assertRaises(TypeError):
2014      msg.map_string_string['123'] = 123
2015
2016    serialized = msg.SerializeToString()
2017    msg2 = map_unittest_pb2.TestMap()
2018    msg2.ParseFromString(serialized)
2019
2020    # Bad key.
2021    with self.assertRaises(TypeError):
2022      msg2.map_string_string[123] = '123'
2023
2024    # Bad value.
2025    with self.assertRaises(TypeError):
2026      msg2.map_string_string['123'] = 123
2027
2028    self.assertEqual(-456, msg2.map_int32_int32[-123])
2029    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
2030    self.assertEqual(456, msg2.map_uint32_uint32[123])
2031    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
2032    self.assertAlmostEqual(1.2, msg.map_int32_float[2])
2033    self.assertEqual(3.3, msg.map_int32_double[1])
2034    self.assertEqual('123', msg2.map_string_string['abc'])
2035    self.assertEqual(True, msg2.map_bool_bool[True])
2036    self.assertEqual(2, msg2.map_int32_enum[888])
2037    self.assertEqual(456, msg2.map_int32_enum[123])
2038    self.assertEqual('{-123: -456}', str(msg2.map_int32_int32))
2039
2040  def testMapEntryAlwaysSerialized(self):
2041    msg = map_unittest_pb2.TestMap()
2042    msg.map_int32_int32[0] = 0
2043    msg.map_string_string[''] = ''
2044    self.assertEqual(msg.ByteSize(), 12)
2045    self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00',
2046                     msg.SerializeToString())
2047
2048  def testStringUnicodeConversionInMap(self):
2049    msg = map_unittest_pb2.TestMap()
2050
2051    unicode_obj = u'\u1234'
2052    bytes_obj = unicode_obj.encode('utf8')
2053
2054    msg.map_string_string[bytes_obj] = bytes_obj
2055
2056    (key, value) = list(msg.map_string_string.items())[0]
2057
2058    self.assertEqual(key, unicode_obj)
2059    self.assertEqual(value, unicode_obj)
2060
2061    self.assertIsInstance(key, str)
2062    self.assertIsInstance(value, str)
2063
2064  def testMessageMap(self):
2065    msg = map_unittest_pb2.TestMap()
2066
2067    self.assertEqual(0, len(msg.map_int32_foreign_message))
2068    self.assertFalse(5 in msg.map_int32_foreign_message)
2069
2070    msg.map_int32_foreign_message[123]
2071    # get_or_create() is an alias for getitem.
2072    msg.map_int32_foreign_message.get_or_create(-456)
2073
2074    self.assertEqual(2, len(msg.map_int32_foreign_message))
2075    self.assertIn(123, msg.map_int32_foreign_message)
2076    self.assertIn(-456, msg.map_int32_foreign_message)
2077    self.assertEqual(2, len(msg.map_int32_foreign_message))
2078
2079    # Bad key.
2080    with self.assertRaises(TypeError):
2081      msg.map_int32_foreign_message['123']
2082
2083    with self.assertRaises(TypeError):
2084      '123' in msg.map_int32_foreign_message
2085
2086    with self.assertRaises(TypeError):
2087      msg.map_int32_foreign_message.__contains__('123')
2088
2089    # Can't assign directly to submessage.
2090    with self.assertRaises(ValueError):
2091      msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
2092
2093    # Verify that trying to assign a bad key doesn't actually add a member to
2094    # the map.
2095    self.assertEqual(2, len(msg.map_int32_foreign_message))
2096
2097    serialized = msg.SerializeToString()
2098    msg2 = map_unittest_pb2.TestMap()
2099    msg2.ParseFromString(serialized)
2100
2101    self.assertEqual(2, len(msg2.map_int32_foreign_message))
2102    self.assertIn(123, msg2.map_int32_foreign_message)
2103    self.assertIn(-456, msg2.map_int32_foreign_message)
2104    self.assertEqual(2, len(msg2.map_int32_foreign_message))
2105    msg2.map_int32_foreign_message[123].c = 1
2106    # TODO: Fix text format for message map.
2107    self.assertIn(
2108        str(msg2.map_int32_foreign_message),
2109        ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
2110
2111  def testNestedMessageMapItemDelete(self):
2112    msg = map_unittest_pb2.TestMap()
2113    msg.map_int32_all_types[1].optional_nested_message.bb = 1
2114    del msg.map_int32_all_types[1]
2115    msg.map_int32_all_types[2].optional_nested_message.bb = 2
2116    self.assertEqual(1, len(msg.map_int32_all_types))
2117    msg.map_int32_all_types[1].optional_nested_message.bb = 1
2118    self.assertEqual(2, len(msg.map_int32_all_types))
2119
2120    serialized = msg.SerializeToString()
2121    msg2 = map_unittest_pb2.TestMap()
2122    msg2.ParseFromString(serialized)
2123    keys = [1, 2]
2124    # The loop triggers PyErr_Occurred() in c extension.
2125    for key in keys:
2126      del msg2.map_int32_all_types[key]
2127
2128  def testMapByteSize(self):
2129    msg = map_unittest_pb2.TestMap()
2130    msg.map_int32_int32[1] = 1
2131    size = msg.ByteSize()
2132    msg.map_int32_int32[1] = 128
2133    self.assertEqual(msg.ByteSize(), size + 1)
2134
2135    msg.map_int32_foreign_message[19].c = 1
2136    size = msg.ByteSize()
2137    msg.map_int32_foreign_message[19].c = 128
2138    self.assertEqual(msg.ByteSize(), size + 1)
2139
2140  def testMergeFrom(self):
2141    msg = map_unittest_pb2.TestMap()
2142    msg.map_int32_int32[12] = 34
2143    msg.map_int32_int32[56] = 78
2144    msg.map_int64_int64[22] = 33
2145    msg.map_int32_foreign_message[111].c = 5
2146    msg.map_int32_foreign_message[222].c = 10
2147
2148    msg2 = map_unittest_pb2.TestMap()
2149    msg2.map_int32_int32[12] = 55
2150    msg2.map_int64_int64[88] = 99
2151    msg2.map_int32_foreign_message[222].c = 15
2152    msg2.map_int32_foreign_message[222].d = 20
2153    old_map_value = msg2.map_int32_foreign_message[222]
2154
2155    msg2.MergeFrom(msg)
2156    # Compare with expected message instead of call
2157    # msg2.map_int32_foreign_message[222] to make sure MergeFrom does not
2158    # sync with repeated field and there is no duplicated keys.
2159    expected_msg = map_unittest_pb2.TestMap()
2160    expected_msg.CopyFrom(msg)
2161    expected_msg.map_int64_int64[88] = 99
2162    self.assertEqual(msg2, expected_msg)
2163
2164    self.assertEqual(34, msg2.map_int32_int32[12])
2165    self.assertEqual(78, msg2.map_int32_int32[56])
2166    self.assertEqual(33, msg2.map_int64_int64[22])
2167    self.assertEqual(99, msg2.map_int64_int64[88])
2168    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
2169    self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
2170    self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
2171    if api_implementation.Type() != 'cpp':
2172      # During the call to MergeFrom(), the C++ implementation will have
2173      # deallocated the underlying message, but this is very difficult to detect
2174      # properly. The line below is likely to cause a segmentation fault.
2175      # With the Python implementation, old_map_value is just 'detached' from
2176      # the main message. Using it will not crash of course, but since it still
2177      # have a reference to the parent message I'm sure we can find interesting
2178      # ways to cause inconsistencies.
2179      self.assertEqual(15, old_map_value.c)
2180
2181    # Verify that there is only one entry per key, even though the MergeFrom
2182    # may have internally created multiple entries for a single key in the
2183    # list representation.
2184    as_dict = {}
2185    for key in msg2.map_int32_foreign_message:
2186      self.assertFalse(key in as_dict)
2187      as_dict[key] = msg2.map_int32_foreign_message[key].c
2188
2189    self.assertEqual({111: 5, 222: 10}, as_dict)
2190
2191    # Special case: test that delete of item really removes the item, even if
2192    # there might have physically been duplicate keys due to the previous merge.
2193    # This is only a special case for the C++ implementation which stores the
2194    # map as an array.
2195    del msg2.map_int32_int32[12]
2196    self.assertFalse(12 in msg2.map_int32_int32)
2197
2198    del msg2.map_int32_foreign_message[222]
2199    self.assertFalse(222 in msg2.map_int32_foreign_message)
2200    with self.assertRaises(TypeError):
2201      del msg2.map_int32_foreign_message['']
2202
2203  def testMapMergeFrom(self):
2204    msg = map_unittest_pb2.TestMap()
2205    msg.map_int32_int32[12] = 34
2206    msg.map_int32_int32[56] = 78
2207    msg.map_int64_int64[22] = 33
2208    msg.map_int32_foreign_message[111].c = 5
2209    msg.map_int32_foreign_message[222].c = 10
2210
2211    msg2 = map_unittest_pb2.TestMap()
2212    msg2.map_int32_int32[12] = 55
2213    msg2.map_int64_int64[88] = 99
2214    msg2.map_int32_foreign_message[222].c = 15
2215    msg2.map_int32_foreign_message[222].d = 20
2216
2217    msg2.map_int32_int32.MergeFrom(msg.map_int32_int32)
2218    self.assertEqual(34, msg2.map_int32_int32[12])
2219    self.assertEqual(78, msg2.map_int32_int32[56])
2220
2221    msg2.map_int64_int64.MergeFrom(msg.map_int64_int64)
2222    self.assertEqual(33, msg2.map_int64_int64[22])
2223    self.assertEqual(99, msg2.map_int64_int64[88])
2224
2225    msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message)
2226    # Compare with expected message instead of call
2227    # msg.map_int32_foreign_message[222] to make sure MergeFrom does not
2228    # sync with repeated field and no duplicated keys.
2229    expected_msg = map_unittest_pb2.TestMap()
2230    expected_msg.CopyFrom(msg)
2231    expected_msg.map_int64_int64[88] = 99
2232    self.assertEqual(msg2, expected_msg)
2233
2234    # Test when cpp extension cache a map.
2235    m1 = map_unittest_pb2.TestMap()
2236    m2 = map_unittest_pb2.TestMap()
2237    self.assertEqual(m1.map_int32_foreign_message, m1.map_int32_foreign_message)
2238    m2.map_int32_foreign_message[123].c = 10
2239    m1.MergeFrom(m2)
2240    self.assertEqual(10, m2.map_int32_foreign_message[123].c)
2241
2242    # Test merge maps within different message types.
2243    m1 = map_unittest_pb2.TestMap()
2244    m2 = map_unittest_pb2.TestMessageMap()
2245    m2.map_int32_message[123].optional_int32 = 10
2246    m1.map_int32_all_types.MergeFrom(m2.map_int32_message)
2247    self.assertEqual(10, m1.map_int32_all_types[123].optional_int32)
2248
2249    # Test overwrite message value map
2250    msg = map_unittest_pb2.TestMap()
2251    msg.map_int32_foreign_message[222].c = 123
2252    msg2 = map_unittest_pb2.TestMap()
2253    msg2.map_int32_foreign_message[222].d = 20
2254    msg.MergeFromString(msg2.SerializeToString())
2255    self.assertEqual(msg.map_int32_foreign_message[222].d, 20)
2256    self.assertNotEqual(msg.map_int32_foreign_message[222].c, 123)
2257
2258    # Merge a dict to map field is not accepted
2259    with self.assertRaises(AttributeError):
2260      m1.map_int32_all_types.MergeFrom(
2261          {1: unittest_proto3_arena_pb2.TestAllTypes()})
2262
2263  def testMergeFromBadType(self):
2264    msg = map_unittest_pb2.TestMap()
2265    with self.assertRaisesRegex(
2266        TypeError,
2267        r'Parameter to MergeFrom\(\) must be instance of same class: expected '
2268        r'.+TestMap got int\.'):
2269      msg.MergeFrom(1)
2270
2271  def testCopyFromBadType(self):
2272    msg = map_unittest_pb2.TestMap()
2273    with self.assertRaisesRegex(
2274        TypeError,
2275        r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
2276        r'expected .+TestMap got int\.'):
2277      msg.CopyFrom(1)
2278
2279  def testIntegerMapWithLongs(self):
2280    msg = map_unittest_pb2.TestMap()
2281    msg.map_int32_int32[int(-123)] = int(-456)
2282    msg.map_int64_int64[int(-2**33)] = int(-2**34)
2283    msg.map_uint32_uint32[int(123)] = int(456)
2284    msg.map_uint64_uint64[int(2**33)] = int(2**34)
2285
2286    serialized = msg.SerializeToString()
2287    msg2 = map_unittest_pb2.TestMap()
2288    msg2.ParseFromString(serialized)
2289
2290    self.assertEqual(-456, msg2.map_int32_int32[-123])
2291    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
2292    self.assertEqual(456, msg2.map_uint32_uint32[123])
2293    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
2294
2295  def testMapAssignmentCausesPresence(self):
2296    msg = map_unittest_pb2.TestMapSubmessage()
2297    msg.test_map.map_int32_int32[123] = 456
2298
2299    serialized = msg.SerializeToString()
2300    msg2 = map_unittest_pb2.TestMapSubmessage()
2301    msg2.ParseFromString(serialized)
2302
2303    self.assertEqual(msg, msg2)
2304
2305    # Now test that various mutations of the map properly invalidate the
2306    # cached size of the submessage.
2307    msg.test_map.map_int32_int32[888] = 999
2308    serialized = msg.SerializeToString()
2309    msg2.ParseFromString(serialized)
2310    self.assertEqual(msg, msg2)
2311
2312    msg.test_map.map_int32_int32.clear()
2313    serialized = msg.SerializeToString()
2314    msg2.ParseFromString(serialized)
2315    self.assertEqual(msg, msg2)
2316
2317  def testMapAssignmentCausesPresenceForSubmessages(self):
2318    msg = map_unittest_pb2.TestMapSubmessage()
2319    msg.test_map.map_int32_foreign_message[123].c = 5
2320
2321    serialized = msg.SerializeToString()
2322    msg2 = map_unittest_pb2.TestMapSubmessage()
2323    msg2.ParseFromString(serialized)
2324
2325    self.assertEqual(msg, msg2)
2326
2327    # Now test that various mutations of the map properly invalidate the
2328    # cached size of the submessage.
2329    msg.test_map.map_int32_foreign_message[888].c = 7
2330    serialized = msg.SerializeToString()
2331    msg2.ParseFromString(serialized)
2332    self.assertEqual(msg, msg2)
2333
2334    msg.test_map.map_int32_foreign_message[888].MergeFrom(
2335        msg.test_map.map_int32_foreign_message[123])
2336    serialized = msg.SerializeToString()
2337    msg2.ParseFromString(serialized)
2338    self.assertEqual(msg, msg2)
2339
2340    msg.test_map.map_int32_foreign_message.clear()
2341    serialized = msg.SerializeToString()
2342    msg2.ParseFromString(serialized)
2343    self.assertEqual(msg, msg2)
2344
2345  def testModifyMapWhileIterating(self):
2346    msg = map_unittest_pb2.TestMap()
2347
2348    string_string_iter = iter(msg.map_string_string)
2349    int32_foreign_iter = iter(msg.map_int32_foreign_message)
2350
2351    msg.map_string_string['abc'] = '123'
2352    msg.map_int32_foreign_message[5].c = 5
2353
2354    with self.assertRaises(RuntimeError):
2355      for key in string_string_iter:
2356        pass
2357
2358    with self.assertRaises(RuntimeError):
2359      for key in int32_foreign_iter:
2360        pass
2361
2362  def testModifyMapEntryWhileIterating(self):
2363    msg = map_unittest_pb2.TestMap()
2364
2365    msg.map_string_string['abc'] = '123'
2366    msg.map_string_string['def'] = '456'
2367    msg.map_string_string['ghi'] = '789'
2368
2369    msg.map_int32_foreign_message[5].c = 5
2370    msg.map_int32_foreign_message[6].c = 6
2371    msg.map_int32_foreign_message[7].c = 7
2372
2373    string_string_keys = list(msg.map_string_string.keys())
2374    int32_foreign_keys = list(msg.map_int32_foreign_message.keys())
2375
2376    keys = []
2377    for key in msg.map_string_string:
2378      keys.append(key)
2379      msg.map_string_string[key] = '000'
2380    self.assertEqual(keys, string_string_keys)
2381    self.assertEqual(keys, list(msg.map_string_string.keys()))
2382
2383    keys = []
2384    for key in msg.map_int32_foreign_message:
2385      keys.append(key)
2386      msg.map_int32_foreign_message[key].c = 0
2387    self.assertEqual(keys, int32_foreign_keys)
2388    self.assertEqual(keys, list(msg.map_int32_foreign_message.keys()))
2389
2390  def testSubmessageMap(self):
2391    msg = map_unittest_pb2.TestMap()
2392
2393    submsg = msg.map_int32_foreign_message[111]
2394    self.assertIs(submsg, msg.map_int32_foreign_message[111])
2395    self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
2396
2397    submsg.c = 5
2398
2399    serialized = msg.SerializeToString()
2400    msg2 = map_unittest_pb2.TestMap()
2401    msg2.ParseFromString(serialized)
2402
2403    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
2404
2405    # Doesn't allow direct submessage assignment.
2406    with self.assertRaises(ValueError):
2407      msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
2408
2409  def testMapIteration(self):
2410    msg = map_unittest_pb2.TestMap()
2411
2412    for k, v in msg.map_int32_int32.items():
2413      # Should not be reached.
2414      self.assertTrue(False)
2415
2416    msg.map_int32_int32[2] = 4
2417    msg.map_int32_int32[3] = 6
2418    msg.map_int32_int32[4] = 8
2419    self.assertEqual(3, len(msg.map_int32_int32))
2420
2421    matching_dict = {2: 4, 3: 6, 4: 8}
2422    self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
2423
2424  def testMapItems(self):
2425    # Map items used to have strange behaviors when use c extension. Because
2426    # [] may reorder the map and invalidate any existing iterators.
2427    # TODO: Check if [] reordering the map is a bug or intended
2428    # behavior.
2429    msg = map_unittest_pb2.TestMap()
2430    msg.map_string_string['local_init_op'] = ''
2431    msg.map_string_string['trainable_variables'] = ''
2432    msg.map_string_string['variables'] = ''
2433    msg.map_string_string['init_op'] = ''
2434    msg.map_string_string['summaries'] = ''
2435    items1 = msg.map_string_string.items()
2436    items2 = msg.map_string_string.items()
2437    self.assertEqual(items1, items2)
2438
2439  def testMapDeterministicSerialization(self):
2440    golden_data = (b'r\x0c\n\x07init_op\x12\x01d'
2441                   b'r\n\n\x05item1\x12\x01e'
2442                   b'r\n\n\x05item2\x12\x01f'
2443                   b'r\n\n\x05item3\x12\x01g'
2444                   b'r\x0b\n\x05item4\x12\x02QQ'
2445                   b'r\x12\n\rlocal_init_op\x12\x01a'
2446                   b'r\x0e\n\tsummaries\x12\x01e'
2447                   b'r\x18\n\x13trainable_variables\x12\x01b'
2448                   b'r\x0e\n\tvariables\x12\x01c')
2449    msg = map_unittest_pb2.TestMap()
2450    msg.map_string_string['local_init_op'] = 'a'
2451    msg.map_string_string['trainable_variables'] = 'b'
2452    msg.map_string_string['variables'] = 'c'
2453    msg.map_string_string['init_op'] = 'd'
2454    msg.map_string_string['summaries'] = 'e'
2455    msg.map_string_string['item1'] = 'e'
2456    msg.map_string_string['item2'] = 'f'
2457    msg.map_string_string['item3'] = 'g'
2458    msg.map_string_string['item4'] = 'QQ'
2459
2460    # If deterministic serialization is not working correctly, this will be
2461    # "flaky" depending on the exact python dict hash seed.
2462    #
2463    # Fortunately, there are enough items in this map that it is extremely
2464    # unlikely to ever hit the "right" in-order combination, so the test
2465    # itself should fail reliably.
2466    self.assertEqual(golden_data, msg.SerializeToString(deterministic=True))
2467
2468  def testMapIterationClearMessage(self):
2469    # Iterator needs to work even if message and map are deleted.
2470    msg = map_unittest_pb2.TestMap()
2471
2472    msg.map_int32_int32[2] = 4
2473    msg.map_int32_int32[3] = 6
2474    msg.map_int32_int32[4] = 8
2475
2476    it = msg.map_int32_int32.items()
2477    del msg
2478
2479    matching_dict = {2: 4, 3: 6, 4: 8}
2480    self.assertMapIterEquals(it, matching_dict)
2481
2482  def testMapConstruction(self):
2483    msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
2484    self.assertEqual(2, msg.map_int32_int32[1])
2485    self.assertEqual(4, msg.map_int32_int32[3])
2486
2487    msg = map_unittest_pb2.TestMap(
2488        map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
2489    self.assertEqual(5, msg.map_int32_foreign_message[3].c)
2490
2491  def testMapScalarFieldConstruction(self):
2492    msg1 = map_unittest_pb2.TestMap()
2493    msg1.map_int32_int32[1] = 42
2494    msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32)
2495    self.assertEqual(42, msg2.map_int32_int32[1])
2496
2497  def testMapMessageFieldConstruction(self):
2498    msg1 = map_unittest_pb2.TestMap()
2499    msg1.map_string_foreign_message['test'].c = 42
2500    msg2 = map_unittest_pb2.TestMap(
2501        map_string_foreign_message=msg1.map_string_foreign_message)
2502    self.assertEqual(42, msg2.map_string_foreign_message['test'].c)
2503
2504  def testMapFieldRaisesCorrectError(self):
2505    # Should raise a TypeError when given a non-iterable.
2506    with self.assertRaises(TypeError):
2507      map_unittest_pb2.TestMap(map_string_foreign_message=1)
2508
2509  def testMapValidAfterFieldCleared(self):
2510    # Map needs to work even if field is cleared.
2511    # For the C++ implementation this tests the correctness of
2512    # MapContainer::Release()
2513    msg = map_unittest_pb2.TestMap()
2514    int32_map = msg.map_int32_int32
2515
2516    int32_map[2] = 4
2517    int32_map[3] = 6
2518    int32_map[4] = 8
2519
2520    msg.ClearField('map_int32_int32')
2521    self.assertEqual(b'', msg.SerializeToString())
2522    matching_dict = {2: 4, 3: 6, 4: 8}
2523    self.assertMapIterEquals(int32_map.items(), matching_dict)
2524
2525  def testMessageMapValidAfterFieldCleared(self):
2526    # Map needs to work even if field is cleared.
2527    # For the C++ implementation this tests the correctness of
2528    # MapContainer::Release()
2529    msg = map_unittest_pb2.TestMap()
2530    int32_foreign_message = msg.map_int32_foreign_message
2531
2532    int32_foreign_message[2].c = 5
2533
2534    msg.ClearField('map_int32_foreign_message')
2535    self.assertEqual(b'', msg.SerializeToString())
2536    self.assertTrue(2 in int32_foreign_message.keys())
2537
2538  def testMessageMapItemValidAfterTopMessageCleared(self):
2539    # Message map item needs to work even if it is cleared.
2540    # For the C++ implementation this tests the correctness of
2541    # MapContainer::Release()
2542    msg = map_unittest_pb2.TestMap()
2543    msg.map_int32_all_types[2].optional_string = 'bar'
2544
2545    if api_implementation.Type() == 'cpp':
2546      # Need to keep the map reference because of b/27942626.
2547      # TODO: Remove it.
2548      unused_map = msg.map_int32_all_types  # pylint: disable=unused-variable
2549    msg_value = msg.map_int32_all_types[2]
2550    msg.Clear()
2551
2552    # Reset to trigger sync between repeated field and map in c++.
2553    msg.map_int32_all_types[3].optional_string = 'foo'
2554    self.assertEqual(msg_value.optional_string, 'bar')
2555
2556  def testMapIterInvalidatedByClearField(self):
2557    # Map iterator is invalidated when field is cleared.
2558    # But this case does need to not crash the interpreter.
2559    # For the C++ implementation this tests the correctness of
2560    # ScalarMapContainer::Release()
2561    msg = map_unittest_pb2.TestMap()
2562
2563    it = iter(msg.map_int32_int32)
2564
2565    msg.ClearField('map_int32_int32')
2566    with self.assertRaises(RuntimeError):
2567      for _ in it:
2568        pass
2569
2570    it = iter(msg.map_int32_foreign_message)
2571    msg.ClearField('map_int32_foreign_message')
2572    with self.assertRaises(RuntimeError):
2573      for _ in it:
2574        pass
2575
2576  def testMapDelete(self):
2577    msg = map_unittest_pb2.TestMap()
2578
2579    self.assertEqual(0, len(msg.map_int32_int32))
2580
2581    msg.map_int32_int32[4] = 6
2582    self.assertEqual(1, len(msg.map_int32_int32))
2583
2584    with self.assertRaises(KeyError):
2585      del msg.map_int32_int32[88]
2586
2587    del msg.map_int32_int32[4]
2588    self.assertEqual(0, len(msg.map_int32_int32))
2589
2590    with self.assertRaises(KeyError):
2591      del msg.map_int32_all_types[32]
2592
2593  def testMapsAreMapping(self):
2594    msg = map_unittest_pb2.TestMap()
2595    self.assertIsInstance(msg.map_int32_int32, collections.abc.Mapping)
2596    self.assertIsInstance(msg.map_int32_int32, collections.abc.MutableMapping)
2597    self.assertIsInstance(msg.map_int32_foreign_message,
2598                          collections.abc.Mapping)
2599    self.assertIsInstance(msg.map_int32_foreign_message,
2600                          collections.abc.MutableMapping)
2601
2602  def testMapsCompare(self):
2603    msg = map_unittest_pb2.TestMap()
2604    msg.map_int32_int32[-123] = -456
2605    self.assertEqual(msg.map_int32_int32, msg.map_int32_int32)
2606    self.assertEqual(msg.map_int32_foreign_message,
2607                     msg.map_int32_foreign_message)
2608    self.assertNotEqual(msg.map_int32_int32, 0)
2609
2610  def testMapFindInitializationErrorsSmokeTest(self):
2611    msg = map_unittest_pb2.TestMap()
2612    msg.map_string_string['abc'] = '123'
2613    msg.map_int32_int32[35] = 64
2614    msg.map_string_foreign_message['foo'].c = 5
2615    self.assertEqual(0, len(msg.FindInitializationErrors()))
2616
2617  @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2')
2618  def testStrictUtf8Check(self):
2619    # Test u'\ud801' is rejected at parser in both python2 and python3.
2620    serialized = (b'r\x03\xed\xa0\x81')
2621    msg = unittest_proto3_arena_pb2.TestAllTypes()
2622    with self.assertRaises(Exception) as context:
2623      msg.MergeFromString(serialized)
2624    if api_implementation.Type() == 'python':
2625      self.assertIn('optional_string', str(context.exception))
2626    else:
2627      self.assertIn('Error parsing message', str(context.exception))
2628
2629    # Test optional_string=u'��' is accepted.
2630    serialized = unittest_proto3_arena_pb2.TestAllTypes(
2631        optional_string=u'��').SerializeToString()
2632    msg2 = unittest_proto3_arena_pb2.TestAllTypes()
2633    msg2.MergeFromString(serialized)
2634    self.assertEqual(msg2.optional_string, u'��')
2635
2636    msg = unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud001')
2637    self.assertEqual(msg.optional_string, u'\ud001')
2638
2639  def testSurrogatesInPython3(self):
2640    # Surrogates are rejected at setters in Python3.
2641    with self.assertRaises(ValueError):
2642      unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\udc01')
2643    with self.assertRaises(ValueError):
2644      unittest_proto3_arena_pb2.TestAllTypes(optional_string=b'\xed\xa0\x81')
2645    with self.assertRaises(ValueError):
2646      unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801')
2647    with self.assertRaises(ValueError):
2648      unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\ud801')
2649
2650  def testCrashNullAA(self):
2651    self.assertEqual(
2652        unittest_proto3_arena_pb2.TestAllTypes.NestedMessage(),
2653        unittest_proto3_arena_pb2.TestAllTypes.NestedMessage())
2654
2655  def testCrashNullAB(self):
2656    self.assertEqual(
2657        unittest_proto3_arena_pb2.TestAllTypes.NestedMessage(),
2658        unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message)
2659
2660  def testCrashNullBA(self):
2661    self.assertEqual(
2662        unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message,
2663        unittest_proto3_arena_pb2.TestAllTypes.NestedMessage())
2664
2665  def testCrashNullBB(self):
2666    self.assertEqual(
2667        unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message,
2668        unittest_proto3_arena_pb2.TestAllTypes().optional_nested_message)
2669
2670
2671@testing_refleaks.TestCase
2672class ValidTypeNamesTest(unittest.TestCase):
2673
2674  def assertImportFromName(self, msg, base_name):
2675    # Parse <type 'module.class_name'> to extra 'some.name' as a string.
2676    tp_name = str(type(msg)).split("'")[1]
2677    valid_names = ('Repeated%sContainer' % base_name,
2678                   'Repeated%sFieldContainer' % base_name)
2679    self.assertTrue(
2680        any(tp_name.endswith(v) for v in valid_names),
2681        '%r does end with any of %r' % (tp_name, valid_names))
2682
2683    parts = tp_name.split('.')
2684    class_name = parts[-1]
2685    module_name = '.'.join(parts[:-1])
2686    __import__(module_name, fromlist=[class_name])
2687
2688  def testTypeNamesCanBeImported(self):
2689    # If import doesn't work, pickling won't work either.
2690    pb = unittest_pb2.TestAllTypes()
2691    self.assertImportFromName(pb.repeated_int32, 'Scalar')
2692    self.assertImportFromName(pb.repeated_nested_message, 'Composite')
2693
2694
2695# We can only test this case under proto2, because proto3 will reject invalid
2696# UTF-8 in the parser, so there should be no way of creating a string field
2697# that contains invalid UTF-8.
2698#
2699# We also can't test it in pure-Python, which validates all string fields for
2700# UTF-8 even when the spec says it shouldn't.
2701@unittest.skipIf(api_implementation.Type() == 'python',
2702                 'Python can\'t create invalid UTF-8 strings')
2703@testing_refleaks.TestCase
2704class InvalidUtf8Test(unittest.TestCase):
2705
2706  def testInvalidUtf8Printing(self):
2707    one_bytes = unittest_pb2.OneBytes()
2708    one_bytes.data = b'ABC\xff123'
2709    one_string = unittest_pb2.OneString()
2710    one_string.ParseFromString(one_bytes.SerializeToString())
2711    self.assertIn('data: "ABC\\377123"', str(one_string))
2712
2713  def testValidUtf8Printing(self):
2714    self.assertIn('data: "€"', str(unittest_pb2.OneString(data='€')))  # 2 byte
2715    self.assertIn('data: "£"', str(unittest_pb2.OneString(data='£')))  # 3 byte
2716    self.assertIn('data: "��"', str(unittest_pb2.OneString(data='��')))  # 4 byte
2717
2718
2719@testing_refleaks.TestCase
2720class PackedFieldTest(unittest.TestCase):
2721
2722  def setMessage(self, message):
2723    message.repeated_int32.append(1)
2724    message.repeated_int64.append(1)
2725    message.repeated_uint32.append(1)
2726    message.repeated_uint64.append(1)
2727    message.repeated_sint32.append(1)
2728    message.repeated_sint64.append(1)
2729    message.repeated_fixed32.append(1)
2730    message.repeated_fixed64.append(1)
2731    message.repeated_sfixed32.append(1)
2732    message.repeated_sfixed64.append(1)
2733    message.repeated_float.append(1.0)
2734    message.repeated_double.append(1.0)
2735    message.repeated_bool.append(True)
2736    message.repeated_nested_enum.append(1)
2737
2738  def testPackedFields(self):
2739    message = packed_field_test_pb2.TestPackedTypes()
2740    self.setMessage(message)
2741    golden_data = (b'\x0A\x01\x01'
2742                   b'\x12\x01\x01'
2743                   b'\x1A\x01\x01'
2744                   b'\x22\x01\x01'
2745                   b'\x2A\x01\x02'
2746                   b'\x32\x01\x02'
2747                   b'\x3A\x04\x01\x00\x00\x00'
2748                   b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2749                   b'\x4A\x04\x01\x00\x00\x00'
2750                   b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2751                   b'\x5A\x04\x00\x00\x80\x3f'
2752                   b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
2753                   b'\x6A\x01\x01'
2754                   b'\x72\x01\x01')
2755    self.assertEqual(golden_data, message.SerializeToString())
2756
2757  def testUnpackedFields(self):
2758    message = packed_field_test_pb2.TestUnpackedTypes()
2759    self.setMessage(message)
2760    golden_data = (b'\x08\x01'
2761                   b'\x10\x01'
2762                   b'\x18\x01'
2763                   b'\x20\x01'
2764                   b'\x28\x02'
2765                   b'\x30\x02'
2766                   b'\x3D\x01\x00\x00\x00'
2767                   b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
2768                   b'\x4D\x01\x00\x00\x00'
2769                   b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
2770                   b'\x5D\x00\x00\x80\x3f'
2771                   b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
2772                   b'\x68\x01'
2773                   b'\x70\x01')
2774    self.assertEqual(golden_data, message.SerializeToString())
2775
2776
2777
2778@testing_refleaks.TestCase
2779class OversizeProtosTest(unittest.TestCase):
2780
2781  def GenerateNestedProto(self, n):
2782    msg = unittest_pb2.TestRecursiveMessage()
2783    sub = msg
2784    for _ in range(n):
2785      sub = sub.a
2786    sub.i = 0
2787    return msg.SerializeToString()
2788
2789  def testSucceedOkSizedProto(self):
2790    msg = unittest_pb2.TestRecursiveMessage()
2791    msg.ParseFromString(self.GenerateNestedProto(100))
2792
2793  def testAssertOversizeProto(self):
2794    if api_implementation.Type() != 'python':
2795      api_implementation._c_module.SetAllowOversizeProtos(False)
2796    msg = unittest_pb2.TestRecursiveMessage()
2797    with self.assertRaises(message.DecodeError) as context:
2798      msg.ParseFromString(self.GenerateNestedProto(101))
2799    self.assertIn('Error parsing message', str(context.exception))
2800
2801  def testSucceedOversizeProto(self):
2802    if api_implementation.Type() == 'python':
2803      decoder.SetRecursionLimit(310)
2804    else:
2805      api_implementation._c_module.SetAllowOversizeProtos(True)
2806    msg = unittest_pb2.TestRecursiveMessage()
2807    msg.ParseFromString(self.GenerateNestedProto(101))
2808    decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
2809
2810
2811if __name__ == '__main__':
2812  unittest.main()
2813