• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2# Protocol Buffers - Google's data interchange format
3# Copyright 2008 Google Inc.  All rights reserved.
4# https://developers.google.com/protocol-buffers/
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11# notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13# copyright notice, this list of conditions and the following disclaimer
14# in the documentation and/or other materials provided with the
15# distribution.
16#     * Neither the name of Google Inc. nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32"""Tests python protocol buffers against the golden message.
33
34Note that the golden messages exercise every known field type, thus this
35test ends up exercising and verifying nearly all of the parsing and
36serialization code in the whole library.
37
38TODO(kenton):  Merge with wire_format_test?  It doesn't make a whole lot of
39sense to call this a test of the "message" module, which only declares an
40abstract interface.
41"""
42
43__author__ = 'gps@google.com (Gregory P. Smith)'
44
45import collections
46import copy
47import math
48import operator
49import pickle
50import pydoc
51import sys
52import unittest
53import warnings
54
55cmp = lambda x, y: (x > y) - (x < y)
56
57from google.protobuf import map_proto2_unittest_pb2
58from google.protobuf import map_unittest_pb2
59from google.protobuf import unittest_pb2
60from google.protobuf import unittest_proto3_arena_pb2
61from google.protobuf import descriptor_pb2
62from google.protobuf import descriptor
63from google.protobuf import descriptor_pool
64from google.protobuf import message_factory
65from google.protobuf import text_format
66from google.protobuf.internal import api_implementation
67from google.protobuf.internal import encoder
68from google.protobuf.internal import more_extensions_pb2
69from google.protobuf.internal import packed_field_test_pb2
70from google.protobuf.internal import test_util
71from google.protobuf.internal import test_proto3_optional_pb2
72from google.protobuf.internal import testing_refleaks
73from google.protobuf import message
74from google.protobuf.internal import _parameterized
75
76UCS2_MAXUNICODE = 65535
77
78warnings.simplefilter('error', DeprecationWarning)
79
80
81@_parameterized.named_parameters(('_proto2', unittest_pb2),
82                                ('_proto3', unittest_proto3_arena_pb2))
83@testing_refleaks.TestCase
84class MessageTest(unittest.TestCase):
85
86  def testBadUtf8String(self, message_module):
87    if api_implementation.Type() != 'python':
88      self.skipTest('Skipping testBadUtf8String, currently only the python '
89                    'api implementation raises UnicodeDecodeError when a '
90                    'string field contains bad utf-8.')
91    bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
92    with self.assertRaises(UnicodeDecodeError) as context:
93      message_module.TestAllTypes.FromString(bad_utf8_data)
94    self.assertIn('TestAllTypes.optional_string', str(context.exception))
95
96  def testGoldenMessage(self, message_module):
97    # Proto3 doesn't have the "default_foo" members or foreign enums,
98    # and doesn't preserve unknown fields, so for proto3 we use a golden
99    # message that doesn't have these fields set.
100    if message_module is unittest_pb2:
101      golden_data = test_util.GoldenFileData('golden_message_oneof_implemented')
102    else:
103      golden_data = test_util.GoldenFileData('golden_message_proto3')
104
105    golden_message = message_module.TestAllTypes()
106    golden_message.ParseFromString(golden_data)
107    if message_module is unittest_pb2:
108      test_util.ExpectAllFieldsSet(self, golden_message)
109    self.assertEqual(golden_data, golden_message.SerializeToString())
110    golden_copy = copy.deepcopy(golden_message)
111    self.assertEqual(golden_data, golden_copy.SerializeToString())
112
113  def testGoldenPackedMessage(self, message_module):
114    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
115    golden_message = message_module.TestPackedTypes()
116    parsed_bytes = golden_message.ParseFromString(golden_data)
117    all_set = message_module.TestPackedTypes()
118    test_util.SetAllPackedFields(all_set)
119    self.assertEqual(parsed_bytes, len(golden_data))
120    self.assertEqual(all_set, golden_message)
121    self.assertEqual(golden_data, all_set.SerializeToString())
122    golden_copy = copy.deepcopy(golden_message)
123    self.assertEqual(golden_data, golden_copy.SerializeToString())
124
125  def testParseErrors(self, message_module):
126    msg = message_module.TestAllTypes()
127    self.assertRaises(TypeError, msg.FromString, 0)
128    self.assertRaises(Exception, msg.FromString, '0')
129    # TODO(jieluo): Fix cpp extension to raise error instead of warning.
130    # b/27494216
131    end_tag = encoder.TagBytes(1, 4)
132    if api_implementation.Type() == 'python':
133      with self.assertRaises(message.DecodeError) as context:
134        msg.FromString(end_tag)
135      self.assertEqual('Unexpected end-group tag.', str(context.exception))
136
137    # Field number 0 is illegal.
138    self.assertRaises(message.DecodeError, msg.FromString, b'\3\4')
139
140  def testDeterminismParameters(self, message_module):
141    # This message is always deterministically serialized, even if determinism
142    # is disabled, so we can use it to verify that all the determinism
143    # parameters work correctly.
144    golden_data = (b'\xe2\x02\nOne string'
145                   b'\xe2\x02\nTwo string'
146                   b'\xe2\x02\nRed string'
147                   b'\xe2\x02\x0bBlue string')
148    golden_message = message_module.TestAllTypes()
149    golden_message.repeated_string.extend([
150        'One string',
151        'Two string',
152        'Red string',
153        'Blue string',
154    ])
155    self.assertEqual(golden_data,
156                     golden_message.SerializeToString(deterministic=None))
157    self.assertEqual(golden_data,
158                     golden_message.SerializeToString(deterministic=False))
159    self.assertEqual(golden_data,
160                     golden_message.SerializeToString(deterministic=True))
161
162    class BadArgError(Exception):
163      pass
164
165    class BadArg(object):
166
167      def __nonzero__(self):
168        raise BadArgError()
169
170      def __bool__(self):
171        raise BadArgError()
172
173    with self.assertRaises(BadArgError):
174      golden_message.SerializeToString(deterministic=BadArg())
175
176  def testPickleSupport(self, message_module):
177    golden_data = test_util.GoldenFileData('golden_message')
178    golden_message = message_module.TestAllTypes()
179    golden_message.ParseFromString(golden_data)
180    pickled_message = pickle.dumps(golden_message)
181
182    unpickled_message = pickle.loads(pickled_message)
183    self.assertEqual(unpickled_message, golden_message)
184
185  def testPickleNestedMessage(self, message_module):
186    golden_message = message_module.TestPickleNestedMessage.NestedMessage(bb=1)
187    pickled_message = pickle.dumps(golden_message)
188    unpickled_message = pickle.loads(pickled_message)
189    self.assertEqual(unpickled_message, golden_message)
190
191  def testPickleNestedNestedMessage(self, message_module):
192    cls = message_module.TestPickleNestedMessage.NestedMessage
193    golden_message = cls.NestedNestedMessage(cc=1)
194    pickled_message = pickle.dumps(golden_message)
195    unpickled_message = pickle.loads(pickled_message)
196    self.assertEqual(unpickled_message, golden_message)
197
198  def testPositiveInfinity(self, message_module):
199    if message_module is unittest_pb2:
200      golden_data = (b'\x5D\x00\x00\x80\x7F'
201                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
202                     b'\xCD\x02\x00\x00\x80\x7F'
203                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
204    else:
205      golden_data = (b'\x5D\x00\x00\x80\x7F'
206                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
207                     b'\xCA\x02\x04\x00\x00\x80\x7F'
208                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
209
210    golden_message = message_module.TestAllTypes()
211    golden_message.ParseFromString(golden_data)
212    self.assertEqual(golden_message.optional_float, math.inf)
213    self.assertEqual(golden_message.optional_double, math.inf)
214    self.assertEqual(golden_message.repeated_float[0], math.inf)
215    self.assertEqual(golden_message.repeated_double[0], math.inf)
216    self.assertEqual(golden_data, golden_message.SerializeToString())
217
218  def testNegativeInfinity(self, message_module):
219    if message_module is unittest_pb2:
220      golden_data = (b'\x5D\x00\x00\x80\xFF'
221                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
222                     b'\xCD\x02\x00\x00\x80\xFF'
223                     b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
224    else:
225      golden_data = (b'\x5D\x00\x00\x80\xFF'
226                     b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
227                     b'\xCA\x02\x04\x00\x00\x80\xFF'
228                     b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
229
230    golden_message = message_module.TestAllTypes()
231    golden_message.ParseFromString(golden_data)
232    self.assertEqual(golden_message.optional_float, -math.inf)
233    self.assertEqual(golden_message.optional_double, -math.inf)
234    self.assertEqual(golden_message.repeated_float[0], -math.inf)
235    self.assertEqual(golden_message.repeated_double[0], -math.inf)
236    self.assertEqual(golden_data, golden_message.SerializeToString())
237
238  def testNotANumber(self, message_module):
239    golden_data = (b'\x5D\x00\x00\xC0\x7F'
240                   b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
241                   b'\xCD\x02\x00\x00\xC0\x7F'
242                   b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
243    golden_message = message_module.TestAllTypes()
244    golden_message.ParseFromString(golden_data)
245    self.assertTrue(math.isnan(golden_message.optional_float))
246    self.assertTrue(math.isnan(golden_message.optional_double))
247    self.assertTrue(math.isnan(golden_message.repeated_float[0]))
248    self.assertTrue(math.isnan(golden_message.repeated_double[0]))
249
250    # The protocol buffer may serialize to any one of multiple different
251    # representations of a NaN.  Rather than verify a specific representation,
252    # verify the serialized string can be converted into a correctly
253    # behaving protocol buffer.
254    serialized = golden_message.SerializeToString()
255    message = message_module.TestAllTypes()
256    message.ParseFromString(serialized)
257    self.assertTrue(math.isnan(message.optional_float))
258    self.assertTrue(math.isnan(message.optional_double))
259    self.assertTrue(math.isnan(message.repeated_float[0]))
260    self.assertTrue(math.isnan(message.repeated_double[0]))
261
262  def testPositiveInfinityPacked(self, message_module):
263    golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
264                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
265    golden_message = message_module.TestPackedTypes()
266    golden_message.ParseFromString(golden_data)
267    self.assertEqual(golden_message.packed_float[0], math.inf)
268    self.assertEqual(golden_message.packed_double[0], math.inf)
269    self.assertEqual(golden_data, golden_message.SerializeToString())
270
271  def testNegativeInfinityPacked(self, message_module):
272    golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
273                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
274    golden_message = message_module.TestPackedTypes()
275    golden_message.ParseFromString(golden_data)
276    self.assertEqual(golden_message.packed_float[0], -math.inf)
277    self.assertEqual(golden_message.packed_double[0], -math.inf)
278    self.assertEqual(golden_data, golden_message.SerializeToString())
279
280  def testNotANumberPacked(self, message_module):
281    golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
282                   b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
283    golden_message = message_module.TestPackedTypes()
284    golden_message.ParseFromString(golden_data)
285    self.assertTrue(math.isnan(golden_message.packed_float[0]))
286    self.assertTrue(math.isnan(golden_message.packed_double[0]))
287
288    serialized = golden_message.SerializeToString()
289    message = message_module.TestPackedTypes()
290    message.ParseFromString(serialized)
291    self.assertTrue(math.isnan(message.packed_float[0]))
292    self.assertTrue(math.isnan(message.packed_double[0]))
293
294  def testExtremeFloatValues(self, message_module):
295    message = message_module.TestAllTypes()
296
297    # Most positive exponent, no significand bits set.
298    kMostPosExponentNoSigBits = math.pow(2, 127)
299    message.optional_float = kMostPosExponentNoSigBits
300    message.ParseFromString(message.SerializeToString())
301    self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
302
303    # Most positive exponent, one significand bit set.
304    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
305    message.optional_float = kMostPosExponentOneSigBit
306    message.ParseFromString(message.SerializeToString())
307    self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
308
309    # Repeat last two cases with values of same magnitude, but negative.
310    message.optional_float = -kMostPosExponentNoSigBits
311    message.ParseFromString(message.SerializeToString())
312    self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
313
314    message.optional_float = -kMostPosExponentOneSigBit
315    message.ParseFromString(message.SerializeToString())
316    self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
317
318    # Most negative exponent, no significand bits set.
319    kMostNegExponentNoSigBits = math.pow(2, -127)
320    message.optional_float = kMostNegExponentNoSigBits
321    message.ParseFromString(message.SerializeToString())
322    self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
323
324    # Most negative exponent, one significand bit set.
325    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
326    message.optional_float = kMostNegExponentOneSigBit
327    message.ParseFromString(message.SerializeToString())
328    self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
329
330    # Repeat last two cases with values of the same magnitude, but negative.
331    message.optional_float = -kMostNegExponentNoSigBits
332    message.ParseFromString(message.SerializeToString())
333    self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
334
335    message.optional_float = -kMostNegExponentOneSigBit
336    message.ParseFromString(message.SerializeToString())
337    self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
338
339    # Max 4 bytes float value
340    max_float = float.fromhex('0x1.fffffep+127')
341    message.optional_float = max_float
342    self.assertAlmostEqual(message.optional_float, max_float)
343    serialized_data = message.SerializeToString()
344    message.ParseFromString(serialized_data)
345    self.assertAlmostEqual(message.optional_float, max_float)
346
347    # Test set double to float field.
348    message.optional_float = 3.4028235e+39
349    self.assertEqual(message.optional_float, float('inf'))
350    serialized_data = message.SerializeToString()
351    message.ParseFromString(serialized_data)
352    self.assertEqual(message.optional_float, float('inf'))
353
354    message.optional_float = -3.4028235e+39
355    self.assertEqual(message.optional_float, float('-inf'))
356
357    message.optional_float = 1.4028235e-39
358    self.assertAlmostEqual(message.optional_float, 1.4028235e-39)
359
360  def testExtremeDoubleValues(self, message_module):
361    message = message_module.TestAllTypes()
362
363    # Most positive exponent, no significand bits set.
364    kMostPosExponentNoSigBits = math.pow(2, 1023)
365    message.optional_double = kMostPosExponentNoSigBits
366    message.ParseFromString(message.SerializeToString())
367    self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
368
369    # Most positive exponent, one significand bit set.
370    kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
371    message.optional_double = kMostPosExponentOneSigBit
372    message.ParseFromString(message.SerializeToString())
373    self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
374
375    # Repeat last two cases with values of same magnitude, but negative.
376    message.optional_double = -kMostPosExponentNoSigBits
377    message.ParseFromString(message.SerializeToString())
378    self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
379
380    message.optional_double = -kMostPosExponentOneSigBit
381    message.ParseFromString(message.SerializeToString())
382    self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
383
384    # Most negative exponent, no significand bits set.
385    kMostNegExponentNoSigBits = math.pow(2, -1023)
386    message.optional_double = kMostNegExponentNoSigBits
387    message.ParseFromString(message.SerializeToString())
388    self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
389
390    # Most negative exponent, one significand bit set.
391    kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
392    message.optional_double = kMostNegExponentOneSigBit
393    message.ParseFromString(message.SerializeToString())
394    self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
395
396    # Repeat last two cases with values of the same magnitude, but negative.
397    message.optional_double = -kMostNegExponentNoSigBits
398    message.ParseFromString(message.SerializeToString())
399    self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
400
401    message.optional_double = -kMostNegExponentOneSigBit
402    message.ParseFromString(message.SerializeToString())
403    self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
404
405  def testFloatPrinting(self, message_module):
406    message = message_module.TestAllTypes()
407    message.optional_float = 2.0
408    self.assertEqual(str(message), 'optional_float: 2.0\n')
409
410  def testHighPrecisionFloatPrinting(self, message_module):
411    msg = message_module.TestAllTypes()
412    msg.optional_float = 0.12345678912345678
413    old_float = msg.optional_float
414    msg.ParseFromString(msg.SerializeToString())
415    self.assertEqual(old_float, msg.optional_float)
416
417  def testHighPrecisionDoublePrinting(self, message_module):
418    msg = message_module.TestAllTypes()
419    msg.optional_double = 0.12345678912345678
420    self.assertEqual(str(msg), 'optional_double: 0.12345678912345678\n')
421
422  def testUnknownFieldPrinting(self, message_module):
423    populated = message_module.TestAllTypes()
424    test_util.SetAllNonLazyFields(populated)
425    empty = message_module.TestEmptyMessage()
426    empty.ParseFromString(populated.SerializeToString())
427    self.assertEqual(str(empty), '')
428
429  def testAppendRepeatedCompositeField(self, message_module):
430    msg = message_module.TestAllTypes()
431    msg.repeated_nested_message.append(
432        message_module.TestAllTypes.NestedMessage(bb=1))
433    nested = message_module.TestAllTypes.NestedMessage(bb=2)
434    msg.repeated_nested_message.append(nested)
435    try:
436      msg.repeated_nested_message.append(1)
437    except TypeError:
438      pass
439    self.assertEqual(2, len(msg.repeated_nested_message))
440    self.assertEqual([1, 2], [m.bb for m in msg.repeated_nested_message])
441
442  def testInsertRepeatedCompositeField(self, message_module):
443    msg = message_module.TestAllTypes()
444    msg.repeated_nested_message.insert(
445        -1, message_module.TestAllTypes.NestedMessage(bb=1))
446    sub_msg = msg.repeated_nested_message[0]
447    msg.repeated_nested_message.insert(
448        0, message_module.TestAllTypes.NestedMessage(bb=2))
449    msg.repeated_nested_message.insert(
450        99, message_module.TestAllTypes.NestedMessage(bb=3))
451    msg.repeated_nested_message.insert(
452        -2, message_module.TestAllTypes.NestedMessage(bb=-1))
453    msg.repeated_nested_message.insert(
454        -1000, message_module.TestAllTypes.NestedMessage(bb=-1000))
455    try:
456      msg.repeated_nested_message.insert(1, 999)
457    except TypeError:
458      pass
459    self.assertEqual(5, len(msg.repeated_nested_message))
460    self.assertEqual([-1000, 2, -1, 1, 3],
461                     [m.bb for m in msg.repeated_nested_message])
462    self.assertEqual(
463        str(msg), 'repeated_nested_message {\n'
464        '  bb: -1000\n'
465        '}\n'
466        'repeated_nested_message {\n'
467        '  bb: 2\n'
468        '}\n'
469        'repeated_nested_message {\n'
470        '  bb: -1\n'
471        '}\n'
472        'repeated_nested_message {\n'
473        '  bb: 1\n'
474        '}\n'
475        'repeated_nested_message {\n'
476        '  bb: 3\n'
477        '}\n')
478    self.assertEqual(sub_msg.bb, 1)
479
480  def testMergeFromRepeatedField(self, message_module):
481    msg = message_module.TestAllTypes()
482    msg.repeated_int32.append(1)
483    msg.repeated_int32.append(3)
484    msg.repeated_nested_message.add(bb=1)
485    msg.repeated_nested_message.add(bb=2)
486    other_msg = message_module.TestAllTypes()
487    other_msg.repeated_nested_message.add(bb=3)
488    other_msg.repeated_nested_message.add(bb=4)
489    other_msg.repeated_int32.append(5)
490    other_msg.repeated_int32.append(7)
491
492    msg.repeated_int32.MergeFrom(other_msg.repeated_int32)
493    self.assertEqual(4, len(msg.repeated_int32))
494
495    msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
496    self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message])
497
498  def testAddWrongRepeatedNestedField(self, message_module):
499    msg = message_module.TestAllTypes()
500    try:
501      msg.repeated_nested_message.add('wrong')
502    except TypeError:
503      pass
504    try:
505      msg.repeated_nested_message.add(value_field='wrong')
506    except ValueError:
507      pass
508    self.assertEqual(len(msg.repeated_nested_message), 0)
509
510  def testRepeatedContains(self, message_module):
511    msg = message_module.TestAllTypes()
512    msg.repeated_int32.extend([1, 2, 3])
513    self.assertIn(2, msg.repeated_int32)
514    self.assertNotIn(0, msg.repeated_int32)
515
516    msg.repeated_nested_message.add(bb=1)
517    sub_msg1 = msg.repeated_nested_message[0]
518    sub_msg2 = message_module.TestAllTypes.NestedMessage(bb=2)
519    sub_msg3 = message_module.TestAllTypes.NestedMessage(bb=3)
520    msg.repeated_nested_message.append(sub_msg2)
521    msg.repeated_nested_message.insert(0, sub_msg3)
522    self.assertIn(sub_msg1, msg.repeated_nested_message)
523    self.assertIn(sub_msg2, msg.repeated_nested_message)
524    self.assertIn(sub_msg3, msg.repeated_nested_message)
525
526  def testRepeatedScalarIterable(self, message_module):
527    msg = message_module.TestAllTypes()
528    msg.repeated_int32.extend([1, 2, 3])
529    add = 0
530    for item in msg.repeated_int32:
531      add += item
532    self.assertEqual(add, 6)
533
534  def testRepeatedNestedFieldIteration(self, message_module):
535    msg = message_module.TestAllTypes()
536    msg.repeated_nested_message.add(bb=1)
537    msg.repeated_nested_message.add(bb=2)
538    msg.repeated_nested_message.add(bb=3)
539    msg.repeated_nested_message.add(bb=4)
540
541    self.assertEqual([1, 2, 3, 4], [m.bb for m in msg.repeated_nested_message])
542    self.assertEqual([4, 3, 2, 1],
543                     [m.bb for m in reversed(msg.repeated_nested_message)])
544    self.assertEqual([4, 3, 2, 1],
545                     [m.bb for m in msg.repeated_nested_message[::-1]])
546
547  def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
548    """Check some different types with the default comparator."""
549    message = message_module.TestAllTypes()
550
551    # TODO(mattp): would testing more scalar types strengthen test?
552    message.repeated_int32.append(1)
553    message.repeated_int32.append(3)
554    message.repeated_int32.append(2)
555    message.repeated_int32.sort()
556    self.assertEqual(message.repeated_int32[0], 1)
557    self.assertEqual(message.repeated_int32[1], 2)
558    self.assertEqual(message.repeated_int32[2], 3)
559    self.assertEqual(str(message.repeated_int32), str([1, 2, 3]))
560
561    message.repeated_float.append(1.1)
562    message.repeated_float.append(1.3)
563    message.repeated_float.append(1.2)
564    message.repeated_float.sort()
565    self.assertAlmostEqual(message.repeated_float[0], 1.1)
566    self.assertAlmostEqual(message.repeated_float[1], 1.2)
567    self.assertAlmostEqual(message.repeated_float[2], 1.3)
568
569    message.repeated_string.append('a')
570    message.repeated_string.append('c')
571    message.repeated_string.append('b')
572    message.repeated_string.sort()
573    self.assertEqual(message.repeated_string[0], 'a')
574    self.assertEqual(message.repeated_string[1], 'b')
575    self.assertEqual(message.repeated_string[2], 'c')
576    self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c']))
577
578    message.repeated_bytes.append(b'a')
579    message.repeated_bytes.append(b'c')
580    message.repeated_bytes.append(b'b')
581    message.repeated_bytes.sort()
582    self.assertEqual(message.repeated_bytes[0], b'a')
583    self.assertEqual(message.repeated_bytes[1], b'b')
584    self.assertEqual(message.repeated_bytes[2], b'c')
585    self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c']))
586
587  def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
588    """Check some different types with custom comparator."""
589    message = message_module.TestAllTypes()
590
591    message.repeated_int32.append(-3)
592    message.repeated_int32.append(-2)
593    message.repeated_int32.append(-1)
594    message.repeated_int32.sort(key=abs)
595    self.assertEqual(message.repeated_int32[0], -1)
596    self.assertEqual(message.repeated_int32[1], -2)
597    self.assertEqual(message.repeated_int32[2], -3)
598
599    message.repeated_string.append('aaa')
600    message.repeated_string.append('bb')
601    message.repeated_string.append('c')
602    message.repeated_string.sort(key=len)
603    self.assertEqual(message.repeated_string[0], 'c')
604    self.assertEqual(message.repeated_string[1], 'bb')
605    self.assertEqual(message.repeated_string[2], 'aaa')
606
607  def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
608    """Check passing a custom comparator to sort a repeated composite field."""
609    message = message_module.TestAllTypes()
610
611    message.repeated_nested_message.add().bb = 1
612    message.repeated_nested_message.add().bb = 3
613    message.repeated_nested_message.add().bb = 2
614    message.repeated_nested_message.add().bb = 6
615    message.repeated_nested_message.add().bb = 5
616    message.repeated_nested_message.add().bb = 4
617    message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
618    self.assertEqual(message.repeated_nested_message[0].bb, 1)
619    self.assertEqual(message.repeated_nested_message[1].bb, 2)
620    self.assertEqual(message.repeated_nested_message[2].bb, 3)
621    self.assertEqual(message.repeated_nested_message[3].bb, 4)
622    self.assertEqual(message.repeated_nested_message[4].bb, 5)
623    self.assertEqual(message.repeated_nested_message[5].bb, 6)
624    self.assertEqual(
625        str(message.repeated_nested_message),
626        '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
627
628  def testSortingRepeatedCompositeFieldsStable(self, message_module):
629    """Check passing a custom comparator to sort a repeated composite field."""
630    message = message_module.TestAllTypes()
631
632    message.repeated_nested_message.add().bb = 21
633    message.repeated_nested_message.add().bb = 20
634    message.repeated_nested_message.add().bb = 13
635    message.repeated_nested_message.add().bb = 33
636    message.repeated_nested_message.add().bb = 11
637    message.repeated_nested_message.add().bb = 24
638    message.repeated_nested_message.add().bb = 10
639    message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
640    self.assertEqual([13, 11, 10, 21, 20, 24, 33],
641                     [n.bb for n in message.repeated_nested_message])
642
643    # Make sure that for the C++ implementation, the underlying fields
644    # are actually reordered.
645    pb = message.SerializeToString()
646    message.Clear()
647    message.MergeFromString(pb)
648    self.assertEqual([13, 11, 10, 21, 20, 24, 33],
649                     [n.bb for n in message.repeated_nested_message])
650
651  def testRepeatedCompositeFieldSortArguments(self, message_module):
652    """Check sorting a repeated composite field using list.sort() arguments."""
653    message = message_module.TestAllTypes()
654
655    get_bb = operator.attrgetter('bb')
656    message.repeated_nested_message.add().bb = 1
657    message.repeated_nested_message.add().bb = 3
658    message.repeated_nested_message.add().bb = 2
659    message.repeated_nested_message.add().bb = 6
660    message.repeated_nested_message.add().bb = 5
661    message.repeated_nested_message.add().bb = 4
662    message.repeated_nested_message.sort(key=get_bb)
663    self.assertEqual([k.bb for k in message.repeated_nested_message],
664                     [1, 2, 3, 4, 5, 6])
665    message.repeated_nested_message.sort(key=get_bb, reverse=True)
666    self.assertEqual([k.bb for k in message.repeated_nested_message],
667                     [6, 5, 4, 3, 2, 1])
668
669  def testRepeatedScalarFieldSortArguments(self, message_module):
670    """Check sorting a scalar field using list.sort() arguments."""
671    message = message_module.TestAllTypes()
672
673    message.repeated_int32.append(-3)
674    message.repeated_int32.append(-2)
675    message.repeated_int32.append(-1)
676    message.repeated_int32.sort(key=abs)
677    self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
678    message.repeated_int32.sort(key=abs, reverse=True)
679    self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
680
681    message.repeated_string.append('aaa')
682    message.repeated_string.append('bb')
683    message.repeated_string.append('c')
684    message.repeated_string.sort(key=len)
685    self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
686    message.repeated_string.sort(key=len, reverse=True)
687    self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
688
689  def testRepeatedFieldsComparable(self, message_module):
690    m1 = message_module.TestAllTypes()
691    m2 = message_module.TestAllTypes()
692    m1.repeated_int32.append(0)
693    m1.repeated_int32.append(1)
694    m1.repeated_int32.append(2)
695    m2.repeated_int32.append(0)
696    m2.repeated_int32.append(1)
697    m2.repeated_int32.append(2)
698    m1.repeated_nested_message.add().bb = 1
699    m1.repeated_nested_message.add().bb = 2
700    m1.repeated_nested_message.add().bb = 3
701    m2.repeated_nested_message.add().bb = 1
702    m2.repeated_nested_message.add().bb = 2
703    m2.repeated_nested_message.add().bb = 3
704
705  def testRepeatedFieldsAreSequences(self, message_module):
706    m = message_module.TestAllTypes()
707    self.assertIsInstance(m.repeated_int32, collections.abc.MutableSequence)
708    self.assertIsInstance(m.repeated_nested_message,
709                          collections.abc.MutableSequence)
710
711  def testRepeatedFieldsNotHashable(self, message_module):
712    m = message_module.TestAllTypes()
713    with self.assertRaises(TypeError):
714      hash(m.repeated_int32)
715    with self.assertRaises(TypeError):
716      hash(m.repeated_nested_message)
717
718  def testRepeatedFieldInsideNestedMessage(self, message_module):
719    m = message_module.NestedTestAllTypes()
720    m.payload.repeated_int32.extend([])
721    self.assertTrue(m.HasField('payload'))
722
723  def testMergeFrom(self, message_module):
724    m1 = message_module.TestAllTypes()
725    m2 = message_module.TestAllTypes()
726    # Cpp extension will lazily create a sub message which is immutable.
727    nested = m1.optional_nested_message
728    self.assertEqual(0, nested.bb)
729    m2.optional_nested_message.bb = 1
730    # Make sure cmessage pointing to a mutable message after merge instead of
731    # the lazily created message.
732    m1.MergeFrom(m2)
733    self.assertEqual(1, nested.bb)
734
735    # Test more nested sub message.
736    msg1 = message_module.NestedTestAllTypes()
737    msg2 = message_module.NestedTestAllTypes()
738    nested = msg1.child.payload.optional_nested_message
739    self.assertEqual(0, nested.bb)
740    msg2.child.payload.optional_nested_message.bb = 1
741    msg1.MergeFrom(msg2)
742    self.assertEqual(1, nested.bb)
743
744    # Test repeated field.
745    self.assertEqual(msg1.payload.repeated_nested_message,
746                     msg1.payload.repeated_nested_message)
747    nested = msg2.payload.repeated_nested_message.add()
748    nested.bb = 1
749    msg1.MergeFrom(msg2)
750    self.assertEqual(1, len(msg1.payload.repeated_nested_message))
751    self.assertEqual(1, nested.bb)
752
753  def testMergeFromString(self, message_module):
754    m1 = message_module.TestAllTypes()
755    m2 = message_module.TestAllTypes()
756    # Cpp extension will lazily create a sub message which is immutable.
757    self.assertEqual(0, m1.optional_nested_message.bb)
758    m2.optional_nested_message.bb = 1
759    # Make sure cmessage pointing to a mutable message after merge instead of
760    # the lazily created message.
761    m1.MergeFromString(m2.SerializeToString())
762    self.assertEqual(1, m1.optional_nested_message.bb)
763
764  def testMergeFromStringUsingMemoryView(self, message_module):
765    m2 = message_module.TestAllTypes()
766    m2.optional_string = 'scalar string'
767    m2.repeated_string.append('repeated string')
768    m2.optional_bytes = b'scalar bytes'
769    m2.repeated_bytes.append(b'repeated bytes')
770
771    serialized = m2.SerializeToString()
772    memview = memoryview(serialized)
773    m1 = message_module.TestAllTypes.FromString(memview)
774
775    self.assertEqual(m1.optional_bytes, b'scalar bytes')
776    self.assertEqual(m1.repeated_bytes, [b'repeated bytes'])
777    self.assertEqual(m1.optional_string, 'scalar string')
778    self.assertEqual(m1.repeated_string, ['repeated string'])
779    # Make sure that the memoryview was correctly converted to bytes, and
780    # that a sub-sliced memoryview is not being used.
781    self.assertIsInstance(m1.optional_bytes, bytes)
782    self.assertIsInstance(m1.repeated_bytes[0], bytes)
783    self.assertIsInstance(m1.optional_string, str)
784    self.assertIsInstance(m1.repeated_string[0], str)
785
786  def testMergeFromEmpty(self, message_module):
787    m1 = message_module.TestAllTypes()
788    # Cpp extension will lazily create a sub message which is immutable.
789    self.assertEqual(0, m1.optional_nested_message.bb)
790    self.assertFalse(m1.HasField('optional_nested_message'))
791    # Make sure the sub message is still immutable after merge from empty.
792    m1.MergeFromString(b'')  # field state should not change
793    self.assertFalse(m1.HasField('optional_nested_message'))
794
795  def ensureNestedMessageExists(self, msg, attribute):
796    """Make sure that a nested message object exists.
797
798    As soon as a nested message attribute is accessed, it will be present in the
799    _fields dict, without being marked as actually being set.
800    """
801    getattr(msg, attribute)
802    self.assertFalse(msg.HasField(attribute))
803
804  def testOneofGetCaseNonexistingField(self, message_module):
805    m = message_module.TestAllTypes()
806    self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
807    self.assertRaises(Exception, m.WhichOneof, 0)
808
809  def testOneofDefaultValues(self, message_module):
810    m = message_module.TestAllTypes()
811    self.assertIs(None, m.WhichOneof('oneof_field'))
812    self.assertFalse(m.HasField('oneof_field'))
813    self.assertFalse(m.HasField('oneof_uint32'))
814
815    # Oneof is set even when setting it to a default value.
816    m.oneof_uint32 = 0
817    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
818    self.assertTrue(m.HasField('oneof_field'))
819    self.assertTrue(m.HasField('oneof_uint32'))
820    self.assertFalse(m.HasField('oneof_string'))
821
822    m.oneof_string = ''
823    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
824    self.assertTrue(m.HasField('oneof_string'))
825    self.assertFalse(m.HasField('oneof_uint32'))
826
827  def testOneofSemantics(self, message_module):
828    m = message_module.TestAllTypes()
829    self.assertIs(None, m.WhichOneof('oneof_field'))
830
831    m.oneof_uint32 = 11
832    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
833    self.assertTrue(m.HasField('oneof_uint32'))
834
835    m.oneof_string = u'foo'
836    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
837    self.assertFalse(m.HasField('oneof_uint32'))
838    self.assertTrue(m.HasField('oneof_string'))
839
840    # Read nested message accessor without accessing submessage.
841    m.oneof_nested_message
842    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
843    self.assertTrue(m.HasField('oneof_string'))
844    self.assertFalse(m.HasField('oneof_nested_message'))
845
846    # Read accessor of nested message without accessing submessage.
847    m.oneof_nested_message.bb
848    self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
849    self.assertTrue(m.HasField('oneof_string'))
850    self.assertFalse(m.HasField('oneof_nested_message'))
851
852    m.oneof_nested_message.bb = 11
853    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
854    self.assertFalse(m.HasField('oneof_string'))
855    self.assertTrue(m.HasField('oneof_nested_message'))
856
857    m.oneof_bytes = b'bb'
858    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
859    self.assertFalse(m.HasField('oneof_nested_message'))
860    self.assertTrue(m.HasField('oneof_bytes'))
861
862  def testOneofCompositeFieldReadAccess(self, message_module):
863    m = message_module.TestAllTypes()
864    m.oneof_uint32 = 11
865
866    self.ensureNestedMessageExists(m, 'oneof_nested_message')
867    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
868    self.assertEqual(11, m.oneof_uint32)
869
870  def testOneofWhichOneof(self, message_module):
871    m = message_module.TestAllTypes()
872    self.assertIs(None, m.WhichOneof('oneof_field'))
873    if message_module is unittest_pb2:
874      self.assertFalse(m.HasField('oneof_field'))
875
876    m.oneof_uint32 = 11
877    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
878    if message_module is unittest_pb2:
879      self.assertTrue(m.HasField('oneof_field'))
880
881    m.oneof_bytes = b'bb'
882    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
883
884    m.ClearField('oneof_bytes')
885    self.assertIs(None, m.WhichOneof('oneof_field'))
886    if message_module is unittest_pb2:
887      self.assertFalse(m.HasField('oneof_field'))
888
889  def testOneofClearField(self, message_module):
890    m = message_module.TestAllTypes()
891    m.oneof_uint32 = 11
892    m.ClearField('oneof_field')
893    if message_module is unittest_pb2:
894      self.assertFalse(m.HasField('oneof_field'))
895    self.assertFalse(m.HasField('oneof_uint32'))
896    self.assertIs(None, m.WhichOneof('oneof_field'))
897
898  def testOneofClearSetField(self, message_module):
899    m = message_module.TestAllTypes()
900    m.oneof_uint32 = 11
901    m.ClearField('oneof_uint32')
902    if message_module is unittest_pb2:
903      self.assertFalse(m.HasField('oneof_field'))
904    self.assertFalse(m.HasField('oneof_uint32'))
905    self.assertIs(None, m.WhichOneof('oneof_field'))
906
907  def testOneofClearUnsetField(self, message_module):
908    m = message_module.TestAllTypes()
909    m.oneof_uint32 = 11
910    self.ensureNestedMessageExists(m, 'oneof_nested_message')
911    m.ClearField('oneof_nested_message')
912    self.assertEqual(11, m.oneof_uint32)
913    if message_module is unittest_pb2:
914      self.assertTrue(m.HasField('oneof_field'))
915    self.assertTrue(m.HasField('oneof_uint32'))
916    self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
917
918  def testOneofDeserialize(self, message_module):
919    m = message_module.TestAllTypes()
920    m.oneof_uint32 = 11
921    m2 = message_module.TestAllTypes()
922    m2.ParseFromString(m.SerializeToString())
923    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
924
925  def testOneofCopyFrom(self, message_module):
926    m = message_module.TestAllTypes()
927    m.oneof_uint32 = 11
928    m2 = message_module.TestAllTypes()
929    m2.CopyFrom(m)
930    self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
931
932  def testOneofNestedMergeFrom(self, message_module):
933    m = message_module.NestedTestAllTypes()
934    m.payload.oneof_uint32 = 11
935    m2 = message_module.NestedTestAllTypes()
936    m2.payload.oneof_bytes = b'bb'
937    m2.child.payload.oneof_bytes = b'bb'
938    m2.MergeFrom(m)
939    self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
940    self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
941
942  def testOneofMessageMergeFrom(self, message_module):
943    m = message_module.NestedTestAllTypes()
944    m.payload.oneof_nested_message.bb = 11
945    m.child.payload.oneof_nested_message.bb = 12
946    m2 = message_module.NestedTestAllTypes()
947    m2.payload.oneof_uint32 = 13
948    m2.MergeFrom(m)
949    self.assertEqual('oneof_nested_message',
950                     m2.payload.WhichOneof('oneof_field'))
951    self.assertEqual('oneof_nested_message',
952                     m2.child.payload.WhichOneof('oneof_field'))
953
954  def testOneofNestedMessageInit(self, message_module):
955    m = message_module.TestAllTypes(
956        oneof_nested_message=message_module.TestAllTypes.NestedMessage())
957    self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
958
959  def testOneofClear(self, message_module):
960    m = message_module.TestAllTypes()
961    m.oneof_uint32 = 11
962    m.Clear()
963    self.assertIsNone(m.WhichOneof('oneof_field'))
964    m.oneof_bytes = b'bb'
965    self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
966
967  def testAssignByteStringToUnicodeField(self, message_module):
968    """Assigning a byte string to a string field should result
969
970    in the value being converted to a Unicode string.
971    """
972    m = message_module.TestAllTypes()
973    m.optional_string = str('')
974    self.assertIsInstance(m.optional_string, str)
975
976  def testLongValuedSlice(self, message_module):
977    """It should be possible to use int-valued indices in slices.
978
979    This didn't used to work in the v2 C++ implementation.
980    """
981    m = message_module.TestAllTypes()
982
983    # Repeated scalar
984    m.repeated_int32.append(1)
985    sl = m.repeated_int32[int(0):int(len(m.repeated_int32))]
986    self.assertEqual(len(m.repeated_int32), len(sl))
987
988    # Repeated composite
989    m.repeated_nested_message.add().bb = 3
990    sl = m.repeated_nested_message[int(0):int(len(m.repeated_nested_message))]
991    self.assertEqual(len(m.repeated_nested_message), len(sl))
992
993  def testExtendShouldNotSwallowExceptions(self, message_module):
994    """This didn't use to work in the v2 C++ implementation."""
995    m = message_module.TestAllTypes()
996    with self.assertRaises(NameError) as _:
997      m.repeated_int32.extend(a for i in range(10))  # pylint: disable=undefined-variable
998    with self.assertRaises(NameError) as _:
999      m.repeated_nested_enum.extend(a for i in range(10))  # pylint: disable=undefined-variable
1000
1001  FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
1002
1003  def testExtendInt32WithNothing(self, message_module):
1004    """Test no-ops extending repeated int32 fields."""
1005    m = message_module.TestAllTypes()
1006    self.assertSequenceEqual([], m.repeated_int32)
1007
1008    # TODO(ptucker): Deprecate this behavior. b/18413862
1009    for falsy_value in MessageTest.FALSY_VALUES:
1010      m.repeated_int32.extend(falsy_value)
1011      self.assertSequenceEqual([], m.repeated_int32)
1012
1013    m.repeated_int32.extend([])
1014    self.assertSequenceEqual([], m.repeated_int32)
1015
1016  def testExtendFloatWithNothing(self, message_module):
1017    """Test no-ops extending repeated float fields."""
1018    m = message_module.TestAllTypes()
1019    self.assertSequenceEqual([], m.repeated_float)
1020
1021    # TODO(ptucker): Deprecate this behavior. b/18413862
1022    for falsy_value in MessageTest.FALSY_VALUES:
1023      m.repeated_float.extend(falsy_value)
1024      self.assertSequenceEqual([], m.repeated_float)
1025
1026    m.repeated_float.extend([])
1027    self.assertSequenceEqual([], m.repeated_float)
1028
1029  def testExtendStringWithNothing(self, message_module):
1030    """Test no-ops extending repeated string fields."""
1031    m = message_module.TestAllTypes()
1032    self.assertSequenceEqual([], m.repeated_string)
1033
1034    # TODO(ptucker): Deprecate this behavior. b/18413862
1035    for falsy_value in MessageTest.FALSY_VALUES:
1036      m.repeated_string.extend(falsy_value)
1037      self.assertSequenceEqual([], m.repeated_string)
1038
1039    m.repeated_string.extend([])
1040    self.assertSequenceEqual([], m.repeated_string)
1041
1042  def testExtendInt32WithPythonList(self, message_module):
1043    """Test extending repeated int32 fields with python lists."""
1044    m = message_module.TestAllTypes()
1045    self.assertSequenceEqual([], m.repeated_int32)
1046    m.repeated_int32.extend([0])
1047    self.assertSequenceEqual([0], m.repeated_int32)
1048    m.repeated_int32.extend([1, 2])
1049    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
1050    m.repeated_int32.extend([3, 4])
1051    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
1052
1053  def testExtendFloatWithPythonList(self, message_module):
1054    """Test extending repeated float fields with python lists."""
1055    m = message_module.TestAllTypes()
1056    self.assertSequenceEqual([], m.repeated_float)
1057    m.repeated_float.extend([0.0])
1058    self.assertSequenceEqual([0.0], m.repeated_float)
1059    m.repeated_float.extend([1.0, 2.0])
1060    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
1061    m.repeated_float.extend([3.0, 4.0])
1062    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
1063
1064  def testExtendStringWithPythonList(self, message_module):
1065    """Test extending repeated string fields with python lists."""
1066    m = message_module.TestAllTypes()
1067    self.assertSequenceEqual([], m.repeated_string)
1068    m.repeated_string.extend([''])
1069    self.assertSequenceEqual([''], m.repeated_string)
1070    m.repeated_string.extend(['11', '22'])
1071    self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
1072    m.repeated_string.extend(['33', '44'])
1073    self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
1074
1075  def testExtendStringWithString(self, message_module):
1076    """Test extending repeated string fields with characters from a string."""
1077    m = message_module.TestAllTypes()
1078    self.assertSequenceEqual([], m.repeated_string)
1079    m.repeated_string.extend('abc')
1080    self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
1081
1082  class TestIterable(object):
1083    """This iterable object mimics the behavior of numpy.array.
1084
1085    __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
1086
1087    """
1088
1089    def __init__(self, values=None):
1090      self._list = values or []
1091
1092    def __nonzero__(self):
1093      size = len(self._list)
1094      if size == 0:
1095        return False
1096      if size == 1:
1097        return bool(self._list[0])
1098      raise ValueError('Truth value is ambiguous.')
1099
1100    def __len__(self):
1101      return len(self._list)
1102
1103    def __iter__(self):
1104      return self._list.__iter__()
1105
1106  def testExtendInt32WithIterable(self, message_module):
1107    """Test extending repeated int32 fields with iterable."""
1108    m = message_module.TestAllTypes()
1109    self.assertSequenceEqual([], m.repeated_int32)
1110    m.repeated_int32.extend(MessageTest.TestIterable([]))
1111    self.assertSequenceEqual([], m.repeated_int32)
1112    m.repeated_int32.extend(MessageTest.TestIterable([0]))
1113    self.assertSequenceEqual([0], m.repeated_int32)
1114    m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
1115    self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
1116    m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
1117    self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
1118
1119  def testExtendFloatWithIterable(self, message_module):
1120    """Test extending repeated float fields with iterable."""
1121    m = message_module.TestAllTypes()
1122    self.assertSequenceEqual([], m.repeated_float)
1123    m.repeated_float.extend(MessageTest.TestIterable([]))
1124    self.assertSequenceEqual([], m.repeated_float)
1125    m.repeated_float.extend(MessageTest.TestIterable([0.0]))
1126    self.assertSequenceEqual([0.0], m.repeated_float)
1127    m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
1128    self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
1129    m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
1130    self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
1131
1132  def testExtendStringWithIterable(self, message_module):
1133    """Test extending repeated string fields with iterable."""
1134    m = message_module.TestAllTypes()
1135    self.assertSequenceEqual([], m.repeated_string)
1136    m.repeated_string.extend(MessageTest.TestIterable([]))
1137    self.assertSequenceEqual([], m.repeated_string)
1138    m.repeated_string.extend(MessageTest.TestIterable(['']))
1139    self.assertSequenceEqual([''], m.repeated_string)
1140    m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
1141    self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
1142    m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
1143    self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
1144
1145  class TestIndex(object):
1146    """This index object mimics the behavior of numpy.int64 and other types."""
1147
1148    def __init__(self, value=None):
1149      self.value = value
1150
1151    def __index__(self):
1152      return self.value
1153
1154  def testRepeatedIndexingWithIntIndex(self, message_module):
1155    msg = message_module.TestAllTypes()
1156    msg.repeated_int32.extend([1, 2, 3])
1157    self.assertEqual(1, msg.repeated_int32[MessageTest.TestIndex(0)])
1158
1159  def testRepeatedIndexingWithNegative1IntIndex(self, message_module):
1160    msg = message_module.TestAllTypes()
1161    msg.repeated_int32.extend([1, 2, 3])
1162    self.assertEqual(3, msg.repeated_int32[MessageTest.TestIndex(-1)])
1163
1164  def testRepeatedIndexingWithNegative1Int(self, message_module):
1165    msg = message_module.TestAllTypes()
1166    msg.repeated_int32.extend([1, 2, 3])
1167    self.assertEqual(3, msg.repeated_int32[-1])
1168
1169  def testPickleRepeatedScalarContainer(self, message_module):
1170    # Pickle repeated scalar container is not supported.
1171    m = message_module.TestAllTypes()
1172    with self.assertRaises(pickle.PickleError) as _:
1173      pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
1174
1175  def testSortEmptyRepeatedCompositeContainer(self, message_module):
1176    """Exercise a scenario that has led to segfaults in the past."""
1177    m = message_module.TestAllTypes()
1178    m.repeated_nested_message.sort()
1179
1180  def testHasFieldOnRepeatedField(self, message_module):
1181    """Using HasField on a repeated field should raise an exception."""
1182    m = message_module.TestAllTypes()
1183    with self.assertRaises(ValueError) as _:
1184      m.HasField('repeated_int32')
1185
1186  def testRepeatedScalarFieldPop(self, message_module):
1187    m = message_module.TestAllTypes()
1188    with self.assertRaises(IndexError) as _:
1189      m.repeated_int32.pop()
1190    m.repeated_int32.extend(range(5))
1191    self.assertEqual(4, m.repeated_int32.pop())
1192    self.assertEqual(0, m.repeated_int32.pop(0))
1193    self.assertEqual(2, m.repeated_int32.pop(1))
1194    self.assertEqual([1, 3], m.repeated_int32)
1195
1196  def testRepeatedCompositeFieldPop(self, message_module):
1197    m = message_module.TestAllTypes()
1198    with self.assertRaises(IndexError) as _:
1199      m.repeated_nested_message.pop()
1200    with self.assertRaises(TypeError) as _:
1201      m.repeated_nested_message.pop('0')
1202    for i in range(5):
1203      n = m.repeated_nested_message.add()
1204      n.bb = i
1205    self.assertEqual(4, m.repeated_nested_message.pop().bb)
1206    self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
1207    self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
1208    self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
1209
1210  def testRepeatedCompareWithSelf(self, message_module):
1211    m = message_module.TestAllTypes()
1212    for i in range(5):
1213      m.repeated_int32.insert(i, i)
1214      n = m.repeated_nested_message.add()
1215      n.bb = i
1216    self.assertSequenceEqual(m.repeated_int32, m.repeated_int32)
1217    self.assertEqual(m.repeated_nested_message, m.repeated_nested_message)
1218
1219  def testReleasedNestedMessages(self, message_module):
1220    """A case that lead to a segfault when a message detached from its parent
1221
1222    container has itself a child container.
1223    """
1224    m = message_module.NestedTestAllTypes()
1225    m = m.repeated_child.add()
1226    m = m.child
1227    m = m.repeated_child.add()
1228    self.assertEqual(m.payload.optional_int32, 0)
1229
1230  def testSetRepeatedComposite(self, message_module):
1231    m = message_module.TestAllTypes()
1232    with self.assertRaises(AttributeError):
1233      m.repeated_int32 = []
1234    m.repeated_int32.append(1)
1235    with self.assertRaises(AttributeError):
1236      m.repeated_int32 = []
1237
1238  def testReturningType(self, message_module):
1239    m = message_module.TestAllTypes()
1240    self.assertEqual(float, type(m.optional_float))
1241    self.assertEqual(float, type(m.optional_double))
1242    self.assertEqual(bool, type(m.optional_bool))
1243    m.optional_float = 1
1244    m.optional_double = 1
1245    m.optional_bool = 1
1246    m.repeated_float.append(1)
1247    m.repeated_double.append(1)
1248    m.repeated_bool.append(1)
1249    m.ParseFromString(m.SerializeToString())
1250    self.assertEqual(float, type(m.optional_float))
1251    self.assertEqual(float, type(m.optional_double))
1252    self.assertEqual('1.0', str(m.optional_double))
1253    self.assertEqual(bool, type(m.optional_bool))
1254    self.assertEqual(float, type(m.repeated_float[0]))
1255    self.assertEqual(float, type(m.repeated_double[0]))
1256    self.assertEqual(bool, type(m.repeated_bool[0]))
1257    self.assertEqual(True, m.repeated_bool[0])
1258
1259
1260# Class to test proto2-only features (required, extensions, etc.)
1261@testing_refleaks.TestCase
1262class Proto2Test(unittest.TestCase):
1263
1264  def testFieldPresence(self):
1265    message = unittest_pb2.TestAllTypes()
1266
1267    self.assertFalse(message.HasField('optional_int32'))
1268    self.assertFalse(message.HasField('optional_bool'))
1269    self.assertFalse(message.HasField('optional_nested_message'))
1270
1271    with self.assertRaises(ValueError):
1272      message.HasField('field_doesnt_exist')
1273
1274    with self.assertRaises(ValueError):
1275      message.HasField('repeated_int32')
1276    with self.assertRaises(ValueError):
1277      message.HasField('repeated_nested_message')
1278
1279    self.assertEqual(0, message.optional_int32)
1280    self.assertEqual(False, message.optional_bool)
1281    self.assertEqual(0, message.optional_nested_message.bb)
1282
1283    # Fields are set even when setting the values to default values.
1284    message.optional_int32 = 0
1285    message.optional_bool = False
1286    message.optional_nested_message.bb = 0
1287    self.assertTrue(message.HasField('optional_int32'))
1288    self.assertTrue(message.HasField('optional_bool'))
1289    self.assertTrue(message.HasField('optional_nested_message'))
1290
1291    # Set the fields to non-default values.
1292    message.optional_int32 = 5
1293    message.optional_bool = True
1294    message.optional_nested_message.bb = 15
1295
1296    self.assertTrue(message.HasField(u'optional_int32'))
1297    self.assertTrue(message.HasField('optional_bool'))
1298    self.assertTrue(message.HasField('optional_nested_message'))
1299
1300    # Clearing the fields unsets them and resets their value to default.
1301    message.ClearField('optional_int32')
1302    message.ClearField(u'optional_bool')
1303    message.ClearField('optional_nested_message')
1304
1305    self.assertFalse(message.HasField('optional_int32'))
1306    self.assertFalse(message.HasField('optional_bool'))
1307    self.assertFalse(message.HasField('optional_nested_message'))
1308    self.assertEqual(0, message.optional_int32)
1309    self.assertEqual(False, message.optional_bool)
1310    self.assertEqual(0, message.optional_nested_message.bb)
1311
1312  def testAssignInvalidEnum(self):
1313    """Assigning an invalid enum number is not allowed in proto2."""
1314    m = unittest_pb2.TestAllTypes()
1315
1316    # Proto2 can not assign unknown enum.
1317    with self.assertRaises(ValueError) as _:
1318      m.optional_nested_enum = 1234567
1319    self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
1320    # Assignment is a different code path than append for the C++ impl.
1321    m.repeated_nested_enum.append(2)
1322    m.repeated_nested_enum[0] = 2
1323    with self.assertRaises(ValueError):
1324      m.repeated_nested_enum[0] = 123456
1325
1326    # Unknown enum value can be parsed but is ignored.
1327    m2 = unittest_proto3_arena_pb2.TestAllTypes()
1328    m2.optional_nested_enum = 1234567
1329    m2.repeated_nested_enum.append(7654321)
1330    serialized = m2.SerializeToString()
1331
1332    m3 = unittest_pb2.TestAllTypes()
1333    m3.ParseFromString(serialized)
1334    self.assertFalse(m3.HasField('optional_nested_enum'))
1335    # 1 is the default value for optional_nested_enum.
1336    self.assertEqual(1, m3.optional_nested_enum)
1337    self.assertEqual(0, len(m3.repeated_nested_enum))
1338    m2.Clear()
1339    m2.ParseFromString(m3.SerializeToString())
1340    self.assertEqual(1234567, m2.optional_nested_enum)
1341    self.assertEqual(7654321, m2.repeated_nested_enum[0])
1342
1343  def testUnknownEnumMap(self):
1344    m = map_proto2_unittest_pb2.TestEnumMap()
1345    m.known_map_field[123] = 0
1346    with self.assertRaises(ValueError):
1347      m.unknown_map_field[1] = 123
1348
1349  def testExtensionsErrors(self):
1350    msg = unittest_pb2.TestAllTypes()
1351    self.assertRaises(AttributeError, getattr, msg, 'Extensions')
1352
1353  def testMergeFromExtensions(self):
1354    msg1 = more_extensions_pb2.TopLevelMessage()
1355    msg2 = more_extensions_pb2.TopLevelMessage()
1356    # Cpp extension will lazily create a sub message which is immutable.
1357    self.assertEqual(
1358        0,
1359        msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension])
1360    self.assertFalse(msg1.HasField('submessage'))
1361    msg2.submessage.Extensions[more_extensions_pb2.optional_int_extension] = 123
1362    # Make sure cmessage and extensions pointing to a mutable message
1363    # after merge instead of the lazily created message.
1364    msg1.MergeFrom(msg2)
1365    self.assertEqual(
1366        123,
1367        msg1.submessage.Extensions[more_extensions_pb2.optional_int_extension])
1368
1369  def testGoldenExtensions(self):
1370    golden_data = test_util.GoldenFileData('golden_message')
1371    golden_message = unittest_pb2.TestAllExtensions()
1372    golden_message.ParseFromString(golden_data)
1373    all_set = unittest_pb2.TestAllExtensions()
1374    test_util.SetAllExtensions(all_set)
1375    self.assertEqual(all_set, golden_message)
1376    self.assertEqual(golden_data, golden_message.SerializeToString())
1377    golden_copy = copy.deepcopy(golden_message)
1378    self.assertEqual(golden_data, golden_copy.SerializeToString())
1379
1380  def testGoldenPackedExtensions(self):
1381    golden_data = test_util.GoldenFileData('golden_packed_fields_message')
1382    golden_message = unittest_pb2.TestPackedExtensions()
1383    golden_message.ParseFromString(golden_data)
1384    all_set = unittest_pb2.TestPackedExtensions()
1385    test_util.SetAllPackedExtensions(all_set)
1386    self.assertEqual(all_set, golden_message)
1387    self.assertEqual(golden_data, all_set.SerializeToString())
1388    golden_copy = copy.deepcopy(golden_message)
1389    self.assertEqual(golden_data, golden_copy.SerializeToString())
1390
1391  def testPickleIncompleteProto(self):
1392    golden_message = unittest_pb2.TestRequired(a=1)
1393    pickled_message = pickle.dumps(golden_message)
1394
1395    unpickled_message = pickle.loads(pickled_message)
1396    self.assertEqual(unpickled_message, golden_message)
1397    self.assertEqual(unpickled_message.a, 1)
1398    # This is still an incomplete proto - so serializing should fail
1399    self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
1400
1401  # TODO(haberman): this isn't really a proto2-specific test except that this
1402  # message has a required field in it.  Should probably be factored out so
1403  # that we can test the other parts with proto3.
1404  def testParsingMerge(self):
1405    """Check the merge behavior when a required or optional field appears
1406
1407    multiple times in the input.
1408    """
1409    messages = [
1410        unittest_pb2.TestAllTypes(),
1411        unittest_pb2.TestAllTypes(),
1412        unittest_pb2.TestAllTypes()
1413    ]
1414    messages[0].optional_int32 = 1
1415    messages[1].optional_int64 = 2
1416    messages[2].optional_int32 = 3
1417    messages[2].optional_string = 'hello'
1418
1419    merged_message = unittest_pb2.TestAllTypes()
1420    merged_message.optional_int32 = 3
1421    merged_message.optional_int64 = 2
1422    merged_message.optional_string = 'hello'
1423
1424    generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
1425    generator.field1.extend(messages)
1426    generator.field2.extend(messages)
1427    generator.field3.extend(messages)
1428    generator.ext1.extend(messages)
1429    generator.ext2.extend(messages)
1430    generator.group1.add().field1.MergeFrom(messages[0])
1431    generator.group1.add().field1.MergeFrom(messages[1])
1432    generator.group1.add().field1.MergeFrom(messages[2])
1433    generator.group2.add().field1.MergeFrom(messages[0])
1434    generator.group2.add().field1.MergeFrom(messages[1])
1435    generator.group2.add().field1.MergeFrom(messages[2])
1436
1437    data = generator.SerializeToString()
1438    parsing_merge = unittest_pb2.TestParsingMerge()
1439    parsing_merge.ParseFromString(data)
1440
1441    # Required and optional fields should be merged.
1442    self.assertEqual(parsing_merge.required_all_types, merged_message)
1443    self.assertEqual(parsing_merge.optional_all_types, merged_message)
1444    self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
1445                     merged_message)
1446    self.assertEqual(
1447        parsing_merge.Extensions[unittest_pb2.TestParsingMerge.optional_ext],
1448        merged_message)
1449
1450    # Repeated fields should not be merged.
1451    self.assertEqual(len(parsing_merge.repeated_all_types), 3)
1452    self.assertEqual(len(parsing_merge.repeatedgroup), 3)
1453    self.assertEqual(
1454        len(parsing_merge.Extensions[
1455            unittest_pb2.TestParsingMerge.repeated_ext]), 3)
1456
1457  def testPythonicInit(self):
1458    message = unittest_pb2.TestAllTypes(
1459        optional_int32=100,
1460        optional_fixed32=200,
1461        optional_float=300.5,
1462        optional_bytes=b'x',
1463        optionalgroup={'a': 400},
1464        optional_nested_message={'bb': 500},
1465        optional_foreign_message={},
1466        optional_nested_enum='BAZ',
1467        repeatedgroup=[{
1468            'a': 600
1469        }, {
1470            'a': 700
1471        }],
1472        repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
1473        default_int32=800,
1474        oneof_string='y')
1475    self.assertIsInstance(message, unittest_pb2.TestAllTypes)
1476    self.assertEqual(100, message.optional_int32)
1477    self.assertEqual(200, message.optional_fixed32)
1478    self.assertEqual(300.5, message.optional_float)
1479    self.assertEqual(b'x', message.optional_bytes)
1480    self.assertEqual(400, message.optionalgroup.a)
1481    self.assertIsInstance(message.optional_nested_message,
1482                          unittest_pb2.TestAllTypes.NestedMessage)
1483    self.assertEqual(500, message.optional_nested_message.bb)
1484    self.assertTrue(message.HasField('optional_foreign_message'))
1485    self.assertEqual(message.optional_foreign_message,
1486                     unittest_pb2.ForeignMessage())
1487    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1488                     message.optional_nested_enum)
1489    self.assertEqual(2, len(message.repeatedgroup))
1490    self.assertEqual(600, message.repeatedgroup[0].a)
1491    self.assertEqual(700, message.repeatedgroup[1].a)
1492    self.assertEqual(2, len(message.repeated_nested_enum))
1493    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
1494                     message.repeated_nested_enum[0])
1495    self.assertEqual(unittest_pb2.TestAllTypes.BAR,
1496                     message.repeated_nested_enum[1])
1497    self.assertEqual(800, message.default_int32)
1498    self.assertEqual('y', message.oneof_string)
1499    self.assertFalse(message.HasField('optional_int64'))
1500    self.assertEqual(0, len(message.repeated_float))
1501    self.assertEqual(42, message.default_int64)
1502
1503    message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
1504    self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1505                     message.optional_nested_enum)
1506
1507    with self.assertRaises(ValueError):
1508      unittest_pb2.TestAllTypes(
1509          optional_nested_message={'INVALID_NESTED_FIELD': 17})
1510
1511    with self.assertRaises(TypeError):
1512      unittest_pb2.TestAllTypes(
1513          optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
1514
1515    with self.assertRaises(ValueError):
1516      unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
1517
1518    with self.assertRaises(ValueError):
1519      unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
1520
1521  def testPythonicInitWithDict(self):
1522    # Both string/unicode field name keys should work.
1523    kwargs = {
1524        'optional_int32': 100,
1525        u'optional_fixed32': 200,
1526    }
1527    msg = unittest_pb2.TestAllTypes(**kwargs)
1528    self.assertEqual(100, msg.optional_int32)
1529    self.assertEqual(200, msg.optional_fixed32)
1530
1531
1532  def test_documentation(self):
1533    # Also used by the interactive help() function.
1534    doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message')
1535    self.assertIn('class TestAllTypes', doc)
1536    self.assertIn('SerializePartialToString', doc)
1537    self.assertIn('repeated_float', doc)
1538    base = unittest_pb2.TestAllTypes.__bases__[0]
1539    self.assertRaises(AttributeError, getattr, base, '_extensions_by_name')
1540
1541
1542# Class to test proto3-only features/behavior (updated field presence & enums)
1543@testing_refleaks.TestCase
1544class Proto3Test(unittest.TestCase):
1545
1546  # Utility method for comparing equality with a map.
1547  def assertMapIterEquals(self, map_iter, dict_value):
1548    # Avoid mutating caller's copy.
1549    dict_value = dict(dict_value)
1550
1551    for k, v in map_iter:
1552      self.assertEqual(v, dict_value[k])
1553      del dict_value[k]
1554
1555    self.assertEqual({}, dict_value)
1556
1557  def testFieldPresence(self):
1558    message = unittest_proto3_arena_pb2.TestAllTypes()
1559
1560    # We can't test presence of non-repeated, non-submessage fields.
1561    with self.assertRaises(ValueError):
1562      message.HasField('optional_int32')
1563    with self.assertRaises(ValueError):
1564      message.HasField('optional_float')
1565    with self.assertRaises(ValueError):
1566      message.HasField('optional_string')
1567    with self.assertRaises(ValueError):
1568      message.HasField('optional_bool')
1569
1570    # But we can still test presence of submessage fields.
1571    self.assertFalse(message.HasField('optional_nested_message'))
1572
1573    # As with proto2, we can't test presence of fields that don't exist, or
1574    # repeated fields.
1575    with self.assertRaises(ValueError):
1576      message.HasField('field_doesnt_exist')
1577
1578    with self.assertRaises(ValueError):
1579      message.HasField('repeated_int32')
1580    with self.assertRaises(ValueError):
1581      message.HasField('repeated_nested_message')
1582
1583    # Fields should default to their type-specific default.
1584    self.assertEqual(0, message.optional_int32)
1585    self.assertEqual(0, message.optional_float)
1586    self.assertEqual('', message.optional_string)
1587    self.assertEqual(False, message.optional_bool)
1588    self.assertEqual(0, message.optional_nested_message.bb)
1589
1590    # Setting a submessage should still return proper presence information.
1591    message.optional_nested_message.bb = 0
1592    self.assertTrue(message.HasField('optional_nested_message'))
1593
1594    # Set the fields to non-default values.
1595    message.optional_int32 = 5
1596    message.optional_float = 1.1
1597    message.optional_string = 'abc'
1598    message.optional_bool = True
1599    message.optional_nested_message.bb = 15
1600
1601    # Clearing the fields unsets them and resets their value to default.
1602    message.ClearField('optional_int32')
1603    message.ClearField('optional_float')
1604    message.ClearField('optional_string')
1605    message.ClearField('optional_bool')
1606    message.ClearField('optional_nested_message')
1607
1608    self.assertEqual(0, message.optional_int32)
1609    self.assertEqual(0, message.optional_float)
1610    self.assertEqual('', message.optional_string)
1611    self.assertEqual(False, message.optional_bool)
1612    self.assertEqual(0, message.optional_nested_message.bb)
1613
1614  def testProto3ParserDropDefaultScalar(self):
1615    message_proto2 = unittest_pb2.TestAllTypes()
1616    message_proto2.optional_int32 = 0
1617    message_proto2.optional_string = ''
1618    message_proto2.optional_bytes = b''
1619    self.assertEqual(len(message_proto2.ListFields()), 3)
1620
1621    message_proto3 = unittest_proto3_arena_pb2.TestAllTypes()
1622    message_proto3.ParseFromString(message_proto2.SerializeToString())
1623    self.assertEqual(len(message_proto3.ListFields()), 0)
1624
1625  def testProto3Optional(self):
1626    msg = test_proto3_optional_pb2.TestProto3Optional()
1627    self.assertFalse(msg.HasField('optional_int32'))
1628    self.assertFalse(msg.HasField('optional_float'))
1629    self.assertFalse(msg.HasField('optional_string'))
1630    self.assertFalse(msg.HasField('optional_nested_message'))
1631    self.assertFalse(msg.optional_nested_message.HasField('bb'))
1632
1633    # Set fields.
1634    msg.optional_int32 = 1
1635    msg.optional_float = 1.0
1636    msg.optional_string = '123'
1637    msg.optional_nested_message.bb = 1
1638    self.assertTrue(msg.HasField('optional_int32'))
1639    self.assertTrue(msg.HasField('optional_float'))
1640    self.assertTrue(msg.HasField('optional_string'))
1641    self.assertTrue(msg.HasField('optional_nested_message'))
1642    self.assertTrue(msg.optional_nested_message.HasField('bb'))
1643    # Set to default value does not clear the fields
1644    msg.optional_int32 = 0
1645    msg.optional_float = 0.0
1646    msg.optional_string = ''
1647    msg.optional_nested_message.bb = 0
1648    self.assertTrue(msg.HasField('optional_int32'))
1649    self.assertTrue(msg.HasField('optional_float'))
1650    self.assertTrue(msg.HasField('optional_string'))
1651    self.assertTrue(msg.HasField('optional_nested_message'))
1652    self.assertTrue(msg.optional_nested_message.HasField('bb'))
1653
1654    # Test serialize
1655    msg2 = test_proto3_optional_pb2.TestProto3Optional()
1656    msg2.ParseFromString(msg.SerializeToString())
1657    self.assertTrue(msg2.HasField('optional_int32'))
1658    self.assertTrue(msg2.HasField('optional_float'))
1659    self.assertTrue(msg2.HasField('optional_string'))
1660    self.assertTrue(msg2.HasField('optional_nested_message'))
1661    self.assertTrue(msg2.optional_nested_message.HasField('bb'))
1662
1663    self.assertEqual(msg.WhichOneof('_optional_int32'), 'optional_int32')
1664
1665    # Clear these fields.
1666    msg.ClearField('optional_int32')
1667    msg.ClearField('optional_float')
1668    msg.ClearField('optional_string')
1669    msg.ClearField('optional_nested_message')
1670    self.assertFalse(msg.HasField('optional_int32'))
1671    self.assertFalse(msg.HasField('optional_float'))
1672    self.assertFalse(msg.HasField('optional_string'))
1673    self.assertFalse(msg.HasField('optional_nested_message'))
1674    self.assertFalse(msg.optional_nested_message.HasField('bb'))
1675
1676    self.assertEqual(msg.WhichOneof('_optional_int32'), None)
1677
1678    # Test has presence:
1679    for field in test_proto3_optional_pb2.TestProto3Optional.DESCRIPTOR.fields:
1680      self.assertTrue(field.has_presence)
1681    for field in unittest_pb2.TestAllTypes.DESCRIPTOR.fields:
1682      if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1683        self.assertFalse(field.has_presence)
1684      else:
1685        self.assertTrue(field.has_presence)
1686    proto3_descriptor = unittest_proto3_arena_pb2.TestAllTypes.DESCRIPTOR
1687    repeated_field = proto3_descriptor.fields_by_name['repeated_int32']
1688    self.assertFalse(repeated_field.has_presence)
1689    singular_field = proto3_descriptor.fields_by_name['optional_int32']
1690    self.assertFalse(singular_field.has_presence)
1691    optional_field = proto3_descriptor.fields_by_name['proto3_optional_int32']
1692    self.assertTrue(optional_field.has_presence)
1693    message_field = proto3_descriptor.fields_by_name['optional_nested_message']
1694    self.assertTrue(message_field.has_presence)
1695    oneof_field = proto3_descriptor.fields_by_name['oneof_uint32']
1696    self.assertTrue(oneof_field.has_presence)
1697
1698  def testAssignUnknownEnum(self):
1699    """Assigning an unknown enum value is allowed and preserves the value."""
1700    m = unittest_proto3_arena_pb2.TestAllTypes()
1701
1702    # Proto3 can assign unknown enums.
1703    m.optional_nested_enum = 1234567
1704    self.assertEqual(1234567, m.optional_nested_enum)
1705    m.repeated_nested_enum.append(22334455)
1706    self.assertEqual(22334455, m.repeated_nested_enum[0])
1707    # Assignment is a different code path than append for the C++ impl.
1708    m.repeated_nested_enum[0] = 7654321
1709    self.assertEqual(7654321, m.repeated_nested_enum[0])
1710    serialized = m.SerializeToString()
1711
1712    m2 = unittest_proto3_arena_pb2.TestAllTypes()
1713    m2.ParseFromString(serialized)
1714    self.assertEqual(1234567, m2.optional_nested_enum)
1715    self.assertEqual(7654321, m2.repeated_nested_enum[0])
1716
1717  # Map isn't really a proto3-only feature. But there is no proto2 equivalent
1718  # of google/protobuf/map_unittest.proto right now, so it's not easy to
1719  # test both with the same test like we do for the other proto2/proto3 tests.
1720  # (google/protobuf/map_proto2_unittest.proto is very different in the set
1721  # of messages and fields it contains).
1722  def testScalarMapDefaults(self):
1723    msg = map_unittest_pb2.TestMap()
1724
1725    # Scalars start out unset.
1726    self.assertFalse(-123 in msg.map_int32_int32)
1727    self.assertFalse(-2**33 in msg.map_int64_int64)
1728    self.assertFalse(123 in msg.map_uint32_uint32)
1729    self.assertFalse(2**33 in msg.map_uint64_uint64)
1730    self.assertFalse(123 in msg.map_int32_double)
1731    self.assertFalse(False in msg.map_bool_bool)
1732    self.assertFalse('abc' in msg.map_string_string)
1733    self.assertFalse(111 in msg.map_int32_bytes)
1734    self.assertFalse(888 in msg.map_int32_enum)
1735
1736    # Accessing an unset key returns the default.
1737    self.assertEqual(0, msg.map_int32_int32[-123])
1738    self.assertEqual(0, msg.map_int64_int64[-2**33])
1739    self.assertEqual(0, msg.map_uint32_uint32[123])
1740    self.assertEqual(0, msg.map_uint64_uint64[2**33])
1741    self.assertEqual(0.0, msg.map_int32_double[123])
1742    self.assertTrue(isinstance(msg.map_int32_double[123], float))
1743    self.assertEqual(False, msg.map_bool_bool[False])
1744    self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
1745    self.assertEqual('', msg.map_string_string['abc'])
1746    self.assertEqual(b'', msg.map_int32_bytes[111])
1747    self.assertEqual(0, msg.map_int32_enum[888])
1748
1749    # It also sets the value in the map
1750    self.assertTrue(-123 in msg.map_int32_int32)
1751    self.assertTrue(-2**33 in msg.map_int64_int64)
1752    self.assertTrue(123 in msg.map_uint32_uint32)
1753    self.assertTrue(2**33 in msg.map_uint64_uint64)
1754    self.assertTrue(123 in msg.map_int32_double)
1755    self.assertTrue(False in msg.map_bool_bool)
1756    self.assertTrue('abc' in msg.map_string_string)
1757    self.assertTrue(111 in msg.map_int32_bytes)
1758    self.assertTrue(888 in msg.map_int32_enum)
1759
1760    self.assertIsInstance(msg.map_string_string['abc'], str)
1761
1762    # Accessing an unset key still throws TypeError if the type of the key
1763    # is incorrect.
1764    with self.assertRaises(TypeError):
1765      msg.map_string_string[123]
1766
1767    with self.assertRaises(TypeError):
1768      123 in msg.map_string_string
1769
1770  def testMapGet(self):
1771    # Need to test that get() properly returns the default, even though the dict
1772    # has defaultdict-like semantics.
1773    msg = map_unittest_pb2.TestMap()
1774
1775    self.assertIsNone(msg.map_int32_int32.get(5))
1776    self.assertEqual(10, msg.map_int32_int32.get(5, 10))
1777    self.assertEqual(10, msg.map_int32_int32.get(key=5, default=10))
1778    self.assertIsNone(msg.map_int32_int32.get(5))
1779
1780    msg.map_int32_int32[5] = 15
1781    self.assertEqual(15, msg.map_int32_int32.get(5))
1782    self.assertEqual(15, msg.map_int32_int32.get(5))
1783    with self.assertRaises(TypeError):
1784      msg.map_int32_int32.get('')
1785
1786    self.assertIsNone(msg.map_int32_foreign_message.get(5))
1787    self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
1788    self.assertEqual(10, msg.map_int32_foreign_message.get(key=5, default=10))
1789
1790    submsg = msg.map_int32_foreign_message[5]
1791    self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
1792    with self.assertRaises(TypeError):
1793      msg.map_int32_foreign_message.get('')
1794
1795  def testScalarMap(self):
1796    msg = map_unittest_pb2.TestMap()
1797
1798    self.assertEqual(0, len(msg.map_int32_int32))
1799    self.assertFalse(5 in msg.map_int32_int32)
1800
1801    msg.map_int32_int32[-123] = -456
1802    msg.map_int64_int64[-2**33] = -2**34
1803    msg.map_uint32_uint32[123] = 456
1804    msg.map_uint64_uint64[2**33] = 2**34
1805    msg.map_int32_float[2] = 1.2
1806    msg.map_int32_double[1] = 3.3
1807    msg.map_string_string['abc'] = '123'
1808    msg.map_bool_bool[True] = True
1809    msg.map_int32_enum[888] = 2
1810    # Unknown numeric enum is supported in proto3.
1811    msg.map_int32_enum[123] = 456
1812
1813    self.assertEqual([], msg.FindInitializationErrors())
1814
1815    self.assertEqual(1, len(msg.map_string_string))
1816
1817    # Bad key.
1818    with self.assertRaises(TypeError):
1819      msg.map_string_string[123] = '123'
1820
1821    # Verify that trying to assign a bad key doesn't actually add a member to
1822    # the map.
1823    self.assertEqual(1, len(msg.map_string_string))
1824
1825    # Bad value.
1826    with self.assertRaises(TypeError):
1827      msg.map_string_string['123'] = 123
1828
1829    serialized = msg.SerializeToString()
1830    msg2 = map_unittest_pb2.TestMap()
1831    msg2.ParseFromString(serialized)
1832
1833    # Bad key.
1834    with self.assertRaises(TypeError):
1835      msg2.map_string_string[123] = '123'
1836
1837    # Bad value.
1838    with self.assertRaises(TypeError):
1839      msg2.map_string_string['123'] = 123
1840
1841    self.assertEqual(-456, msg2.map_int32_int32[-123])
1842    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
1843    self.assertEqual(456, msg2.map_uint32_uint32[123])
1844    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
1845    self.assertAlmostEqual(1.2, msg.map_int32_float[2])
1846    self.assertEqual(3.3, msg.map_int32_double[1])
1847    self.assertEqual('123', msg2.map_string_string['abc'])
1848    self.assertEqual(True, msg2.map_bool_bool[True])
1849    self.assertEqual(2, msg2.map_int32_enum[888])
1850    self.assertEqual(456, msg2.map_int32_enum[123])
1851    self.assertEqual('{-123: -456}', str(msg2.map_int32_int32))
1852
1853  def testMapEntryAlwaysSerialized(self):
1854    msg = map_unittest_pb2.TestMap()
1855    msg.map_int32_int32[0] = 0
1856    msg.map_string_string[''] = ''
1857    self.assertEqual(msg.ByteSize(), 12)
1858    self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00',
1859                     msg.SerializeToString())
1860
1861  def testStringUnicodeConversionInMap(self):
1862    msg = map_unittest_pb2.TestMap()
1863
1864    unicode_obj = u'\u1234'
1865    bytes_obj = unicode_obj.encode('utf8')
1866
1867    msg.map_string_string[bytes_obj] = bytes_obj
1868
1869    (key, value) = list(msg.map_string_string.items())[0]
1870
1871    self.assertEqual(key, unicode_obj)
1872    self.assertEqual(value, unicode_obj)
1873
1874    self.assertIsInstance(key, str)
1875    self.assertIsInstance(value, str)
1876
1877  def testMessageMap(self):
1878    msg = map_unittest_pb2.TestMap()
1879
1880    self.assertEqual(0, len(msg.map_int32_foreign_message))
1881    self.assertFalse(5 in msg.map_int32_foreign_message)
1882
1883    msg.map_int32_foreign_message[123]
1884    # get_or_create() is an alias for getitem.
1885    msg.map_int32_foreign_message.get_or_create(-456)
1886
1887    self.assertEqual(2, len(msg.map_int32_foreign_message))
1888    self.assertIn(123, msg.map_int32_foreign_message)
1889    self.assertIn(-456, msg.map_int32_foreign_message)
1890    self.assertEqual(2, len(msg.map_int32_foreign_message))
1891
1892    # Bad key.
1893    with self.assertRaises(TypeError):
1894      msg.map_int32_foreign_message['123']
1895
1896    # Can't assign directly to submessage.
1897    with self.assertRaises(ValueError):
1898      msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
1899
1900    # Verify that trying to assign a bad key doesn't actually add a member to
1901    # the map.
1902    self.assertEqual(2, len(msg.map_int32_foreign_message))
1903
1904    serialized = msg.SerializeToString()
1905    msg2 = map_unittest_pb2.TestMap()
1906    msg2.ParseFromString(serialized)
1907
1908    self.assertEqual(2, len(msg2.map_int32_foreign_message))
1909    self.assertIn(123, msg2.map_int32_foreign_message)
1910    self.assertIn(-456, msg2.map_int32_foreign_message)
1911    self.assertEqual(2, len(msg2.map_int32_foreign_message))
1912    msg2.map_int32_foreign_message[123].c = 1
1913    # TODO(jieluo): Fix text format for message map.
1914    self.assertIn(
1915        str(msg2.map_int32_foreign_message),
1916        ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
1917
1918  def testNestedMessageMapItemDelete(self):
1919    msg = map_unittest_pb2.TestMap()
1920    msg.map_int32_all_types[1].optional_nested_message.bb = 1
1921    del msg.map_int32_all_types[1]
1922    msg.map_int32_all_types[2].optional_nested_message.bb = 2
1923    self.assertEqual(1, len(msg.map_int32_all_types))
1924    msg.map_int32_all_types[1].optional_nested_message.bb = 1
1925    self.assertEqual(2, len(msg.map_int32_all_types))
1926
1927    serialized = msg.SerializeToString()
1928    msg2 = map_unittest_pb2.TestMap()
1929    msg2.ParseFromString(serialized)
1930    keys = [1, 2]
1931    # The loop triggers PyErr_Occurred() in c extension.
1932    for key in keys:
1933      del msg2.map_int32_all_types[key]
1934
1935  def testMapByteSize(self):
1936    msg = map_unittest_pb2.TestMap()
1937    msg.map_int32_int32[1] = 1
1938    size = msg.ByteSize()
1939    msg.map_int32_int32[1] = 128
1940    self.assertEqual(msg.ByteSize(), size + 1)
1941
1942    msg.map_int32_foreign_message[19].c = 1
1943    size = msg.ByteSize()
1944    msg.map_int32_foreign_message[19].c = 128
1945    self.assertEqual(msg.ByteSize(), size + 1)
1946
1947  def testMergeFrom(self):
1948    msg = map_unittest_pb2.TestMap()
1949    msg.map_int32_int32[12] = 34
1950    msg.map_int32_int32[56] = 78
1951    msg.map_int64_int64[22] = 33
1952    msg.map_int32_foreign_message[111].c = 5
1953    msg.map_int32_foreign_message[222].c = 10
1954
1955    msg2 = map_unittest_pb2.TestMap()
1956    msg2.map_int32_int32[12] = 55
1957    msg2.map_int64_int64[88] = 99
1958    msg2.map_int32_foreign_message[222].c = 15
1959    msg2.map_int32_foreign_message[222].d = 20
1960    old_map_value = msg2.map_int32_foreign_message[222]
1961
1962    msg2.MergeFrom(msg)
1963    # Compare with expected message instead of call
1964    # msg2.map_int32_foreign_message[222] to make sure MergeFrom does not
1965    # sync with repeated field and there is no duplicated keys.
1966    expected_msg = map_unittest_pb2.TestMap()
1967    expected_msg.CopyFrom(msg)
1968    expected_msg.map_int64_int64[88] = 99
1969    self.assertEqual(msg2, expected_msg)
1970
1971    self.assertEqual(34, msg2.map_int32_int32[12])
1972    self.assertEqual(78, msg2.map_int32_int32[56])
1973    self.assertEqual(33, msg2.map_int64_int64[22])
1974    self.assertEqual(99, msg2.map_int64_int64[88])
1975    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
1976    self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
1977    self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
1978    if api_implementation.Type() != 'cpp':
1979      # During the call to MergeFrom(), the C++ implementation will have
1980      # deallocated the underlying message, but this is very difficult to detect
1981      # properly. The line below is likely to cause a segmentation fault.
1982      # With the Python implementation, old_map_value is just 'detached' from
1983      # the main message. Using it will not crash of course, but since it still
1984      # have a reference to the parent message I'm sure we can find interesting
1985      # ways to cause inconsistencies.
1986      self.assertEqual(15, old_map_value.c)
1987
1988    # Verify that there is only one entry per key, even though the MergeFrom
1989    # may have internally created multiple entries for a single key in the
1990    # list representation.
1991    as_dict = {}
1992    for key in msg2.map_int32_foreign_message:
1993      self.assertFalse(key in as_dict)
1994      as_dict[key] = msg2.map_int32_foreign_message[key].c
1995
1996    self.assertEqual({111: 5, 222: 10}, as_dict)
1997
1998    # Special case: test that delete of item really removes the item, even if
1999    # there might have physically been duplicate keys due to the previous merge.
2000    # This is only a special case for the C++ implementation which stores the
2001    # map as an array.
2002    del msg2.map_int32_int32[12]
2003    self.assertFalse(12 in msg2.map_int32_int32)
2004
2005    del msg2.map_int32_foreign_message[222]
2006    self.assertFalse(222 in msg2.map_int32_foreign_message)
2007    with self.assertRaises(TypeError):
2008      del msg2.map_int32_foreign_message['']
2009
2010  def testMapMergeFrom(self):
2011    msg = map_unittest_pb2.TestMap()
2012    msg.map_int32_int32[12] = 34
2013    msg.map_int32_int32[56] = 78
2014    msg.map_int64_int64[22] = 33
2015    msg.map_int32_foreign_message[111].c = 5
2016    msg.map_int32_foreign_message[222].c = 10
2017
2018    msg2 = map_unittest_pb2.TestMap()
2019    msg2.map_int32_int32[12] = 55
2020    msg2.map_int64_int64[88] = 99
2021    msg2.map_int32_foreign_message[222].c = 15
2022    msg2.map_int32_foreign_message[222].d = 20
2023
2024    msg2.map_int32_int32.MergeFrom(msg.map_int32_int32)
2025    self.assertEqual(34, msg2.map_int32_int32[12])
2026    self.assertEqual(78, msg2.map_int32_int32[56])
2027
2028    msg2.map_int64_int64.MergeFrom(msg.map_int64_int64)
2029    self.assertEqual(33, msg2.map_int64_int64[22])
2030    self.assertEqual(99, msg2.map_int64_int64[88])
2031
2032    msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message)
2033    # Compare with expected message instead of call
2034    # msg.map_int32_foreign_message[222] to make sure MergeFrom does not
2035    # sync with repeated field and no duplicated keys.
2036    expected_msg = map_unittest_pb2.TestMap()
2037    expected_msg.CopyFrom(msg)
2038    expected_msg.map_int64_int64[88] = 99
2039    self.assertEqual(msg2, expected_msg)
2040
2041    # Test when cpp extension cache a map.
2042    m1 = map_unittest_pb2.TestMap()
2043    m2 = map_unittest_pb2.TestMap()
2044    self.assertEqual(m1.map_int32_foreign_message, m1.map_int32_foreign_message)
2045    m2.map_int32_foreign_message[123].c = 10
2046    m1.MergeFrom(m2)
2047    self.assertEqual(10, m2.map_int32_foreign_message[123].c)
2048
2049    # Test merge maps within different message types.
2050    m1 = map_unittest_pb2.TestMap()
2051    m2 = map_unittest_pb2.TestMessageMap()
2052    m2.map_int32_message[123].optional_int32 = 10
2053    m1.map_int32_all_types.MergeFrom(m2.map_int32_message)
2054    self.assertEqual(10, m1.map_int32_all_types[123].optional_int32)
2055
2056    # Test overwrite message value map
2057    msg = map_unittest_pb2.TestMap()
2058    msg.map_int32_foreign_message[222].c = 123
2059    msg2 = map_unittest_pb2.TestMap()
2060    msg2.map_int32_foreign_message[222].d = 20
2061    msg.MergeFromString(msg2.SerializeToString())
2062    self.assertEqual(msg.map_int32_foreign_message[222].d, 20)
2063    self.assertNotEqual(msg.map_int32_foreign_message[222].c, 123)
2064
2065    # Merge a dict to map field is not accepted
2066    with self.assertRaises(AttributeError):
2067      m1.map_int32_all_types.MergeFrom(
2068          {1: unittest_proto3_arena_pb2.TestAllTypes()})
2069
2070  def testMergeFromBadType(self):
2071    msg = map_unittest_pb2.TestMap()
2072    with self.assertRaisesRegex(
2073        TypeError,
2074        r'Parameter to MergeFrom\(\) must be instance of same class: expected '
2075        r'.+TestMap got int\.'):
2076      msg.MergeFrom(1)
2077
2078  def testCopyFromBadType(self):
2079    msg = map_unittest_pb2.TestMap()
2080    with self.assertRaisesRegex(
2081        TypeError,
2082        r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
2083        r'expected .+TestMap got int\.'):
2084      msg.CopyFrom(1)
2085
2086  def testIntegerMapWithLongs(self):
2087    msg = map_unittest_pb2.TestMap()
2088    msg.map_int32_int32[int(-123)] = int(-456)
2089    msg.map_int64_int64[int(-2**33)] = int(-2**34)
2090    msg.map_uint32_uint32[int(123)] = int(456)
2091    msg.map_uint64_uint64[int(2**33)] = int(2**34)
2092
2093    serialized = msg.SerializeToString()
2094    msg2 = map_unittest_pb2.TestMap()
2095    msg2.ParseFromString(serialized)
2096
2097    self.assertEqual(-456, msg2.map_int32_int32[-123])
2098    self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
2099    self.assertEqual(456, msg2.map_uint32_uint32[123])
2100    self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
2101
2102  def testMapAssignmentCausesPresence(self):
2103    msg = map_unittest_pb2.TestMapSubmessage()
2104    msg.test_map.map_int32_int32[123] = 456
2105
2106    serialized = msg.SerializeToString()
2107    msg2 = map_unittest_pb2.TestMapSubmessage()
2108    msg2.ParseFromString(serialized)
2109
2110    self.assertEqual(msg, msg2)
2111
2112    # Now test that various mutations of the map properly invalidate the
2113    # cached size of the submessage.
2114    msg.test_map.map_int32_int32[888] = 999
2115    serialized = msg.SerializeToString()
2116    msg2.ParseFromString(serialized)
2117    self.assertEqual(msg, msg2)
2118
2119    msg.test_map.map_int32_int32.clear()
2120    serialized = msg.SerializeToString()
2121    msg2.ParseFromString(serialized)
2122    self.assertEqual(msg, msg2)
2123
2124  def testMapAssignmentCausesPresenceForSubmessages(self):
2125    msg = map_unittest_pb2.TestMapSubmessage()
2126    msg.test_map.map_int32_foreign_message[123].c = 5
2127
2128    serialized = msg.SerializeToString()
2129    msg2 = map_unittest_pb2.TestMapSubmessage()
2130    msg2.ParseFromString(serialized)
2131
2132    self.assertEqual(msg, msg2)
2133
2134    # Now test that various mutations of the map properly invalidate the
2135    # cached size of the submessage.
2136    msg.test_map.map_int32_foreign_message[888].c = 7
2137    serialized = msg.SerializeToString()
2138    msg2.ParseFromString(serialized)
2139    self.assertEqual(msg, msg2)
2140
2141    msg.test_map.map_int32_foreign_message[888].MergeFrom(
2142        msg.test_map.map_int32_foreign_message[123])
2143    serialized = msg.SerializeToString()
2144    msg2.ParseFromString(serialized)
2145    self.assertEqual(msg, msg2)
2146
2147    msg.test_map.map_int32_foreign_message.clear()
2148    serialized = msg.SerializeToString()
2149    msg2.ParseFromString(serialized)
2150    self.assertEqual(msg, msg2)
2151
2152  def testModifyMapWhileIterating(self):
2153    msg = map_unittest_pb2.TestMap()
2154
2155    string_string_iter = iter(msg.map_string_string)
2156    int32_foreign_iter = iter(msg.map_int32_foreign_message)
2157
2158    msg.map_string_string['abc'] = '123'
2159    msg.map_int32_foreign_message[5].c = 5
2160
2161    with self.assertRaises(RuntimeError):
2162      for key in string_string_iter:
2163        pass
2164
2165    with self.assertRaises(RuntimeError):
2166      for key in int32_foreign_iter:
2167        pass
2168
2169  def testModifyMapEntryWhileIterating(self):
2170    msg = map_unittest_pb2.TestMap()
2171
2172    msg.map_string_string['abc'] = '123'
2173    msg.map_string_string['def'] = '456'
2174    msg.map_string_string['ghi'] = '789'
2175
2176    msg.map_int32_foreign_message[5].c = 5
2177    msg.map_int32_foreign_message[6].c = 6
2178    msg.map_int32_foreign_message[7].c = 7
2179
2180    string_string_keys = list(msg.map_string_string.keys())
2181    int32_foreign_keys = list(msg.map_int32_foreign_message.keys())
2182
2183    keys = []
2184    for key in msg.map_string_string:
2185      keys.append(key)
2186      msg.map_string_string[key] = '000'
2187    self.assertEqual(keys, string_string_keys)
2188    self.assertEqual(keys, list(msg.map_string_string.keys()))
2189
2190    keys = []
2191    for key in msg.map_int32_foreign_message:
2192      keys.append(key)
2193      msg.map_int32_foreign_message[key].c = 0
2194    self.assertEqual(keys, int32_foreign_keys)
2195    self.assertEqual(keys, list(msg.map_int32_foreign_message.keys()))
2196
2197  def testSubmessageMap(self):
2198    msg = map_unittest_pb2.TestMap()
2199
2200    submsg = msg.map_int32_foreign_message[111]
2201    self.assertIs(submsg, msg.map_int32_foreign_message[111])
2202    self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
2203
2204    submsg.c = 5
2205
2206    serialized = msg.SerializeToString()
2207    msg2 = map_unittest_pb2.TestMap()
2208    msg2.ParseFromString(serialized)
2209
2210    self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
2211
2212    # Doesn't allow direct submessage assignment.
2213    with self.assertRaises(ValueError):
2214      msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
2215
2216  def testMapIteration(self):
2217    msg = map_unittest_pb2.TestMap()
2218
2219    for k, v in msg.map_int32_int32.items():
2220      # Should not be reached.
2221      self.assertTrue(False)
2222
2223    msg.map_int32_int32[2] = 4
2224    msg.map_int32_int32[3] = 6
2225    msg.map_int32_int32[4] = 8
2226    self.assertEqual(3, len(msg.map_int32_int32))
2227
2228    matching_dict = {2: 4, 3: 6, 4: 8}
2229    self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
2230
2231  def testMapItems(self):
2232    # Map items used to have strange behaviors when use c extension. Because
2233    # [] may reorder the map and invalidate any existing iterators.
2234    # TODO(jieluo): Check if [] reordering the map is a bug or intended
2235    # behavior.
2236    msg = map_unittest_pb2.TestMap()
2237    msg.map_string_string['local_init_op'] = ''
2238    msg.map_string_string['trainable_variables'] = ''
2239    msg.map_string_string['variables'] = ''
2240    msg.map_string_string['init_op'] = ''
2241    msg.map_string_string['summaries'] = ''
2242    items1 = msg.map_string_string.items()
2243    items2 = msg.map_string_string.items()
2244    self.assertEqual(items1, items2)
2245
2246  def testMapDeterministicSerialization(self):
2247    golden_data = (b'r\x0c\n\x07init_op\x12\x01d'
2248                   b'r\n\n\x05item1\x12\x01e'
2249                   b'r\n\n\x05item2\x12\x01f'
2250                   b'r\n\n\x05item3\x12\x01g'
2251                   b'r\x0b\n\x05item4\x12\x02QQ'
2252                   b'r\x12\n\rlocal_init_op\x12\x01a'
2253                   b'r\x0e\n\tsummaries\x12\x01e'
2254                   b'r\x18\n\x13trainable_variables\x12\x01b'
2255                   b'r\x0e\n\tvariables\x12\x01c')
2256    msg = map_unittest_pb2.TestMap()
2257    msg.map_string_string['local_init_op'] = 'a'
2258    msg.map_string_string['trainable_variables'] = 'b'
2259    msg.map_string_string['variables'] = 'c'
2260    msg.map_string_string['init_op'] = 'd'
2261    msg.map_string_string['summaries'] = 'e'
2262    msg.map_string_string['item1'] = 'e'
2263    msg.map_string_string['item2'] = 'f'
2264    msg.map_string_string['item3'] = 'g'
2265    msg.map_string_string['item4'] = 'QQ'
2266
2267    # If deterministic serialization is not working correctly, this will be
2268    # "flaky" depending on the exact python dict hash seed.
2269    #
2270    # Fortunately, there are enough items in this map that it is extremely
2271    # unlikely to ever hit the "right" in-order combination, so the test
2272    # itself should fail reliably.
2273    self.assertEqual(golden_data, msg.SerializeToString(deterministic=True))
2274
2275  def testMapIterationClearMessage(self):
2276    # Iterator needs to work even if message and map are deleted.
2277    msg = map_unittest_pb2.TestMap()
2278
2279    msg.map_int32_int32[2] = 4
2280    msg.map_int32_int32[3] = 6
2281    msg.map_int32_int32[4] = 8
2282
2283    it = msg.map_int32_int32.items()
2284    del msg
2285
2286    matching_dict = {2: 4, 3: 6, 4: 8}
2287    self.assertMapIterEquals(it, matching_dict)
2288
2289  def testMapConstruction(self):
2290    msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
2291    self.assertEqual(2, msg.map_int32_int32[1])
2292    self.assertEqual(4, msg.map_int32_int32[3])
2293
2294    msg = map_unittest_pb2.TestMap(
2295        map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
2296    self.assertEqual(5, msg.map_int32_foreign_message[3].c)
2297
2298  def testMapScalarFieldConstruction(self):
2299    msg1 = map_unittest_pb2.TestMap()
2300    msg1.map_int32_int32[1] = 42
2301    msg2 = map_unittest_pb2.TestMap(map_int32_int32=msg1.map_int32_int32)
2302    self.assertEqual(42, msg2.map_int32_int32[1])
2303
2304  def testMapMessageFieldConstruction(self):
2305    msg1 = map_unittest_pb2.TestMap()
2306    msg1.map_string_foreign_message['test'].c = 42
2307    msg2 = map_unittest_pb2.TestMap(
2308        map_string_foreign_message=msg1.map_string_foreign_message)
2309    self.assertEqual(42, msg2.map_string_foreign_message['test'].c)
2310
2311  def testMapFieldRaisesCorrectError(self):
2312    # Should raise a TypeError when given a non-iterable.
2313    with self.assertRaises(TypeError):
2314      map_unittest_pb2.TestMap(map_string_foreign_message=1)
2315
2316  def testMapValidAfterFieldCleared(self):
2317    # Map needs to work even if field is cleared.
2318    # For the C++ implementation this tests the correctness of
2319    # MapContainer::Release()
2320    msg = map_unittest_pb2.TestMap()
2321    int32_map = msg.map_int32_int32
2322
2323    int32_map[2] = 4
2324    int32_map[3] = 6
2325    int32_map[4] = 8
2326
2327    msg.ClearField('map_int32_int32')
2328    self.assertEqual(b'', msg.SerializeToString())
2329    matching_dict = {2: 4, 3: 6, 4: 8}
2330    self.assertMapIterEquals(int32_map.items(), matching_dict)
2331
2332  def testMessageMapValidAfterFieldCleared(self):
2333    # Map needs to work even if field is cleared.
2334    # For the C++ implementation this tests the correctness of
2335    # MapContainer::Release()
2336    msg = map_unittest_pb2.TestMap()
2337    int32_foreign_message = msg.map_int32_foreign_message
2338
2339    int32_foreign_message[2].c = 5
2340
2341    msg.ClearField('map_int32_foreign_message')
2342    self.assertEqual(b'', msg.SerializeToString())
2343    self.assertTrue(2 in int32_foreign_message.keys())
2344
2345  def testMessageMapItemValidAfterTopMessageCleared(self):
2346    # Message map item needs to work even if it is cleared.
2347    # For the C++ implementation this tests the correctness of
2348    # MapContainer::Release()
2349    msg = map_unittest_pb2.TestMap()
2350    msg.map_int32_all_types[2].optional_string = 'bar'
2351
2352    if api_implementation.Type() == 'cpp':
2353      # Need to keep the map reference because of b/27942626.
2354      # TODO(jieluo): Remove it.
2355      unused_map = msg.map_int32_all_types  # pylint: disable=unused-variable
2356    msg_value = msg.map_int32_all_types[2]
2357    msg.Clear()
2358
2359    # Reset to trigger sync between repeated field and map in c++.
2360    msg.map_int32_all_types[3].optional_string = 'foo'
2361    self.assertEqual(msg_value.optional_string, 'bar')
2362
2363  def testMapIterInvalidatedByClearField(self):
2364    # Map iterator is invalidated when field is cleared.
2365    # But this case does need to not crash the interpreter.
2366    # For the C++ implementation this tests the correctness of
2367    # ScalarMapContainer::Release()
2368    msg = map_unittest_pb2.TestMap()
2369
2370    it = iter(msg.map_int32_int32)
2371
2372    msg.ClearField('map_int32_int32')
2373    with self.assertRaises(RuntimeError):
2374      for _ in it:
2375        pass
2376
2377    it = iter(msg.map_int32_foreign_message)
2378    msg.ClearField('map_int32_foreign_message')
2379    with self.assertRaises(RuntimeError):
2380      for _ in it:
2381        pass
2382
2383  def testMapDelete(self):
2384    msg = map_unittest_pb2.TestMap()
2385
2386    self.assertEqual(0, len(msg.map_int32_int32))
2387
2388    msg.map_int32_int32[4] = 6
2389    self.assertEqual(1, len(msg.map_int32_int32))
2390
2391    with self.assertRaises(KeyError):
2392      del msg.map_int32_int32[88]
2393
2394    del msg.map_int32_int32[4]
2395    self.assertEqual(0, len(msg.map_int32_int32))
2396
2397    with self.assertRaises(KeyError):
2398      del msg.map_int32_all_types[32]
2399
2400  def testMapsAreMapping(self):
2401    msg = map_unittest_pb2.TestMap()
2402    self.assertIsInstance(msg.map_int32_int32, collections.abc.Mapping)
2403    self.assertIsInstance(msg.map_int32_int32, collections.abc.MutableMapping)
2404    self.assertIsInstance(msg.map_int32_foreign_message,
2405                          collections.abc.Mapping)
2406    self.assertIsInstance(msg.map_int32_foreign_message,
2407                          collections.abc.MutableMapping)
2408
2409  def testMapsCompare(self):
2410    msg = map_unittest_pb2.TestMap()
2411    msg.map_int32_int32[-123] = -456
2412    self.assertEqual(msg.map_int32_int32, msg.map_int32_int32)
2413    self.assertEqual(msg.map_int32_foreign_message,
2414                     msg.map_int32_foreign_message)
2415    self.assertNotEqual(msg.map_int32_int32, 0)
2416
2417  def testMapFindInitializationErrorsSmokeTest(self):
2418    msg = map_unittest_pb2.TestMap()
2419    msg.map_string_string['abc'] = '123'
2420    msg.map_int32_int32[35] = 64
2421    msg.map_string_foreign_message['foo'].c = 5
2422    self.assertEqual(0, len(msg.FindInitializationErrors()))
2423
2424  @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2')
2425  def testStrictUtf8Check(self):
2426    # Test u'\ud801' is rejected at parser in both python2 and python3.
2427    serialized = (b'r\x03\xed\xa0\x81')
2428    msg = unittest_proto3_arena_pb2.TestAllTypes()
2429    with self.assertRaises(Exception) as context:
2430      msg.MergeFromString(serialized)
2431    if api_implementation.Type() == 'python':
2432      self.assertIn('optional_string', str(context.exception))
2433    else:
2434      self.assertIn('Error parsing message', str(context.exception))
2435
2436    # Test optional_string=u'��' is accepted.
2437    serialized = unittest_proto3_arena_pb2.TestAllTypes(
2438        optional_string=u'��').SerializeToString()
2439    msg2 = unittest_proto3_arena_pb2.TestAllTypes()
2440    msg2.MergeFromString(serialized)
2441    self.assertEqual(msg2.optional_string, u'��')
2442
2443    msg = unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud001')
2444    self.assertEqual(msg.optional_string, u'\ud001')
2445
2446  def testSurrogatesInPython3(self):
2447    # Surrogates are rejected at setters in Python3.
2448    with self.assertRaises(ValueError):
2449      unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\udc01')
2450    with self.assertRaises(ValueError):
2451      unittest_proto3_arena_pb2.TestAllTypes(optional_string=b'\xed\xa0\x81')
2452    with self.assertRaises(ValueError):
2453      unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801')
2454    with self.assertRaises(ValueError):
2455      unittest_proto3_arena_pb2.TestAllTypes(optional_string=u'\ud801\ud801')
2456
2457
2458
2459
2460@testing_refleaks.TestCase
2461class ValidTypeNamesTest(unittest.TestCase):
2462
2463  def assertImportFromName(self, msg, base_name):
2464    # Parse <type 'module.class_name'> to extra 'some.name' as a string.
2465    tp_name = str(type(msg)).split("'")[1]
2466    valid_names = ('Repeated%sContainer' % base_name,
2467                   'Repeated%sFieldContainer' % base_name)
2468    self.assertTrue(
2469        any(tp_name.endswith(v) for v in valid_names),
2470        '%r does end with any of %r' % (tp_name, valid_names))
2471
2472    parts = tp_name.split('.')
2473    class_name = parts[-1]
2474    module_name = '.'.join(parts[:-1])
2475    __import__(module_name, fromlist=[class_name])
2476
2477  def testTypeNamesCanBeImported(self):
2478    # If import doesn't work, pickling won't work either.
2479    pb = unittest_pb2.TestAllTypes()
2480    self.assertImportFromName(pb.repeated_int32, 'Scalar')
2481    self.assertImportFromName(pb.repeated_nested_message, 'Composite')
2482
2483
2484@testing_refleaks.TestCase
2485class PackedFieldTest(unittest.TestCase):
2486
2487  def setMessage(self, message):
2488    message.repeated_int32.append(1)
2489    message.repeated_int64.append(1)
2490    message.repeated_uint32.append(1)
2491    message.repeated_uint64.append(1)
2492    message.repeated_sint32.append(1)
2493    message.repeated_sint64.append(1)
2494    message.repeated_fixed32.append(1)
2495    message.repeated_fixed64.append(1)
2496    message.repeated_sfixed32.append(1)
2497    message.repeated_sfixed64.append(1)
2498    message.repeated_float.append(1.0)
2499    message.repeated_double.append(1.0)
2500    message.repeated_bool.append(True)
2501    message.repeated_nested_enum.append(1)
2502
2503  def testPackedFields(self):
2504    message = packed_field_test_pb2.TestPackedTypes()
2505    self.setMessage(message)
2506    golden_data = (b'\x0A\x01\x01'
2507                   b'\x12\x01\x01'
2508                   b'\x1A\x01\x01'
2509                   b'\x22\x01\x01'
2510                   b'\x2A\x01\x02'
2511                   b'\x32\x01\x02'
2512                   b'\x3A\x04\x01\x00\x00\x00'
2513                   b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2514                   b'\x4A\x04\x01\x00\x00\x00'
2515                   b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2516                   b'\x5A\x04\x00\x00\x80\x3f'
2517                   b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
2518                   b'\x6A\x01\x01'
2519                   b'\x72\x01\x01')
2520    self.assertEqual(golden_data, message.SerializeToString())
2521
2522  def testUnpackedFields(self):
2523    message = packed_field_test_pb2.TestUnpackedTypes()
2524    self.setMessage(message)
2525    golden_data = (b'\x08\x01'
2526                   b'\x10\x01'
2527                   b'\x18\x01'
2528                   b'\x20\x01'
2529                   b'\x28\x02'
2530                   b'\x30\x02'
2531                   b'\x3D\x01\x00\x00\x00'
2532                   b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
2533                   b'\x4D\x01\x00\x00\x00'
2534                   b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
2535                   b'\x5D\x00\x00\x80\x3f'
2536                   b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
2537                   b'\x68\x01'
2538                   b'\x70\x01')
2539    self.assertEqual(golden_data, message.SerializeToString())
2540
2541
2542@unittest.skipIf(api_implementation.Type() != 'cpp',
2543                 'explicit tests of the C++ implementation')
2544@testing_refleaks.TestCase
2545class OversizeProtosTest(unittest.TestCase):
2546
2547  @classmethod
2548  def setUpClass(cls):
2549    # At the moment, reference cycles between DescriptorPool and Message classes
2550    # are not detected and these objects are never freed.
2551    # To avoid errors with ReferenceLeakChecker, we create the class only once.
2552    file_desc = """
2553      name: "f/f.msg2"
2554      package: "f"
2555      message_type {
2556        name: "msg1"
2557        field {
2558          name: "payload"
2559          number: 1
2560          label: LABEL_OPTIONAL
2561          type: TYPE_STRING
2562        }
2563      }
2564      message_type {
2565        name: "msg2"
2566        field {
2567          name: "field"
2568          number: 1
2569          label: LABEL_OPTIONAL
2570          type: TYPE_MESSAGE
2571          type_name: "msg1"
2572        }
2573      }
2574    """
2575    pool = descriptor_pool.DescriptorPool()
2576    desc = descriptor_pb2.FileDescriptorProto()
2577    text_format.Parse(file_desc, desc)
2578    pool.Add(desc)
2579    cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
2580        pool.FindMessageTypeByName('f.msg2'))
2581
2582  def setUp(self):
2583    self.p = self.proto_cls()
2584    self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
2585    self.p_serialized = self.p.SerializeToString()
2586
2587  def testAssertOversizeProto(self):
2588    from google.protobuf.pyext._message import SetAllowOversizeProtos
2589    SetAllowOversizeProtos(False)
2590    q = self.proto_cls()
2591    try:
2592      q.ParseFromString(self.p_serialized)
2593    except message.DecodeError as e:
2594      self.assertEqual(str(e), 'Error parsing message')
2595
2596  def testSucceedOversizeProto(self):
2597    from google.protobuf.pyext._message import SetAllowOversizeProtos
2598    SetAllowOversizeProtos(True)
2599    q = self.proto_cls()
2600    q.ParseFromString(self.p_serialized)
2601    self.assertEqual(self.p.field.payload, q.field.payload)
2602
2603
2604if __name__ == '__main__':
2605  unittest.main()
2606