• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2# Protocol Buffers - Google's data interchange format
3# Copyright 2008 Google Inc.  All rights reserved.
4#
5# Use of this source code is governed by a BSD-style
6# license that can be found in the LICENSE file or at
7# https://developers.google.com/open-source/licenses/bsd
8
9"""Unittest for reflection.py, which also indirectly tests the output of the
10pure-Python protocol compiler.
11"""
12
13import copy
14import gc
15import operator
16import struct
17import sys
18import unittest
19import warnings
20
21from google.protobuf import descriptor
22from google.protobuf import descriptor_pb2
23from google.protobuf import message
24from google.protobuf import message_factory
25from google.protobuf import reflection
26from google.protobuf import text_format
27from google.protobuf.internal import api_implementation
28from google.protobuf.internal import decoder
29from google.protobuf.internal import message_set_extensions_pb2
30from google.protobuf.internal import more_extensions_pb2
31from google.protobuf.internal import more_messages_pb2
32from google.protobuf.internal import test_util
33from google.protobuf.internal import testing_refleaks
34from google.protobuf.internal import wire_format
35from google.protobuf.internal import _parameterized
36from google.protobuf import unittest_import_pb2
37from google.protobuf import unittest_mset_pb2
38from google.protobuf import unittest_pb2
39from google.protobuf import unittest_proto3_arena_pb2
40
41
42warnings.simplefilter('error', DeprecationWarning)
43
44
45class _MiniDecoder(object):
46  """Decodes a stream of values from a string.
47
48  Once upon a time we actually had a class called decoder.Decoder.  Then we
49  got rid of it during a redesign that made decoding much, much faster overall.
50  But a couple tests in this file used it to check that the serialized form of
51  a message was correct.  So, this class implements just the methods that were
52  used by said tests, so that we don't have to rewrite the tests.
53  """
54
55  def __init__(self, bytes):
56    self._bytes = bytes
57    self._pos = 0
58
59  def ReadVarint(self):
60    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
61    return result
62
63  ReadInt32 = ReadVarint
64  ReadInt64 = ReadVarint
65  ReadUInt32 = ReadVarint
66  ReadUInt64 = ReadVarint
67
68  def ReadSInt64(self):
69    return wire_format.ZigZagDecode(self.ReadVarint())
70
71  ReadSInt32 = ReadSInt64
72
73  def ReadFieldNumberAndWireType(self):
74    return wire_format.UnpackTag(self.ReadVarint())
75
76  def ReadFloat(self):
77    result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
78    self._pos += 4
79    return result
80
81  def ReadDouble(self):
82    result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
83    self._pos += 8
84    return result
85
86  def EndOfStream(self):
87    return self._pos == len(self._bytes)
88
89
90@_parameterized.named_parameters(
91    ('_proto2', unittest_pb2),
92    ('_proto3', unittest_proto3_arena_pb2))
93@testing_refleaks.TestCase
94class ReflectionTest(unittest.TestCase):
95
96  def assertListsEqual(self, values, others):
97    self.assertEqual(len(values), len(others))
98    for i in range(len(values)):
99      self.assertEqual(values[i], others[i])
100
101  def testScalarConstructor(self, message_module):
102    # Constructor with only scalar types should succeed.
103    proto = message_module.TestAllTypes(
104        optional_int32=24,
105        optional_double=54.321,
106        optional_string='optional_string',
107        optional_float=None)
108
109    self.assertEqual(24, proto.optional_int32)
110    self.assertEqual(54.321, proto.optional_double)
111    self.assertEqual('optional_string', proto.optional_string)
112    if message_module is unittest_pb2:
113      self.assertFalse(proto.HasField("optional_float"))
114
115  def testRepeatedScalarConstructor(self, message_module):
116    # Constructor with only repeated scalar types should succeed.
117    proto = message_module.TestAllTypes(
118        repeated_int32=[1, 2, 3, 4],
119        repeated_double=[1.23, 54.321],
120        repeated_bool=[True, False, False],
121        repeated_string=["optional_string"],
122        repeated_float=None)
123
124    self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
125    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
126    self.assertEqual([True, False, False], list(proto.repeated_bool))
127    self.assertEqual(["optional_string"], list(proto.repeated_string))
128    self.assertEqual([], list(proto.repeated_float))
129
130  def testMixedConstructor(self, message_module):
131    # Constructor with only mixed types should succeed.
132    proto = message_module.TestAllTypes(
133        optional_int32=24,
134        optional_string='optional_string',
135        repeated_double=[1.23, 54.321],
136        repeated_bool=[True, False, False],
137        repeated_nested_message=[
138            message_module.TestAllTypes.NestedMessage(
139                bb=message_module.TestAllTypes.FOO),
140            message_module.TestAllTypes.NestedMessage(
141                bb=message_module.TestAllTypes.BAR)],
142        repeated_foreign_message=[
143            message_module.ForeignMessage(c=-43),
144            message_module.ForeignMessage(c=45324),
145            message_module.ForeignMessage(c=12)],
146        optional_nested_message=None)
147
148    self.assertEqual(24, proto.optional_int32)
149    self.assertEqual('optional_string', proto.optional_string)
150    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
151    self.assertEqual([True, False, False], list(proto.repeated_bool))
152    self.assertEqual(
153        [message_module.TestAllTypes.NestedMessage(
154            bb=message_module.TestAllTypes.FOO),
155         message_module.TestAllTypes.NestedMessage(
156             bb=message_module.TestAllTypes.BAR)],
157        list(proto.repeated_nested_message))
158    self.assertEqual(
159        [message_module.ForeignMessage(c=-43),
160         message_module.ForeignMessage(c=45324),
161         message_module.ForeignMessage(c=12)],
162        list(proto.repeated_foreign_message))
163    self.assertFalse(proto.HasField("optional_nested_message"))
164
165  def testConstructorTypeError(self, message_module):
166    self.assertRaises(
167        TypeError, message_module.TestAllTypes, optional_int32='foo')
168    self.assertRaises(
169        TypeError, message_module.TestAllTypes, optional_string=1234)
170    self.assertRaises(
171        TypeError, message_module.TestAllTypes, optional_nested_message=1234)
172    self.assertRaises(
173        TypeError, message_module.TestAllTypes, repeated_int32=1234)
174    self.assertRaises(
175        TypeError, message_module.TestAllTypes, repeated_int32=['foo'])
176    self.assertRaises(
177        TypeError, message_module.TestAllTypes, repeated_string=1234)
178    self.assertRaises(
179        TypeError, message_module.TestAllTypes, repeated_string=[1234])
180    self.assertRaises(
181        TypeError, message_module.TestAllTypes, repeated_nested_message=1234)
182    self.assertRaises(
183        TypeError, message_module.TestAllTypes, repeated_nested_message=[1234])
184
185  def testConstructorInvalidatesCachedByteSize(self, message_module):
186    message = message_module.TestAllTypes(optional_int32=12)
187    self.assertEqual(2, message.ByteSize())
188
189    message = message_module.TestAllTypes(
190        optional_nested_message=message_module.TestAllTypes.NestedMessage())
191    self.assertEqual(3, message.ByteSize())
192
193    message = message_module.TestAllTypes(repeated_int32=[12])
194    # TODO: Add this test back for proto3
195    if message_module is unittest_pb2:
196      self.assertEqual(3, message.ByteSize())
197
198    message = message_module.TestAllTypes(
199        repeated_nested_message=[message_module.TestAllTypes.NestedMessage()])
200    self.assertEqual(3, message.ByteSize())
201
202  def testReferencesToNestedMessage(self, message_module):
203    proto = message_module.TestAllTypes()
204    nested = proto.optional_nested_message
205    del proto
206    # A previous version had a bug where this would raise an exception when
207    # hitting a now-dead weak reference.
208    nested.bb = 23
209
210  def testOneOf(self, message_module):
211    proto = message_module.TestAllTypes()
212    proto.oneof_uint32 = 10
213    proto.oneof_nested_message.bb = 11
214    self.assertEqual(11, proto.oneof_nested_message.bb)
215    self.assertFalse(proto.HasField('oneof_uint32'))
216    nested = proto.oneof_nested_message
217    proto.oneof_string = 'abc'
218    self.assertEqual('abc', proto.oneof_string)
219    self.assertEqual(11, nested.bb)
220    self.assertFalse(proto.HasField('oneof_nested_message'))
221
222  def testGetDefaultMessageAfterDisconnectingDefaultMessage(
223      self, message_module):
224    proto = message_module.TestAllTypes()
225    nested = proto.optional_nested_message
226    proto.ClearField('optional_nested_message')
227    del proto
228    del nested
229    # Force a garbage collect so that the underlying CMessages are freed along
230    # with the Messages they point to. This is to make sure we're not deleting
231    # default message instances.
232    gc.collect()
233    proto = message_module.TestAllTypes()
234    nested = proto.optional_nested_message
235
236  def testDisconnectingNestedMessageAfterSettingField(self, message_module):
237    proto = message_module.TestAllTypes()
238    nested = proto.optional_nested_message
239    nested.bb = 5
240    self.assertTrue(proto.HasField('optional_nested_message'))
241    proto.ClearField('optional_nested_message')  # Should disconnect from parent
242    self.assertEqual(5, nested.bb)
243    self.assertEqual(0, proto.optional_nested_message.bb)
244    self.assertIsNot(nested, proto.optional_nested_message)
245    nested.bb = 23
246    self.assertFalse(proto.HasField('optional_nested_message'))
247    self.assertEqual(0, proto.optional_nested_message.bb)
248
249  def testDisconnectingNestedMessageBeforeGettingField(self, message_module):
250    proto = message_module.TestAllTypes()
251    self.assertFalse(proto.HasField('optional_nested_message'))
252    proto.ClearField('optional_nested_message')
253    self.assertFalse(proto.HasField('optional_nested_message'))
254
255  def testDisconnectingNestedMessageAfterMerge(self, message_module):
256    # This test exercises the code path that does not use ReleaseMessage().
257    # The underlying fear is that if we use ReleaseMessage() incorrectly,
258    # we will have memory leaks.  It's hard to check that that doesn't happen,
259    # but at least we can exercise that code path to make sure it works.
260    proto1 = message_module.TestAllTypes()
261    proto2 = message_module.TestAllTypes()
262    proto2.optional_nested_message.bb = 5
263    proto1.MergeFrom(proto2)
264    self.assertTrue(proto1.HasField('optional_nested_message'))
265    proto1.ClearField('optional_nested_message')
266    self.assertFalse(proto1.HasField('optional_nested_message'))
267
268  def testDisconnectingLazyNestedMessage(self, message_module):
269    # This test exercises releasing a nested message that is lazy. This test
270    # only exercises real code in the C++ implementation as Python does not
271    # support lazy parsing, but the current C++ implementation results in
272    # memory corruption and a crash.
273    if api_implementation.Type() != 'python':
274      return
275    proto = message_module.TestAllTypes()
276    proto.optional_lazy_message.bb = 5
277    proto.ClearField('optional_lazy_message')
278    del proto
279    gc.collect()
280
281  def testSingularListFields(self, message_module):
282    proto = message_module.TestAllTypes()
283    proto.optional_fixed32 = 1
284    proto.optional_int32 = 5
285    proto.optional_string = 'foo'
286    # Access sub-message but don't set it yet.
287    nested_message = proto.optional_nested_message
288    self.assertEqual(
289      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
290        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
291        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
292      proto.ListFields())
293
294    proto.optional_nested_message.bb = 123
295    self.assertEqual(
296      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
297        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
298        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
299        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
300             nested_message) ],
301      proto.ListFields())
302
303  def testRepeatedListFields(self, message_module):
304    proto = message_module.TestAllTypes()
305    proto.repeated_fixed32.append(1)
306    proto.repeated_int32.append(5)
307    proto.repeated_int32.append(11)
308    proto.repeated_string.extend(['foo', 'bar'])
309    proto.repeated_string.extend([])
310    proto.repeated_string.append('baz')
311    proto.repeated_string.extend(str(x) for x in range(2))
312    proto.optional_int32 = 21
313    proto.repeated_bool  # Access but don't set anything; should not be listed.
314    self.assertEqual(
315      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
316        (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
317        (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
318        (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
319          ['foo', 'bar', 'baz', '0', '1']) ],
320      proto.ListFields())
321
322  def testClearFieldWithUnknownFieldName(self, message_module):
323    proto = message_module.TestAllTypes()
324    self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
325    self.assertRaises(ValueError, proto.ClearField, b'nonexistent_field')
326
327  def testDisallowedAssignments(self, message_module):
328    # It's illegal to assign values directly to repeated fields
329    # or to nonrepeated composite fields.  Ensure that this fails.
330    proto = message_module.TestAllTypes()
331    # Repeated fields.
332    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
333    # Lists shouldn't work, either.
334    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
335    # Composite fields.
336    self.assertRaises(AttributeError, setattr, proto,
337                      'optional_nested_message', 23)
338    # Assignment to a repeated nested message field without specifying
339    # the index in the array of nested messages.
340    self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
341                      'bb', 34)
342    # Assignment to an attribute of a repeated field.
343    self.assertRaises(AttributeError, setattr, proto.repeated_float,
344                      'some_attribute', 34)
345    # proto.nonexistent_field = 23 should fail as well.
346    self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
347
348  def testSingleScalarTypeSafety(self, message_module):
349    proto = message_module.TestAllTypes()
350    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
351    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
352    self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
353    self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
354    self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo')
355    self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo')
356    self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo')
357    # TODO: Fix type checking difference for python and c extension
358    if (api_implementation.Type() == 'python' or
359        (sys.version_info.major, sys.version_info.minor) >= (3, 10)):
360      self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1)
361    else:
362      proto.optional_bool = 1.1
363
364  def assertIntegerTypes(self, integer_fn, message_module):
365    """Verifies setting of scalar integers.
366
367    Args:
368      integer_fn: A function to wrap the integers that will be assigned.
369      message_module: unittest_pb2 or unittest_proto3_arena_pb2
370    """
371    def TestGetAndDeserialize(field_name, value, expected_type):
372      proto = message_module.TestAllTypes()
373      value = integer_fn(value)
374      setattr(proto, field_name, value)
375      self.assertIsInstance(getattr(proto, field_name), expected_type)
376      proto2 = message_module.TestAllTypes()
377      proto2.ParseFromString(proto.SerializeToString())
378      self.assertIsInstance(getattr(proto2, field_name), expected_type)
379
380    TestGetAndDeserialize('optional_int32', 1, int)
381    TestGetAndDeserialize('optional_int32', 1 << 30, int)
382    TestGetAndDeserialize('optional_uint32', 1 << 30, int)
383    integer_64 = int
384    if struct.calcsize('L') == 4:
385      # Python only has signed ints, so 32-bit python can't fit an uint32
386      # in an int.
387      TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64)
388    else:
389      # 64-bit python can fit uint32 inside an int
390      TestGetAndDeserialize('optional_uint32', 1 << 31, int)
391    TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
392    TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
393    TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
394    TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
395
396  def testIntegerTypes(self, message_module):
397    self.assertIntegerTypes(lambda x: x, message_module)
398
399  def testNonStandardIntegerTypes(self, message_module):
400    self.assertIntegerTypes(test_util.NonStandardInteger, message_module)
401
402  def testIllegalValuesForIntegers(self, message_module):
403    pb = message_module.TestAllTypes()
404
405    # Strings are illegal, even when the represent an integer.
406    with self.assertRaises(TypeError):
407      pb.optional_uint64 = '2'
408
409    # The exact error should propagate with a poorly written custom integer.
410    with self.assertRaisesRegex(RuntimeError, 'my_error'):
411      pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
412
413  def assetIntegerBoundsChecking(self, integer_fn, message_module):
414    """Verifies bounds checking for scalar integer fields.
415
416    Args:
417      integer_fn: A function to wrap the integers that will be assigned.
418      message_module: unittest_pb2 or unittest_proto3_arena_pb2
419    """
420    def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
421      pb = message_module.TestAllTypes()
422      expected_min = integer_fn(expected_min)
423      expected_max = integer_fn(expected_max)
424      setattr(pb, field_name, expected_min)
425      self.assertEqual(expected_min, getattr(pb, field_name))
426      setattr(pb, field_name, expected_max)
427      self.assertEqual(expected_max, getattr(pb, field_name))
428      self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
429                        expected_min - 1)
430      self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
431                        expected_max + 1)
432
433    TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
434    TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
435    TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
436    TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
437    # A bit of white-box testing since -1 is an int and not a long in C++ and
438    # so goes down a different path.
439    pb = message_module.TestAllTypes()
440    with self.assertRaises((ValueError, TypeError)):
441      pb.optional_uint64 = integer_fn(-(1 << 63))
442
443    pb = message_module.TestAllTypes()
444    pb.optional_nested_enum = integer_fn(1)
445    self.assertEqual(1, pb.optional_nested_enum)
446
447  def testSingleScalarBoundsChecking(self, message_module):
448    self.assetIntegerBoundsChecking(lambda x: x, message_module)
449
450  def testNonStandardSingleScalarBoundsChecking(self, message_module):
451    self.assetIntegerBoundsChecking(
452        test_util.NonStandardInteger, message_module)
453
454  def testRepeatedScalarTypeSafety(self, message_module):
455    proto = message_module.TestAllTypes()
456    self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
457    self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
458    self.assertRaises(TypeError, proto.repeated_string, 10)
459    self.assertRaises(TypeError, proto.repeated_bytes, 10)
460
461    proto.repeated_int32.append(10)
462    proto.repeated_int32[0] = 23
463    self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
464    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
465    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, [])
466    self.assertRaises(TypeError, proto.repeated_int32.__setitem__,
467                      'index', 23)
468
469    proto.repeated_string.append('2')
470    self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10)
471
472    # Repeated enums tests.
473    # proto.repeated_nested_enum.append(0)
474
475  def testSingleScalarGettersAndSetters(self, message_module):
476    proto = message_module.TestAllTypes()
477    self.assertEqual(0, proto.optional_int32)
478    proto.optional_int32 = 1
479    self.assertEqual(1, proto.optional_int32)
480
481    proto.optional_uint64 = 0xffffffffffff
482    self.assertEqual(0xffffffffffff, proto.optional_uint64)
483    proto.optional_uint64 = 0xffffffffffffffff
484    self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
485    # TODO: Test all other scalar field types.
486
487  def testEnums(self, message_module):
488    proto = message_module.TestAllTypes()
489    self.assertEqual(1, proto.FOO)
490    self.assertEqual(1, message_module.TestAllTypes.FOO)
491    self.assertEqual(2, proto.BAR)
492    self.assertEqual(2, message_module.TestAllTypes.BAR)
493    self.assertEqual(3, proto.BAZ)
494    self.assertEqual(3, message_module.TestAllTypes.BAZ)
495
496  def testEnum_Name(self, message_module):
497    self.assertEqual(
498        'FOREIGN_FOO',
499        message_module.ForeignEnum.Name(message_module.FOREIGN_FOO))
500    self.assertEqual(
501        'FOREIGN_BAR',
502        message_module.ForeignEnum.Name(message_module.FOREIGN_BAR))
503    self.assertEqual(
504        'FOREIGN_BAZ',
505        message_module.ForeignEnum.Name(message_module.FOREIGN_BAZ))
506    self.assertRaises(ValueError,
507                      message_module.ForeignEnum.Name, 11312)
508
509    proto = message_module.TestAllTypes()
510    self.assertEqual('FOO',
511                     proto.NestedEnum.Name(proto.FOO))
512    self.assertEqual('FOO',
513                     message_module.TestAllTypes.NestedEnum.Name(proto.FOO))
514    self.assertEqual('BAR',
515                     proto.NestedEnum.Name(proto.BAR))
516    self.assertEqual('BAR',
517                     message_module.TestAllTypes.NestedEnum.Name(proto.BAR))
518    self.assertEqual('BAZ',
519                     proto.NestedEnum.Name(proto.BAZ))
520    self.assertEqual('BAZ',
521                     message_module.TestAllTypes.NestedEnum.Name(proto.BAZ))
522    self.assertRaises(ValueError,
523                      proto.NestedEnum.Name, 11312)
524    self.assertRaises(ValueError,
525                      message_module.TestAllTypes.NestedEnum.Name, 11312)
526
527    # Check some coercion cases.
528    self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name,
529                      11312.0)
530    self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name,
531                      None)
532    self.assertEqual('FOO', message_module.TestAllTypes.NestedEnum.Name(True))
533
534  def testEnum_Value(self, message_module):
535    self.assertEqual(message_module.FOREIGN_FOO,
536                     message_module.ForeignEnum.Value('FOREIGN_FOO'))
537    self.assertEqual(message_module.FOREIGN_FOO,
538                     message_module.ForeignEnum.FOREIGN_FOO)
539
540    self.assertEqual(message_module.FOREIGN_BAR,
541                     message_module.ForeignEnum.Value('FOREIGN_BAR'))
542    self.assertEqual(message_module.FOREIGN_BAR,
543                     message_module.ForeignEnum.FOREIGN_BAR)
544
545    self.assertEqual(message_module.FOREIGN_BAZ,
546                     message_module.ForeignEnum.Value('FOREIGN_BAZ'))
547    self.assertEqual(message_module.FOREIGN_BAZ,
548                     message_module.ForeignEnum.FOREIGN_BAZ)
549
550    self.assertRaises(ValueError,
551                      message_module.ForeignEnum.Value, 'FO')
552    with self.assertRaises(AttributeError):
553      message_module.ForeignEnum.FO
554
555    proto = message_module.TestAllTypes()
556    self.assertEqual(proto.FOO,
557                     proto.NestedEnum.Value('FOO'))
558    self.assertEqual(proto.FOO,
559                     proto.NestedEnum.FOO)
560
561    self.assertEqual(proto.FOO,
562                     message_module.TestAllTypes.NestedEnum.Value('FOO'))
563    self.assertEqual(proto.FOO,
564                     message_module.TestAllTypes.NestedEnum.FOO)
565
566    self.assertEqual(proto.BAR,
567                     proto.NestedEnum.Value('BAR'))
568    self.assertEqual(proto.BAR,
569                     proto.NestedEnum.BAR)
570
571    self.assertEqual(proto.BAR,
572                     message_module.TestAllTypes.NestedEnum.Value('BAR'))
573    self.assertEqual(proto.BAR,
574                     message_module.TestAllTypes.NestedEnum.BAR)
575
576    self.assertEqual(proto.BAZ,
577                     proto.NestedEnum.Value('BAZ'))
578    self.assertEqual(proto.BAZ,
579                     proto.NestedEnum.BAZ)
580
581    self.assertEqual(proto.BAZ,
582                     message_module.TestAllTypes.NestedEnum.Value('BAZ'))
583    self.assertEqual(proto.BAZ,
584                     message_module.TestAllTypes.NestedEnum.BAZ)
585
586    self.assertRaises(ValueError,
587                      proto.NestedEnum.Value, 'Foo')
588    with self.assertRaises(AttributeError):
589      proto.NestedEnum.Value.Foo
590
591    self.assertRaises(ValueError,
592                      message_module.TestAllTypes.NestedEnum.Value, 'Foo')
593    with self.assertRaises(AttributeError):
594      message_module.TestAllTypes.NestedEnum.Value.Foo
595
596  def testEnum_KeysAndValues(self, message_module):
597    if message_module == unittest_pb2:
598      keys = [
599          'FOREIGN_FOO',
600          'FOREIGN_BAR',
601          'FOREIGN_BAZ',
602          'FOREIGN_BAX',
603          'FOREIGN_LARGE',
604      ]
605      values = [4, 5, 6, 32, 123456]
606      items = [
607          ('FOREIGN_FOO', 4),
608          ('FOREIGN_BAR', 5),
609          ('FOREIGN_BAZ', 6),
610          ('FOREIGN_BAX', 32),
611          ('FOREIGN_LARGE', 123456),
612      ]
613    else:
614      keys = [
615          'FOREIGN_ZERO',
616          'FOREIGN_FOO',
617          'FOREIGN_BAR',
618          'FOREIGN_BAZ',
619          'FOREIGN_LARGE',
620      ]
621      values = [0, 4, 5, 6, 123456]
622      items = [
623          ('FOREIGN_ZERO', 0),
624          ('FOREIGN_FOO', 4),
625          ('FOREIGN_BAR', 5),
626          ('FOREIGN_BAZ', 6),
627          ('FOREIGN_LARGE', 123456),
628      ]
629    self.assertEqual(keys,
630                     list(message_module.ForeignEnum.keys()))
631    self.assertEqual(values,
632                     list(message_module.ForeignEnum.values()))
633    self.assertEqual(items,
634                     list(message_module.ForeignEnum.items()))
635
636    proto = message_module.TestAllTypes()
637    if message_module == unittest_pb2:
638      keys = ['FOO', 'BAR', 'BAZ', 'NEG']
639      values = [1, 2, 3, -1]
640      items = [('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)]
641    else:
642      keys = ['ZERO', 'FOO', 'BAR', 'BAZ', 'NEG']
643      values = [0, 1, 2, 3, -1]
644      items = [('ZERO', 0), ('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)]
645    self.assertEqual(keys, list(proto.NestedEnum.keys()))
646    self.assertEqual(values, list(proto.NestedEnum.values()))
647    self.assertEqual(items,
648                     list(proto.NestedEnum.items()))
649
650  def testStaticParseFrom(self, message_module):
651    proto1 = message_module.TestAllTypes()
652    test_util.SetAllFields(proto1)
653
654    string1 = proto1.SerializeToString()
655    proto2 = message_module.TestAllTypes.FromString(string1)
656
657    # Messages should be equal.
658    self.assertEqual(proto2, proto1)
659
660  def testMergeFromSingularField(self, message_module):
661    # Test merge with just a singular field.
662    proto1 = message_module.TestAllTypes()
663    proto1.optional_int32 = 1
664
665    proto2 = message_module.TestAllTypes()
666    # This shouldn't get overwritten.
667    proto2.optional_string = 'value'
668
669    proto2.MergeFrom(proto1)
670    self.assertEqual(1, proto2.optional_int32)
671    self.assertEqual('value', proto2.optional_string)
672
673  def testMergeFromRepeatedField(self, message_module):
674    # Test merge with just a repeated field.
675    proto1 = message_module.TestAllTypes()
676    proto1.repeated_int32.append(1)
677    proto1.repeated_int32.append(2)
678
679    proto2 = message_module.TestAllTypes()
680    proto2.repeated_int32.append(0)
681    proto2.MergeFrom(proto1)
682
683    self.assertEqual(0, proto2.repeated_int32[0])
684    self.assertEqual(1, proto2.repeated_int32[1])
685    self.assertEqual(2, proto2.repeated_int32[2])
686
687  def testMergeFromRepeatedNestedMessage(self, message_module):
688    # Test merge with a repeated nested message.
689    proto1 = message_module.TestAllTypes()
690    m = proto1.repeated_nested_message.add()
691    m.bb = 123
692    m = proto1.repeated_nested_message.add()
693    m.bb = 321
694
695    proto2 = message_module.TestAllTypes()
696    m = proto2.repeated_nested_message.add()
697    m.bb = 999
698    proto2.MergeFrom(proto1)
699    self.assertEqual(999, proto2.repeated_nested_message[0].bb)
700    self.assertEqual(123, proto2.repeated_nested_message[1].bb)
701    self.assertEqual(321, proto2.repeated_nested_message[2].bb)
702
703    proto3 = message_module.TestAllTypes()
704    proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
705    self.assertEqual(999, proto3.repeated_nested_message[0].bb)
706    self.assertEqual(123, proto3.repeated_nested_message[1].bb)
707    self.assertEqual(321, proto3.repeated_nested_message[2].bb)
708
709  def testMergeFromAllFields(self, message_module):
710    # With all fields set.
711    proto1 = message_module.TestAllTypes()
712    test_util.SetAllFields(proto1)
713    proto2 = message_module.TestAllTypes()
714    proto2.MergeFrom(proto1)
715
716    # Messages should be equal.
717    self.assertEqual(proto2, proto1)
718
719    # Serialized string should be equal too.
720    string1 = proto1.SerializeToString()
721    string2 = proto2.SerializeToString()
722    self.assertEqual(string1, string2)
723
724  def testMergeFromBug(self, message_module):
725    message1 = message_module.TestAllTypes()
726    message2 = message_module.TestAllTypes()
727
728    # Cause optional_nested_message to be instantiated within message1, even
729    # though it is not considered to be "present".
730    message1.optional_nested_message
731    self.assertFalse(message1.HasField('optional_nested_message'))
732
733    # Merge into message2.  This should not instantiate the field is message2.
734    message2.MergeFrom(message1)
735    self.assertFalse(message2.HasField('optional_nested_message'))
736
737  def testCopyFromSingularField(self, message_module):
738    # Test copy with just a singular field.
739    proto1 = message_module.TestAllTypes()
740    proto1.optional_int32 = 1
741    proto1.optional_string = 'important-text'
742
743    proto2 = message_module.TestAllTypes()
744    proto2.optional_string = 'value'
745
746    proto2.CopyFrom(proto1)
747    self.assertEqual(1, proto2.optional_int32)
748    self.assertEqual('important-text', proto2.optional_string)
749
750  def testCopyFromRepeatedField(self, message_module):
751    # Test copy with a repeated field.
752    proto1 = message_module.TestAllTypes()
753    proto1.repeated_int32.append(1)
754    proto1.repeated_int32.append(2)
755
756    proto2 = message_module.TestAllTypes()
757    proto2.repeated_int32.append(0)
758    proto2.CopyFrom(proto1)
759
760    self.assertEqual(1, proto2.repeated_int32[0])
761    self.assertEqual(2, proto2.repeated_int32[1])
762
763  def testCopyFromAllFields(self, message_module):
764    # With all fields set.
765    proto1 = message_module.TestAllTypes()
766    test_util.SetAllFields(proto1)
767    proto2 = message_module.TestAllTypes()
768    proto2.CopyFrom(proto1)
769
770    # Messages should be equal.
771    self.assertEqual(proto2, proto1)
772
773    # Serialized string should be equal too.
774    string1 = proto1.SerializeToString()
775    string2 = proto2.SerializeToString()
776    self.assertEqual(string1, string2)
777
778  def testCopyFromSelf(self, message_module):
779    proto1 = message_module.TestAllTypes()
780    proto1.repeated_int32.append(1)
781    proto1.optional_int32 = 2
782    proto1.optional_string = 'important-text'
783
784    proto1.CopyFrom(proto1)
785    self.assertEqual(1, proto1.repeated_int32[0])
786    self.assertEqual(2, proto1.optional_int32)
787    self.assertEqual('important-text', proto1.optional_string)
788
789  def testDeepCopy(self, message_module):
790    proto1 = message_module.TestAllTypes()
791    proto1.optional_int32 = 1
792    proto2 = copy.deepcopy(proto1)
793    self.assertEqual(1, proto2.optional_int32)
794
795    proto1.repeated_int32.append(2)
796    proto1.repeated_int32.append(3)
797    container = copy.deepcopy(proto1.repeated_int32)
798    self.assertEqual([2, 3], container)
799    container.remove(container[0])
800    self.assertEqual([3], container)
801
802    message1 = proto1.repeated_nested_message.add()
803    message1.bb = 1
804    messages = copy.deepcopy(proto1.repeated_nested_message)
805    self.assertEqual(proto1.repeated_nested_message, messages)
806    message1.bb = 2
807    self.assertNotEqual(proto1.repeated_nested_message, messages)
808    messages.remove(messages[0])
809    self.assertEqual(len(messages), 0)
810
811  def testEmptyDeepCopy(self, message_module):
812    proto1 = message_module.TestAllTypes()
813    nested2 = copy.deepcopy(proto1.optional_nested_message)
814    self.assertEqual(0, nested2.bb)
815
816    # TODO: Implement deepcopy for extension dict
817
818  def testDisconnectingBeforeClear(self, message_module):
819    proto = message_module.TestAllTypes()
820    nested = proto.optional_nested_message
821    proto.Clear()
822    self.assertIsNot(nested, proto.optional_nested_message)
823    nested.bb = 23
824    self.assertFalse(proto.HasField('optional_nested_message'))
825    self.assertEqual(0, proto.optional_nested_message.bb)
826
827    proto = message_module.TestAllTypes()
828    nested = proto.optional_nested_message
829    nested.bb = 5
830    foreign = proto.optional_foreign_message
831    foreign.c = 6
832    proto.Clear()
833    self.assertIsNot(nested, proto.optional_nested_message)
834    self.assertIsNot(foreign, proto.optional_foreign_message)
835    self.assertEqual(5, nested.bb)
836    self.assertEqual(6, foreign.c)
837    nested.bb = 15
838    foreign.c = 16
839    self.assertFalse(proto.HasField('optional_nested_message'))
840    self.assertEqual(0, proto.optional_nested_message.bb)
841    self.assertFalse(proto.HasField('optional_foreign_message'))
842    self.assertEqual(0, proto.optional_foreign_message.c)
843
844  def testStringUTF8Encoding(self, message_module):
845    proto = message_module.TestAllTypes()
846
847    # Assignment of a unicode object to a field of type 'bytes' is not allowed.
848    self.assertRaises(TypeError,
849                      setattr, proto, 'optional_bytes', u'unicode object')
850
851    # Check that the default value is of python's 'unicode' type.
852    self.assertEqual(type(proto.optional_string), str)
853
854    proto.optional_string = str('Testing')
855    self.assertEqual(proto.optional_string, str('Testing'))
856
857    # Assign a value of type 'str' which can be encoded in UTF-8.
858    proto.optional_string = str('Testing')
859    self.assertEqual(proto.optional_string, str('Testing'))
860
861    # Try to assign a 'bytes' object which contains non-UTF-8.
862    self.assertRaises(ValueError,
863                      setattr, proto, 'optional_string', b'a\x80a')
864    # No exception: Assign already encoded UTF-8 bytes to a string field.
865    utf8_bytes = u'Тест'.encode('utf-8')
866    proto.optional_string = utf8_bytes
867    # No exception: Assign the a non-ascii unicode object.
868    proto.optional_string = u'Тест'
869    # No exception thrown (normal str assignment containing ASCII).
870    proto.optional_string = 'abc'
871
872  def testBytesInTextFormat(self, message_module):
873    proto = message_module.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
874    self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', str(proto))
875
876  def testEmptyNestedMessage(self, message_module):
877    proto = message_module.TestAllTypes()
878    proto.optional_nested_message.MergeFrom(
879        message_module.TestAllTypes.NestedMessage())
880    self.assertTrue(proto.HasField('optional_nested_message'))
881
882    proto = message_module.TestAllTypes()
883    proto.optional_nested_message.CopyFrom(
884        message_module.TestAllTypes.NestedMessage())
885    self.assertTrue(proto.HasField('optional_nested_message'))
886
887    proto = message_module.TestAllTypes()
888    bytes_read = proto.optional_nested_message.MergeFromString(b'')
889    self.assertEqual(0, bytes_read)
890    self.assertTrue(proto.HasField('optional_nested_message'))
891
892    proto = message_module.TestAllTypes()
893    proto.optional_nested_message.ParseFromString(b'')
894    self.assertTrue(proto.HasField('optional_nested_message'))
895
896    serialized = proto.SerializeToString()
897    proto2 = message_module.TestAllTypes()
898    self.assertEqual(
899        len(serialized),
900        proto2.MergeFromString(serialized))
901    self.assertTrue(proto2.HasField('optional_nested_message'))
902
903
904# Class to test proto2-only features (required, extensions, etc.)
905@testing_refleaks.TestCase
906class Proto2ReflectionTest(unittest.TestCase):
907
908  def testRepeatedCompositeConstructor(self):
909    # Constructor with only repeated composite types should succeed.
910    proto = unittest_pb2.TestAllTypes(
911        repeated_nested_message=[
912            unittest_pb2.TestAllTypes.NestedMessage(
913                bb=unittest_pb2.TestAllTypes.FOO),
914            unittest_pb2.TestAllTypes.NestedMessage(
915                bb=unittest_pb2.TestAllTypes.BAR)],
916        repeated_foreign_message=[
917            unittest_pb2.ForeignMessage(c=-43),
918            unittest_pb2.ForeignMessage(c=45324),
919            unittest_pb2.ForeignMessage(c=12)],
920        repeatedgroup=[
921            unittest_pb2.TestAllTypes.RepeatedGroup(),
922            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
923            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
924
925    self.assertEqual(
926        [unittest_pb2.TestAllTypes.NestedMessage(
927            bb=unittest_pb2.TestAllTypes.FOO),
928         unittest_pb2.TestAllTypes.NestedMessage(
929             bb=unittest_pb2.TestAllTypes.BAR)],
930        list(proto.repeated_nested_message))
931    self.assertEqual(
932        [unittest_pb2.ForeignMessage(c=-43),
933         unittest_pb2.ForeignMessage(c=45324),
934         unittest_pb2.ForeignMessage(c=12)],
935        list(proto.repeated_foreign_message))
936    self.assertEqual(
937        [unittest_pb2.TestAllTypes.RepeatedGroup(),
938         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
939         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
940        list(proto.repeatedgroup))
941
942  def assertListsEqual(self, values, others):
943    self.assertEqual(len(values), len(others))
944    for i in range(len(values)):
945      self.assertEqual(values[i], others[i])
946
947  def testSimpleHasBits(self):
948    # Test a scalar.
949    proto = unittest_pb2.TestAllTypes()
950    self.assertFalse(proto.HasField('optional_int32'))
951    self.assertEqual(0, proto.optional_int32)
952    # HasField() shouldn't be true if all we've done is
953    # read the default value.
954    self.assertFalse(proto.HasField('optional_int32'))
955    proto.optional_int32 = 1
956    # Setting a value however *should* set the "has" bit.
957    self.assertTrue(proto.HasField('optional_int32'))
958    proto.ClearField('optional_int32')
959    # And clearing that value should unset the "has" bit.
960    self.assertFalse(proto.HasField('optional_int32'))
961
962  def testHasBitsWithSinglyNestedScalar(self):
963    # Helper used to test foreign messages and groups.
964    #
965    # composite_field_name should be the name of a non-repeated
966    # composite (i.e., foreign or group) field in TestAllTypes,
967    # and scalar_field_name should be the name of an integer-valued
968    # scalar field within that composite.
969    #
970    # I never thought I'd miss C++ macros and templates so much. :(
971    # This helper is semantically just:
972    #
973    #   assert proto.composite_field.scalar_field == 0
974    #   assert not proto.composite_field.HasField('scalar_field')
975    #   assert not proto.HasField('composite_field')
976    #
977    #   proto.composite_field.scalar_field = 10
978    #   old_composite_field = proto.composite_field
979    #
980    #   assert proto.composite_field.scalar_field == 10
981    #   assert proto.composite_field.HasField('scalar_field')
982    #   assert proto.HasField('composite_field')
983    #
984    #   proto.ClearField('composite_field')
985    #
986    #   assert not proto.composite_field.HasField('scalar_field')
987    #   assert not proto.HasField('composite_field')
988    #   assert proto.composite_field.scalar_field == 0
989    #
990    #   # Now ensure that ClearField('composite_field') disconnected
991    #   # the old field object from the object tree...
992    #   assert old_composite_field is not proto.composite_field
993    #   old_composite_field.scalar_field = 20
994    #   assert not proto.composite_field.HasField('scalar_field')
995    #   assert not proto.HasField('composite_field')
996    def TestCompositeHasBits(composite_field_name, scalar_field_name):
997      proto = unittest_pb2.TestAllTypes()
998      # First, check that we can get the scalar value, and see that it's the
999      # default (0), but that proto.HasField('omposite') and
1000      # proto.composite.HasField('scalar') will still return False.
1001      composite_field = getattr(proto, composite_field_name)
1002      original_scalar_value = getattr(composite_field, scalar_field_name)
1003      self.assertEqual(0, original_scalar_value)
1004      # Assert that the composite object does not "have" the scalar.
1005      self.assertFalse(composite_field.HasField(scalar_field_name))
1006      # Assert that proto does not "have" the composite field.
1007      self.assertFalse(proto.HasField(composite_field_name))
1008
1009      # Now set the scalar within the composite field.  Ensure that the setting
1010      # is reflected, and that proto.HasField('composite') and
1011      # proto.composite.HasField('scalar') now both return True.
1012      new_val = 20
1013      setattr(composite_field, scalar_field_name, new_val)
1014      self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
1015      # Hold on to a reference to the current composite_field object.
1016      old_composite_field = composite_field
1017      # Assert that the has methods now return true.
1018      self.assertTrue(composite_field.HasField(scalar_field_name))
1019      self.assertTrue(proto.HasField(composite_field_name))
1020
1021      # Now call the clear method...
1022      proto.ClearField(composite_field_name)
1023
1024      # ...and ensure that the "has" bits are all back to False...
1025      composite_field = getattr(proto, composite_field_name)
1026      self.assertFalse(composite_field.HasField(scalar_field_name))
1027      self.assertFalse(proto.HasField(composite_field_name))
1028      # ...and ensure that the scalar field has returned to its default.
1029      self.assertEqual(0, getattr(composite_field, scalar_field_name))
1030
1031      self.assertIsNot(old_composite_field, composite_field)
1032      setattr(old_composite_field, scalar_field_name, new_val)
1033      self.assertFalse(composite_field.HasField(scalar_field_name))
1034      self.assertFalse(proto.HasField(composite_field_name))
1035      self.assertEqual(0, getattr(composite_field, scalar_field_name))
1036
1037    # Test simple, single-level nesting when we set a scalar.
1038    TestCompositeHasBits('optionalgroup', 'a')
1039    TestCompositeHasBits('optional_nested_message', 'bb')
1040    TestCompositeHasBits('optional_foreign_message', 'c')
1041    TestCompositeHasBits('optional_import_message', 'd')
1042
1043  def testHasBitsWhenModifyingRepeatedFields(self):
1044    # Test nesting when we add an element to a repeated field in a submessage.
1045    proto = unittest_pb2.TestNestedMessageHasBits()
1046    proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
1047    self.assertEqual(
1048        [5], proto.optional_nested_message.nestedmessage_repeated_int32)
1049    self.assertTrue(proto.HasField('optional_nested_message'))
1050
1051    # Do the same test, but with a repeated composite field within the
1052    # submessage.
1053    proto.ClearField('optional_nested_message')
1054    self.assertFalse(proto.HasField('optional_nested_message'))
1055    proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
1056    self.assertTrue(proto.HasField('optional_nested_message'))
1057
1058  def testHasBitsForManyLevelsOfNesting(self):
1059    # Test nesting many levels deep.
1060    recursive_proto = unittest_pb2.TestMutualRecursionA()
1061    self.assertFalse(recursive_proto.HasField('bb'))
1062    self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
1063    self.assertFalse(recursive_proto.HasField('bb'))
1064    recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
1065    self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
1066    self.assertTrue(recursive_proto.HasField('bb'))
1067    self.assertTrue(recursive_proto.bb.HasField('a'))
1068    self.assertTrue(recursive_proto.bb.a.HasField('bb'))
1069    self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
1070    self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
1071    self.assertFalse(recursive_proto.bb.a.bb.a.bb.HasField('a'))
1072    self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
1073
1074  def testSingularListExtensions(self):
1075    proto = unittest_pb2.TestAllExtensions()
1076    proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
1077    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 5
1078    proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
1079    self.assertEqual(
1080      [ (unittest_pb2.optional_int32_extension  , 5),
1081        (unittest_pb2.optional_fixed32_extension, 1),
1082        (unittest_pb2.optional_string_extension , 'foo') ],
1083      proto.ListFields())
1084    del proto.Extensions[unittest_pb2.optional_fixed32_extension]
1085    self.assertEqual(
1086        [(unittest_pb2.optional_int32_extension, 5),
1087         (unittest_pb2.optional_string_extension, 'foo')],
1088        proto.ListFields())
1089
1090  def testRepeatedListExtensions(self):
1091    proto = unittest_pb2.TestAllExtensions()
1092    proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
1093    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(5)
1094    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(11)
1095    proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
1096    proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
1097    proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
1098    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 21
1099    self.assertEqual(
1100      [ (unittest_pb2.optional_int32_extension  , 21),
1101        (unittest_pb2.repeated_int32_extension  , [5, 11]),
1102        (unittest_pb2.repeated_fixed32_extension, [1]),
1103        (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
1104      proto.ListFields())
1105    del proto.Extensions[unittest_pb2.repeated_int32_extension]
1106    del proto.Extensions[unittest_pb2.repeated_string_extension]
1107    self.assertEqual(
1108        [(unittest_pb2.optional_int32_extension, 21),
1109         (unittest_pb2.repeated_fixed32_extension, [1])],
1110        proto.ListFields())
1111
1112  def testListFieldsAndExtensions(self):
1113    proto = unittest_pb2.TestFieldOrderings()
1114    test_util.SetAllFieldsAndExtensions(proto)
1115    unittest_pb2.my_extension_int
1116    self.assertEqual(
1117      [ (proto.DESCRIPTOR.fields_by_name['my_int'   ], 1),
1118        (unittest_pb2.my_extension_int               , 23),
1119        (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
1120        (unittest_pb2.my_extension_string            , 'bar'),
1121        (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
1122      proto.ListFields())
1123
1124  def testDefaultValues(self):
1125    proto = unittest_pb2.TestAllTypes()
1126    self.assertEqual(0, proto.optional_int32)
1127    self.assertEqual(0, proto.optional_int64)
1128    self.assertEqual(0, proto.optional_uint32)
1129    self.assertEqual(0, proto.optional_uint64)
1130    self.assertEqual(0, proto.optional_sint32)
1131    self.assertEqual(0, proto.optional_sint64)
1132    self.assertEqual(0, proto.optional_fixed32)
1133    self.assertEqual(0, proto.optional_fixed64)
1134    self.assertEqual(0, proto.optional_sfixed32)
1135    self.assertEqual(0, proto.optional_sfixed64)
1136    self.assertEqual(0.0, proto.optional_float)
1137    self.assertEqual(0.0, proto.optional_double)
1138    self.assertEqual(False, proto.optional_bool)
1139    self.assertEqual('', proto.optional_string)
1140    self.assertEqual(b'', proto.optional_bytes)
1141
1142    self.assertEqual(41, proto.default_int32)
1143    self.assertEqual(42, proto.default_int64)
1144    self.assertEqual(43, proto.default_uint32)
1145    self.assertEqual(44, proto.default_uint64)
1146    self.assertEqual(-45, proto.default_sint32)
1147    self.assertEqual(46, proto.default_sint64)
1148    self.assertEqual(47, proto.default_fixed32)
1149    self.assertEqual(48, proto.default_fixed64)
1150    self.assertEqual(49, proto.default_sfixed32)
1151    self.assertEqual(-50, proto.default_sfixed64)
1152    self.assertEqual(51.5, proto.default_float)
1153    self.assertEqual(52e3, proto.default_double)
1154    self.assertEqual(True, proto.default_bool)
1155    self.assertEqual('hello', proto.default_string)
1156    self.assertEqual(b'world', proto.default_bytes)
1157    self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
1158    self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
1159    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
1160                     proto.default_import_enum)
1161
1162    proto = unittest_pb2.TestExtremeDefaultValues()
1163    self.assertEqual(u'\u1234', proto.utf8_string)
1164
1165  def testHasFieldWithUnknownFieldName(self):
1166    proto = unittest_pb2.TestAllTypes()
1167    self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
1168
1169  def testClearRemovesChildren(self):
1170    # Make sure there aren't any implementation bugs that are only partially
1171    # clearing the message (which can happen in the more complex C++
1172    # implementation which has parallel message lists).
1173    proto = unittest_pb2.TestRequiredForeign()
1174    for i in range(10):
1175      proto.repeated_message.add()
1176    proto2 = unittest_pb2.TestRequiredForeign()
1177    proto.CopyFrom(proto2)
1178    self.assertRaises(IndexError, lambda: proto.repeated_message[5])
1179
1180  def testSingleScalarClearField(self):
1181    proto = unittest_pb2.TestAllTypes()
1182    # Should be allowed to clear something that's not there (a no-op).
1183    proto.ClearField('optional_int32')
1184    proto.optional_int32 = 1
1185    self.assertTrue(proto.HasField('optional_int32'))
1186    proto.ClearField('optional_int32')
1187    self.assertEqual(0, proto.optional_int32)
1188    self.assertFalse(proto.HasField('optional_int32'))
1189    # TODO: Test all other scalar field types.
1190
1191  def testRepeatedScalars(self):
1192    proto = unittest_pb2.TestAllTypes()
1193
1194    self.assertFalse(proto.repeated_int32)
1195    self.assertEqual(0, len(proto.repeated_int32))
1196    proto.repeated_int32.append(5)
1197    proto.repeated_int32.append(10)
1198    proto.repeated_int32.append(15)
1199    self.assertTrue(proto.repeated_int32)
1200    self.assertEqual(3, len(proto.repeated_int32))
1201
1202    self.assertEqual([5, 10, 15], proto.repeated_int32)
1203
1204    # Test single retrieval.
1205    self.assertEqual(5, proto.repeated_int32[0])
1206    self.assertEqual(15, proto.repeated_int32[-1])
1207    # Test out-of-bounds indices.
1208    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
1209    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
1210    # Test incorrect types passed to __getitem__.
1211    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
1212    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
1213
1214    # Test single assignment.
1215    proto.repeated_int32[1] = 20
1216    self.assertEqual([5, 20, 15], proto.repeated_int32)
1217
1218    # Test insertion.
1219    proto.repeated_int32.insert(1, 25)
1220    self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
1221
1222    # Test slice retrieval.
1223    proto.repeated_int32.append(30)
1224    self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
1225    self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
1226
1227    # Test slice assignment with an iterator
1228    proto.repeated_int32[1:4] = (i for i in range(3))
1229    self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
1230
1231    # Test slice assignment.
1232    proto.repeated_int32[1:4] = [35, 40, 45]
1233    self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
1234
1235    # Test that we can use the field as an iterator.
1236    result = []
1237    for i in proto.repeated_int32:
1238      result.append(i)
1239    self.assertEqual([5, 35, 40, 45, 30], result)
1240
1241    # Test single deletion.
1242    del proto.repeated_int32[2]
1243    self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
1244
1245    # Test slice deletion.
1246    del proto.repeated_int32[2:]
1247    self.assertEqual([5, 35], proto.repeated_int32)
1248
1249    # Test extending.
1250    proto.repeated_int32.extend([3, 13])
1251    self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
1252
1253    # Test clearing.
1254    proto.ClearField('repeated_int32')
1255    self.assertFalse(proto.repeated_int32)
1256    self.assertEqual(0, len(proto.repeated_int32))
1257
1258    proto.repeated_int32.append(1)
1259    self.assertEqual(1, proto.repeated_int32[-1])
1260    # Test assignment to a negative index.
1261    proto.repeated_int32[-1] = 2
1262    self.assertEqual(2, proto.repeated_int32[-1])
1263
1264    # Test deletion at negative indices.
1265    proto.repeated_int32[:] = [0, 1, 2, 3]
1266    del proto.repeated_int32[-1]
1267    self.assertEqual([0, 1, 2], proto.repeated_int32)
1268
1269    del proto.repeated_int32[-2]
1270    self.assertEqual([0, 2], proto.repeated_int32)
1271
1272    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
1273    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
1274
1275    del proto.repeated_int32[-2:-1]
1276    self.assertEqual([2], proto.repeated_int32)
1277
1278    del proto.repeated_int32[100:10000]
1279    self.assertEqual([2], proto.repeated_int32)
1280
1281  def testRepeatedScalarsRemove(self):
1282    proto = unittest_pb2.TestAllTypes()
1283
1284    self.assertFalse(proto.repeated_int32)
1285    self.assertEqual(0, len(proto.repeated_int32))
1286    proto.repeated_int32.append(5)
1287    proto.repeated_int32.append(10)
1288    proto.repeated_int32.append(5)
1289    proto.repeated_int32.append(5)
1290
1291    self.assertEqual(4, len(proto.repeated_int32))
1292    proto.repeated_int32.remove(5)
1293    self.assertEqual(3, len(proto.repeated_int32))
1294    self.assertEqual(10, proto.repeated_int32[0])
1295    self.assertEqual(5, proto.repeated_int32[1])
1296    self.assertEqual(5, proto.repeated_int32[2])
1297
1298    proto.repeated_int32.remove(5)
1299    self.assertEqual(2, len(proto.repeated_int32))
1300    self.assertEqual(10, proto.repeated_int32[0])
1301    self.assertEqual(5, proto.repeated_int32[1])
1302
1303    proto.repeated_int32.remove(10)
1304    self.assertEqual(1, len(proto.repeated_int32))
1305    self.assertEqual(5, proto.repeated_int32[0])
1306
1307    # Remove a non-existent element.
1308    self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
1309
1310  def testRepeatedScalarsReverse_Empty(self):
1311    proto = unittest_pb2.TestAllTypes()
1312
1313    self.assertFalse(proto.repeated_int32)
1314    self.assertEqual(0, len(proto.repeated_int32))
1315
1316    self.assertIsNone(proto.repeated_int32.reverse())
1317
1318    self.assertFalse(proto.repeated_int32)
1319    self.assertEqual(0, len(proto.repeated_int32))
1320
1321  def testRepeatedScalarsReverse_NonEmpty(self):
1322    proto = unittest_pb2.TestAllTypes()
1323
1324    self.assertFalse(proto.repeated_int32)
1325    self.assertEqual(0, len(proto.repeated_int32))
1326
1327    proto.repeated_int32.append(1)
1328    proto.repeated_int32.append(2)
1329    proto.repeated_int32.append(3)
1330    proto.repeated_int32.append(4)
1331
1332    self.assertEqual(4, len(proto.repeated_int32))
1333
1334    self.assertIsNone(proto.repeated_int32.reverse())
1335
1336    self.assertEqual(4, len(proto.repeated_int32))
1337    self.assertEqual(4, proto.repeated_int32[0])
1338    self.assertEqual(3, proto.repeated_int32[1])
1339    self.assertEqual(2, proto.repeated_int32[2])
1340    self.assertEqual(1, proto.repeated_int32[3])
1341
1342  def testRepeatedComposites(self):
1343    proto = unittest_pb2.TestAllTypes()
1344    self.assertFalse(proto.repeated_nested_message)
1345    self.assertEqual(0, len(proto.repeated_nested_message))
1346    m0 = proto.repeated_nested_message.add()
1347    m1 = proto.repeated_nested_message.add()
1348    self.assertTrue(proto.repeated_nested_message)
1349    self.assertEqual(2, len(proto.repeated_nested_message))
1350    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
1351    self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage)
1352
1353    # Test out-of-bounds indices.
1354    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
1355                      1234)
1356    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
1357                      -1234)
1358
1359    # Test incorrect types passed to __getitem__.
1360    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
1361                      'foo')
1362    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
1363                      None)
1364
1365    # Test slice retrieval.
1366    m2 = proto.repeated_nested_message.add()
1367    m3 = proto.repeated_nested_message.add()
1368    m4 = proto.repeated_nested_message.add()
1369    self.assertListsEqual(
1370        [m1, m2, m3], proto.repeated_nested_message[1:4])
1371    self.assertListsEqual(
1372        [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
1373    self.assertListsEqual(
1374        [m0, m1], proto.repeated_nested_message[:2])
1375    self.assertListsEqual(
1376        [m2, m3, m4], proto.repeated_nested_message[2:])
1377    self.assertEqual(
1378        m0, proto.repeated_nested_message[0])
1379    self.assertListsEqual(
1380        [m0], proto.repeated_nested_message[:1])
1381
1382    # Test that we can use the field as an iterator.
1383    result = []
1384    for i in proto.repeated_nested_message:
1385      result.append(i)
1386    self.assertListsEqual([m0, m1, m2, m3, m4], result)
1387
1388    # Test single deletion.
1389    del proto.repeated_nested_message[2]
1390    self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
1391
1392    # Test slice deletion.
1393    del proto.repeated_nested_message[2:]
1394    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
1395
1396    # Test extending.
1397    n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
1398    n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
1399    proto.repeated_nested_message.extend([n1,n2])
1400    self.assertEqual(4, len(proto.repeated_nested_message))
1401    self.assertEqual(n1, proto.repeated_nested_message[2])
1402    self.assertEqual(n2, proto.repeated_nested_message[3])
1403    self.assertRaises(TypeError,
1404                      proto.repeated_nested_message.extend, n1)
1405    self.assertRaises(TypeError,
1406                      proto.repeated_nested_message.extend, [0])
1407    wrong_message_type = unittest_pb2.TestAllTypes()
1408    self.assertRaises(TypeError,
1409                      proto.repeated_nested_message.extend,
1410                      [wrong_message_type])
1411
1412    # Test clearing.
1413    proto.ClearField('repeated_nested_message')
1414    self.assertFalse(proto.repeated_nested_message)
1415    self.assertEqual(0, len(proto.repeated_nested_message))
1416
1417    # Test constructing an element while adding it.
1418    proto.repeated_nested_message.add(bb=23)
1419    self.assertEqual(1, len(proto.repeated_nested_message))
1420    self.assertEqual(23, proto.repeated_nested_message[0].bb)
1421    self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
1422    with self.assertRaises(Exception):
1423      proto.repeated_nested_message[0] = 23
1424
1425  def testRepeatedCompositeRemove(self):
1426    proto = unittest_pb2.TestAllTypes()
1427
1428    self.assertEqual(0, len(proto.repeated_nested_message))
1429    m0 = proto.repeated_nested_message.add()
1430    # Need to set some differentiating variable so m0 != m1 != m2:
1431    m0.bb = len(proto.repeated_nested_message)
1432    m1 = proto.repeated_nested_message.add()
1433    m1.bb = len(proto.repeated_nested_message)
1434    self.assertTrue(m0 != m1)
1435    m2 = proto.repeated_nested_message.add()
1436    m2.bb = len(proto.repeated_nested_message)
1437    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
1438
1439    self.assertEqual(3, len(proto.repeated_nested_message))
1440    proto.repeated_nested_message.remove(m0)
1441    self.assertEqual(2, len(proto.repeated_nested_message))
1442    self.assertEqual(m1, proto.repeated_nested_message[0])
1443    self.assertEqual(m2, proto.repeated_nested_message[1])
1444
1445    # Removing m0 again or removing None should raise error
1446    self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
1447    self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
1448    self.assertEqual(2, len(proto.repeated_nested_message))
1449
1450    proto.repeated_nested_message.remove(m2)
1451    self.assertEqual(1, len(proto.repeated_nested_message))
1452    self.assertEqual(m1, proto.repeated_nested_message[0])
1453
1454  def testRepeatedCompositeReverse_Empty(self):
1455    proto = unittest_pb2.TestAllTypes()
1456
1457    self.assertFalse(proto.repeated_nested_message)
1458    self.assertEqual(0, len(proto.repeated_nested_message))
1459
1460    self.assertIsNone(proto.repeated_nested_message.reverse())
1461
1462    self.assertFalse(proto.repeated_nested_message)
1463    self.assertEqual(0, len(proto.repeated_nested_message))
1464
1465  def testRepeatedCompositeReverse_NonEmpty(self):
1466    proto = unittest_pb2.TestAllTypes()
1467
1468    self.assertFalse(proto.repeated_nested_message)
1469    self.assertEqual(0, len(proto.repeated_nested_message))
1470
1471    m0 = proto.repeated_nested_message.add()
1472    m0.bb = len(proto.repeated_nested_message)
1473    m1 = proto.repeated_nested_message.add()
1474    m1.bb = len(proto.repeated_nested_message)
1475    m2 = proto.repeated_nested_message.add()
1476    m2.bb = len(proto.repeated_nested_message)
1477    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
1478
1479    self.assertIsNone(proto.repeated_nested_message.reverse())
1480
1481    self.assertListsEqual([m2, m1, m0], proto.repeated_nested_message)
1482
1483  def testHandWrittenReflection(self):
1484    # Hand written extensions are only supported by the pure-Python
1485    # implementation of the API.
1486    if api_implementation.Type() != 'python':
1487      return
1488
1489    file = descriptor.FileDescriptor(name='foo.proto', package='')
1490    FieldDescriptor = descriptor.FieldDescriptor
1491    foo_field_descriptor = FieldDescriptor(
1492        name='foo_field', full_name='MyProto.foo_field',
1493        index=0, number=1, type=FieldDescriptor.TYPE_INT64,
1494        cpp_type=FieldDescriptor.CPPTYPE_INT64,
1495        label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
1496        containing_type=None, message_type=None, enum_type=None,
1497        is_extension=False, extension_scope=None,
1498        options=descriptor_pb2.FieldOptions(), file=file,
1499        # pylint: disable=protected-access
1500        create_key=descriptor._internal_create_key)
1501    mydescriptor = descriptor.Descriptor(
1502        name='MyProto', full_name='MyProto', filename='ignored',
1503        containing_type=None, nested_types=[], enum_types=[],
1504        fields=[foo_field_descriptor], extensions=[],
1505        options=descriptor_pb2.MessageOptions(),
1506        file=file,
1507        # pylint: disable=protected-access
1508        create_key=descriptor._internal_create_key)
1509
1510    class MyProtoClass(
1511        message.Message, metaclass=reflection.GeneratedProtocolMessageType):
1512      DESCRIPTOR = mydescriptor
1513    myproto_instance = MyProtoClass()
1514    self.assertEqual(0, myproto_instance.foo_field)
1515    self.assertFalse(myproto_instance.HasField('foo_field'))
1516    myproto_instance.foo_field = 23
1517    self.assertEqual(23, myproto_instance.foo_field)
1518    self.assertTrue(myproto_instance.HasField('foo_field'))
1519
1520  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
1521  def testDescriptorProtoSupport(self):
1522    # Hand written descriptors/reflection are only supported by the pure-Python
1523    # implementation of the API.
1524    if api_implementation.Type() != 'python':
1525      return
1526
1527    def AddDescriptorField(proto, field_name, field_type):
1528      AddDescriptorField.field_index += 1
1529      new_field = proto.field.add()
1530      new_field.name = field_name
1531      new_field.type = field_type
1532      new_field.number = AddDescriptorField.field_index
1533      new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
1534
1535    AddDescriptorField.field_index = 0
1536
1537    desc_proto = descriptor_pb2.DescriptorProto()
1538    desc_proto.name = 'Car'
1539    fdp = descriptor_pb2.FieldDescriptorProto
1540    AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
1541    AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
1542    AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
1543    AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
1544    # Add a repeated field
1545    AddDescriptorField.field_index += 1
1546    new_field = desc_proto.field.add()
1547    new_field.name = 'owners'
1548    new_field.type = fdp.TYPE_STRING
1549    new_field.number = AddDescriptorField.field_index
1550    new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
1551
1552    desc = descriptor.MakeDescriptor(desc_proto)
1553    self.assertTrue('name' in desc.fields_by_name)
1554    self.assertTrue('year' in desc.fields_by_name)
1555    self.assertTrue('automatic' in desc.fields_by_name)
1556    self.assertTrue('price' in desc.fields_by_name)
1557    self.assertTrue('owners' in desc.fields_by_name)
1558
1559    class CarMessage(
1560        message.Message, metaclass=reflection.GeneratedProtocolMessageType):
1561      DESCRIPTOR = desc
1562
1563    prius = CarMessage()
1564    prius.name = 'prius'
1565    prius.year = 2010
1566    prius.automatic = True
1567    prius.price = 25134.75
1568    prius.owners.extend(['bob', 'susan'])
1569
1570    serialized_prius = prius.SerializeToString()
1571    new_prius = message_factory.GetMessageClass(desc)()
1572    new_prius.ParseFromString(serialized_prius)
1573    self.assertIsNot(new_prius, prius)
1574    self.assertEqual(prius, new_prius)
1575
1576    # these are unnecessary assuming message equality works as advertised but
1577    # explicitly check to be safe since we're mucking about in metaclass foo
1578    self.assertEqual(prius.name, new_prius.name)
1579    self.assertEqual(prius.year, new_prius.year)
1580    self.assertEqual(prius.automatic, new_prius.automatic)
1581    self.assertEqual(prius.price, new_prius.price)
1582    self.assertEqual(prius.owners, new_prius.owners)
1583
1584  def testExtensionDelete(self):
1585    extendee_proto = more_extensions_pb2.ExtendedMessage()
1586
1587    extension_int32 = more_extensions_pb2.optional_int_extension
1588    extendee_proto.Extensions[extension_int32] = 23
1589
1590    extension_repeated = more_extensions_pb2.repeated_int_extension
1591    extendee_proto.Extensions[extension_repeated].append(11)
1592
1593    extension_msg = more_extensions_pb2.optional_message_extension
1594    extendee_proto.Extensions[extension_msg].foreign_message_int = 56
1595
1596    self.assertEqual(len(extendee_proto.Extensions), 3)
1597    del extendee_proto.Extensions[extension_msg]
1598    self.assertEqual(len(extendee_proto.Extensions), 2)
1599    del extendee_proto.Extensions[extension_repeated]
1600    self.assertEqual(len(extendee_proto.Extensions), 1)
1601    # Delete a none exist extension. It is OK to "del m.Extensions[ext]"
1602    # even if the extension is not present in the message; we don't
1603    # raise KeyError. This is consistent with "m.Extensions[ext]"
1604    # returning a default value even if we did not set anything.
1605    del extendee_proto.Extensions[extension_repeated]
1606    self.assertEqual(len(extendee_proto.Extensions), 1)
1607    del extendee_proto.Extensions[extension_int32]
1608    self.assertEqual(len(extendee_proto.Extensions), 0)
1609
1610  def testExtensionIter(self):
1611    extendee_proto = more_extensions_pb2.ExtendedMessage()
1612
1613    extension_int32 = more_extensions_pb2.optional_int_extension
1614    extendee_proto.Extensions[extension_int32] = 23
1615
1616    extension_repeated = more_extensions_pb2.repeated_int_extension
1617    extendee_proto.Extensions[extension_repeated].append(11)
1618
1619    extension_msg = more_extensions_pb2.optional_message_extension
1620    extendee_proto.Extensions[extension_msg].foreign_message_int = 56
1621
1622    # Set some normal fields.
1623    extendee_proto.optional_int32 = 1
1624    extendee_proto.repeated_string.append('hi')
1625
1626    expected = (extension_int32, extension_msg, extension_repeated)
1627    count = 0
1628    for item in extendee_proto.Extensions:
1629      self.assertEqual(item.name, expected[count].name)
1630      self.assertIn(item, extendee_proto.Extensions)
1631      count += 1
1632    self.assertEqual(count, 3)
1633
1634  def testExtensionContainsError(self):
1635    extendee_proto = more_extensions_pb2.ExtendedMessage()
1636    self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, 0)
1637
1638    field = more_extensions_pb2.ExtendedMessage.DESCRIPTOR.fields_by_name[
1639        'optional_int32']
1640    self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, field)
1641
1642  def testTopLevelExtensionsForOptionalScalar(self):
1643    extendee_proto = unittest_pb2.TestAllExtensions()
1644    extension = unittest_pb2.optional_int32_extension
1645    self.assertFalse(extendee_proto.HasExtension(extension))
1646    self.assertNotIn(extension, extendee_proto.Extensions)
1647    self.assertEqual(0, extendee_proto.Extensions[extension])
1648    # As with normal scalar fields, just doing a read doesn't actually set the
1649    # "has" bit.
1650    self.assertFalse(extendee_proto.HasExtension(extension))
1651    self.assertNotIn(extension, extendee_proto.Extensions)
1652    # Actually set the thing.
1653    extendee_proto.Extensions[extension] = 23
1654    self.assertEqual(23, extendee_proto.Extensions[extension])
1655    self.assertTrue(extendee_proto.HasExtension(extension))
1656    self.assertIn(extension, extendee_proto.Extensions)
1657    # Ensure that clearing works as well.
1658    extendee_proto.ClearExtension(extension)
1659    self.assertEqual(0, extendee_proto.Extensions[extension])
1660    self.assertFalse(extendee_proto.HasExtension(extension))
1661    self.assertNotIn(extension, extendee_proto.Extensions)
1662
1663  def testTopLevelExtensionsForRepeatedScalar(self):
1664    extendee_proto = unittest_pb2.TestAllExtensions()
1665    extension = unittest_pb2.repeated_string_extension
1666    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1667    self.assertNotIn(extension, extendee_proto.Extensions)
1668    extendee_proto.Extensions[extension].append('foo')
1669    self.assertEqual(['foo'], extendee_proto.Extensions[extension])
1670    self.assertIn(extension, extendee_proto.Extensions)
1671    string_list = extendee_proto.Extensions[extension]
1672    extendee_proto.ClearExtension(extension)
1673    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1674    self.assertNotIn(extension, extendee_proto.Extensions)
1675    self.assertIsNot(string_list, extendee_proto.Extensions[extension])
1676    # Shouldn't be allowed to do Extensions[extension] = 'a'
1677    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1678                      extension, 'a')
1679
1680  def testTopLevelExtensionsForOptionalMessage(self):
1681    extendee_proto = unittest_pb2.TestAllExtensions()
1682    extension = unittest_pb2.optional_foreign_message_extension
1683    self.assertFalse(extendee_proto.HasExtension(extension))
1684    self.assertNotIn(extension, extendee_proto.Extensions)
1685    self.assertEqual(0, extendee_proto.Extensions[extension].c)
1686    # As with normal (non-extension) fields, merely reading from the
1687    # thing shouldn't set the "has" bit.
1688    self.assertFalse(extendee_proto.HasExtension(extension))
1689    self.assertNotIn(extension, extendee_proto.Extensions)
1690    extendee_proto.Extensions[extension].c = 23
1691    self.assertEqual(23, extendee_proto.Extensions[extension].c)
1692    self.assertTrue(extendee_proto.HasExtension(extension))
1693    self.assertIn(extension, extendee_proto.Extensions)
1694    # Save a reference here.
1695    foreign_message = extendee_proto.Extensions[extension]
1696    extendee_proto.ClearExtension(extension)
1697    self.assertIsNot(foreign_message, extendee_proto.Extensions[extension])
1698    # Setting a field on foreign_message now shouldn't set
1699    # any "has" bits on extendee_proto.
1700    foreign_message.c = 42
1701    self.assertEqual(42, foreign_message.c)
1702    self.assertTrue(foreign_message.HasField('c'))
1703    self.assertFalse(extendee_proto.HasExtension(extension))
1704    self.assertNotIn(extension, extendee_proto.Extensions)
1705    # Shouldn't be allowed to do Extensions[extension] = 'a'
1706    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1707                      extension, 'a')
1708
1709  def testTopLevelExtensionsForRepeatedMessage(self):
1710    extendee_proto = unittest_pb2.TestAllExtensions()
1711    extension = unittest_pb2.repeatedgroup_extension
1712    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1713    group = extendee_proto.Extensions[extension].add()
1714    group.a = 23
1715    self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
1716    group.a = 42
1717    self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
1718    group_list = extendee_proto.Extensions[extension]
1719    extendee_proto.ClearExtension(extension)
1720    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1721    self.assertIsNot(group_list, extendee_proto.Extensions[extension])
1722    # Shouldn't be allowed to do Extensions[extension] = 'a'
1723    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1724                      extension, 'a')
1725
1726  def testNestedExtensions(self):
1727    extendee_proto = unittest_pb2.TestAllExtensions()
1728    extension = unittest_pb2.TestRequired.single
1729
1730    # We just test the non-repeated case.
1731    self.assertFalse(extendee_proto.HasExtension(extension))
1732    self.assertNotIn(extension, extendee_proto.Extensions)
1733    required = extendee_proto.Extensions[extension]
1734    self.assertEqual(0, required.a)
1735    self.assertFalse(extendee_proto.HasExtension(extension))
1736    self.assertNotIn(extension, extendee_proto.Extensions)
1737    required.a = 23
1738    self.assertEqual(23, extendee_proto.Extensions[extension].a)
1739    self.assertTrue(extendee_proto.HasExtension(extension))
1740    self.assertIn(extension, extendee_proto.Extensions)
1741    extendee_proto.ClearExtension(extension)
1742    self.assertIsNot(required, extendee_proto.Extensions[extension])
1743    self.assertFalse(extendee_proto.HasExtension(extension))
1744    self.assertNotIn(extension, extendee_proto.Extensions)
1745
1746  def testRegisteredExtensions(self):
1747    pool = unittest_pb2.DESCRIPTOR.pool
1748    self.assertTrue(
1749        pool.FindExtensionByNumber(
1750            unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
1751    self.assertIs(
1752        pool.FindExtensionByName(
1753            'protobuf_unittest.optional_int32_extension').containing_type,
1754        unittest_pb2.TestAllExtensions.DESCRIPTOR)
1755    # Make sure extensions haven't been registered into types that shouldn't
1756    # have any.
1757    self.assertEqual(0, len(
1758        pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
1759
1760  # If message A directly contains message B, and
1761  # a.HasField('b') is currently False, then mutating any
1762  # extension in B should change a.HasField('b') to True
1763  # (and so on up the object tree).
1764  def testHasBitsForAncestorsOfExtendedMessage(self):
1765    # Optional scalar extension.
1766    toplevel = more_extensions_pb2.TopLevelMessage()
1767    self.assertFalse(toplevel.HasField('submessage'))
1768    self.assertEqual(0, toplevel.submessage.Extensions[
1769        more_extensions_pb2.optional_int_extension])
1770    self.assertFalse(toplevel.HasField('submessage'))
1771    toplevel.submessage.Extensions[
1772        more_extensions_pb2.optional_int_extension] = 23
1773    self.assertEqual(23, toplevel.submessage.Extensions[
1774        more_extensions_pb2.optional_int_extension])
1775    self.assertTrue(toplevel.HasField('submessage'))
1776
1777    # Repeated scalar extension.
1778    toplevel = more_extensions_pb2.TopLevelMessage()
1779    self.assertFalse(toplevel.HasField('submessage'))
1780    self.assertEqual([], toplevel.submessage.Extensions[
1781        more_extensions_pb2.repeated_int_extension])
1782    self.assertFalse(toplevel.HasField('submessage'))
1783    toplevel.submessage.Extensions[
1784        more_extensions_pb2.repeated_int_extension].append(23)
1785    self.assertEqual([23], toplevel.submessage.Extensions[
1786        more_extensions_pb2.repeated_int_extension])
1787    self.assertTrue(toplevel.HasField('submessage'))
1788
1789    # Optional message extension.
1790    toplevel = more_extensions_pb2.TopLevelMessage()
1791    self.assertFalse(toplevel.HasField('submessage'))
1792    self.assertEqual(0, toplevel.submessage.Extensions[
1793        more_extensions_pb2.optional_message_extension].foreign_message_int)
1794    self.assertFalse(toplevel.HasField('submessage'))
1795    toplevel.submessage.Extensions[
1796        more_extensions_pb2.optional_message_extension].foreign_message_int = 23
1797    self.assertEqual(23, toplevel.submessage.Extensions[
1798        more_extensions_pb2.optional_message_extension].foreign_message_int)
1799    self.assertTrue(toplevel.HasField('submessage'))
1800
1801    # Repeated message extension.
1802    toplevel = more_extensions_pb2.TopLevelMessage()
1803    self.assertFalse(toplevel.HasField('submessage'))
1804    self.assertEqual(0, len(toplevel.submessage.Extensions[
1805        more_extensions_pb2.repeated_message_extension]))
1806    self.assertFalse(toplevel.HasField('submessage'))
1807    foreign = toplevel.submessage.Extensions[
1808        more_extensions_pb2.repeated_message_extension].add()
1809    self.assertEqual(foreign, toplevel.submessage.Extensions[
1810        more_extensions_pb2.repeated_message_extension][0])
1811    self.assertTrue(toplevel.HasField('submessage'))
1812
1813  def testDisconnectionAfterClearingEmptyMessage(self):
1814    toplevel = more_extensions_pb2.TopLevelMessage()
1815    extendee_proto = toplevel.submessage
1816    extension = more_extensions_pb2.optional_message_extension
1817    extension_proto = extendee_proto.Extensions[extension]
1818    extendee_proto.ClearExtension(extension)
1819    extension_proto.foreign_message_int = 23
1820
1821    self.assertIsNot(extension_proto, extendee_proto.Extensions[extension])
1822
1823  def testExtensionFailureModes(self):
1824    extendee_proto = unittest_pb2.TestAllExtensions()
1825
1826    # Try non-extension-handle arguments to HasExtension,
1827    # ClearExtension(), and Extensions[]...
1828    self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
1829    self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
1830    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
1831    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
1832
1833    # Try something that *is* an extension handle, just not for
1834    # this message...
1835    for unknown_handle in (more_extensions_pb2.optional_int_extension,
1836                           more_extensions_pb2.optional_message_extension,
1837                           more_extensions_pb2.repeated_int_extension,
1838                           more_extensions_pb2.repeated_message_extension):
1839      self.assertRaises(KeyError, extendee_proto.HasExtension,
1840                        unknown_handle)
1841      self.assertRaises(KeyError, extendee_proto.ClearExtension,
1842                        unknown_handle)
1843      self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
1844                        unknown_handle)
1845      self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
1846                        unknown_handle, 5)
1847
1848    # Try call HasExtension() with a valid handle, but for a
1849    # *repeated* field.  (Just as with non-extension repeated
1850    # fields, Has*() isn't supported for extension repeated fields).
1851    self.assertRaises(KeyError, extendee_proto.HasExtension,
1852                      unittest_pb2.repeated_string_extension)
1853
1854  def testMergeFromOptionalGroup(self):
1855    # Test merge with an optional group.
1856    proto1 = unittest_pb2.TestAllTypes()
1857    proto1.optionalgroup.a = 12
1858    proto2 = unittest_pb2.TestAllTypes()
1859    proto2.MergeFrom(proto1)
1860    self.assertEqual(12, proto2.optionalgroup.a)
1861
1862  def testMergeFromExtensionsSingular(self):
1863    proto1 = unittest_pb2.TestAllExtensions()
1864    proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
1865
1866    proto2 = unittest_pb2.TestAllExtensions()
1867    proto2.MergeFrom(proto1)
1868    self.assertEqual(
1869        1, proto2.Extensions[unittest_pb2.optional_int32_extension])
1870
1871  def testMergeFromExtensionsRepeated(self):
1872    proto1 = unittest_pb2.TestAllExtensions()
1873    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
1874    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
1875
1876    proto2 = unittest_pb2.TestAllExtensions()
1877    proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
1878    proto2.MergeFrom(proto1)
1879    self.assertEqual(
1880        3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
1881    self.assertEqual(
1882        0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
1883    self.assertEqual(
1884        1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
1885    self.assertEqual(
1886        2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
1887
1888  def testMergeFromExtensionsNestedMessage(self):
1889    proto1 = unittest_pb2.TestAllExtensions()
1890    ext1 = proto1.Extensions[
1891        unittest_pb2.repeated_nested_message_extension]
1892    m = ext1.add()
1893    m.bb = 222
1894    m = ext1.add()
1895    m.bb = 333
1896
1897    proto2 = unittest_pb2.TestAllExtensions()
1898    ext2 = proto2.Extensions[
1899        unittest_pb2.repeated_nested_message_extension]
1900    m = ext2.add()
1901    m.bb = 111
1902
1903    proto2.MergeFrom(proto1)
1904    ext2 = proto2.Extensions[
1905        unittest_pb2.repeated_nested_message_extension]
1906    self.assertEqual(3, len(ext2))
1907    self.assertEqual(111, ext2[0].bb)
1908    self.assertEqual(222, ext2[1].bb)
1909    self.assertEqual(333, ext2[2].bb)
1910
1911  def testCopyFromBadType(self):
1912    # The python implementation doesn't raise an exception in this
1913    # case. In theory it should.
1914    if api_implementation.Type() == 'python':
1915      return
1916    proto1 = unittest_pb2.TestAllTypes()
1917    proto2 = unittest_pb2.TestAllExtensions()
1918    self.assertRaises(TypeError, proto1.CopyFrom, proto2)
1919
1920  def testClear(self):
1921    proto = unittest_pb2.TestAllTypes()
1922    # C++ implementation does not support lazy fields right now so leave it
1923    # out for now.
1924    if api_implementation.Type() == 'python':
1925      test_util.SetAllFields(proto)
1926    else:
1927      test_util.SetAllNonLazyFields(proto)
1928    # Clear the message.
1929    proto.Clear()
1930    self.assertEqual(proto.ByteSize(), 0)
1931    empty_proto = unittest_pb2.TestAllTypes()
1932    self.assertEqual(proto, empty_proto)
1933
1934    # Test if extensions which were set are cleared.
1935    proto = unittest_pb2.TestAllExtensions()
1936    test_util.SetAllExtensions(proto)
1937    # Clear the message.
1938    proto.Clear()
1939    self.assertEqual(proto.ByteSize(), 0)
1940    empty_proto = unittest_pb2.TestAllExtensions()
1941    self.assertEqual(proto, empty_proto)
1942
1943  def testDisconnectingInOneof(self):
1944    m = unittest_pb2.TestOneof2()  # This message has two messages in a oneof.
1945    m.foo_message.moo_int = 5
1946    sub_message = m.foo_message
1947    # Accessing another message's field does not clear the first one
1948    self.assertEqual(m.foo_lazy_message.moo_int, 0)
1949    self.assertEqual(m.foo_message.moo_int, 5)
1950    # But mutating another message in the oneof detaches the first one.
1951    m.foo_lazy_message.moo_int = 6
1952    self.assertEqual(m.foo_message.moo_int, 0)
1953    # The reference we got above was detached and is still valid.
1954    self.assertEqual(sub_message.moo_int, 5)
1955    sub_message.moo_int = 7
1956
1957  def assertInitialized(self, proto):
1958    self.assertTrue(proto.IsInitialized())
1959    # Neither method should raise an exception.
1960    proto.SerializeToString()
1961    proto.SerializePartialToString()
1962
1963  def assertNotInitialized(self, proto, error_size=None):
1964    errors = []
1965    self.assertFalse(proto.IsInitialized())
1966    self.assertFalse(proto.IsInitialized(errors))
1967    self.assertEqual(error_size, len(errors))
1968    self.assertRaises(message.EncodeError, proto.SerializeToString)
1969    # "Partial" serialization doesn't care if message is uninitialized.
1970    proto.SerializePartialToString()
1971
1972  def testIsInitialized(self):
1973    # Trivial cases - all optional fields and extensions.
1974    proto = unittest_pb2.TestAllTypes()
1975    self.assertInitialized(proto)
1976    proto = unittest_pb2.TestAllExtensions()
1977    self.assertInitialized(proto)
1978
1979    # The case of uninitialized required fields.
1980    proto = unittest_pb2.TestRequired()
1981    self.assertNotInitialized(proto, 3)
1982    proto.a = proto.b = proto.c = 2
1983    self.assertInitialized(proto)
1984
1985    # The case of uninitialized submessage.
1986    proto = unittest_pb2.TestRequiredForeign()
1987    self.assertInitialized(proto)
1988    proto.optional_message.a = 1
1989    self.assertNotInitialized(proto, 2)
1990    proto.optional_message.b = 0
1991    proto.optional_message.c = 0
1992    self.assertInitialized(proto)
1993
1994    # Uninitialized repeated submessage.
1995    message1 = proto.repeated_message.add()
1996    self.assertNotInitialized(proto, 3)
1997    message1.a = message1.b = message1.c = 0
1998    self.assertInitialized(proto)
1999
2000    # Uninitialized repeated group in an extension.
2001    proto = unittest_pb2.TestAllExtensions()
2002    extension = unittest_pb2.TestRequired.multi
2003    message1 = proto.Extensions[extension].add()
2004    message2 = proto.Extensions[extension].add()
2005    self.assertNotInitialized(proto, 6)
2006    message1.a = 1
2007    message1.b = 1
2008    message1.c = 1
2009    self.assertNotInitialized(proto, 3)
2010    message2.a = 2
2011    message2.b = 2
2012    message2.c = 2
2013    self.assertInitialized(proto)
2014
2015    # Uninitialized nonrepeated message in an extension.
2016    proto = unittest_pb2.TestAllExtensions()
2017    extension = unittest_pb2.TestRequired.single
2018    proto.Extensions[extension].a = 1
2019    self.assertNotInitialized(proto, 2)
2020    proto.Extensions[extension].b = 2
2021    proto.Extensions[extension].c = 3
2022    self.assertInitialized(proto)
2023
2024    # Try passing an errors list.
2025    errors = []
2026    proto = unittest_pb2.TestRequired()
2027    self.assertFalse(proto.IsInitialized(errors))
2028    self.assertEqual(errors, ['a', 'b', 'c'])
2029    self.assertRaises(TypeError, proto.IsInitialized, 1, 2, 3)
2030
2031  @unittest.skipIf(
2032      api_implementation.Type() == 'python',
2033      'Errors are only available from the most recent C++ implementation.')
2034  def testFileDescriptorErrors(self):
2035    file_name = 'test_file_descriptor_errors.proto'
2036    package_name = 'test_file_descriptor_errors.proto'
2037    file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
2038    file_descriptor_proto.name = file_name
2039    file_descriptor_proto.package = package_name
2040    m1 = file_descriptor_proto.message_type.add()
2041    m1.name = 'msg1'
2042    # Compiles the proto into the C++ descriptor pool
2043    descriptor.FileDescriptor(
2044        file_name,
2045        package_name,
2046        serialized_pb=file_descriptor_proto.SerializeToString())
2047    # Add a FileDescriptorProto that has duplicate symbols
2048    another_file_name = 'another_test_file_descriptor_errors.proto'
2049    file_descriptor_proto.name = another_file_name
2050    m2 = file_descriptor_proto.message_type.add()
2051    m2.name = 'msg2'
2052    with self.assertRaises(TypeError) as cm:
2053      descriptor.FileDescriptor(
2054          another_file_name,
2055          package_name,
2056          serialized_pb=file_descriptor_proto.SerializeToString())
2057      self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
2058                      getattr(cm.expected, '__name__', cm.expected))
2059      self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
2060      # Error message will say something about this definition being a
2061      # duplicate, though we don't check the message exactly to avoid a
2062      # dependency on the C++ logging code.
2063      self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
2064
2065  def testDescriptorProtoHasFileOptions(self):
2066    self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
2067    self.assertEqual(
2068        descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
2069        'com.google.protobuf',
2070    )
2071
2072  def testDescriptorProtoHasFieldOptions(self):
2073    self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
2074    self.assertEqual(
2075        descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
2076        'com.google.protobuf',
2077    )
2078    packed_desc = (
2079        descriptor_pb2.SourceCodeInfo.DESCRIPTOR.nested_types_by_name.get(
2080            'Location'
2081        ).fields_by_name.get('path')
2082    )
2083    self.assertTrue(packed_desc.has_options)
2084    self.assertTrue(packed_desc.GetOptions().packed)
2085
2086  def testDescriptorProtoHasFeatureOptions(self):
2087    self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
2088    self.assertEqual(
2089        descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
2090        'com.google.protobuf',
2091    )
2092    presence_desc = descriptor_pb2.FeatureSet.DESCRIPTOR.fields_by_name.get(
2093        'field_presence'
2094    )
2095    self.assertTrue(presence_desc.has_options)
2096    self.assertEqual(
2097        presence_desc.GetOptions().retention,
2098        descriptor_pb2.FieldOptions.OptionRetention.RETENTION_RUNTIME,
2099    )
2100    self.assertListsEqual(
2101        presence_desc.GetOptions().targets,
2102        [
2103            descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FIELD,
2104            descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FILE,
2105        ],
2106    )
2107
2108  def testStringUTF8Serialization(self):
2109    proto = message_set_extensions_pb2.TestMessageSet()
2110    extension_message = message_set_extensions_pb2.TestMessageSetExtension2
2111    extension = extension_message.message_set_extension
2112
2113    test_utf8 = u'Тест'
2114    test_utf8_bytes = test_utf8.encode('utf-8')
2115
2116    # 'Test' in another language, using UTF-8 charset.
2117    proto.Extensions[extension].str = test_utf8
2118
2119    # Serialize using the MessageSet wire format (this is specified in the
2120    # .proto file).
2121    serialized = proto.SerializeToString()
2122
2123    # Check byte size.
2124    self.assertEqual(proto.ByteSize(), len(serialized))
2125
2126    raw = unittest_mset_pb2.RawMessageSet()
2127    bytes_read = raw.MergeFromString(serialized)
2128    self.assertEqual(len(serialized), bytes_read)
2129
2130    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
2131
2132    self.assertEqual(1, len(raw.item))
2133    # Check that the type_id is the same as the tag ID in the .proto file.
2134    self.assertEqual(raw.item[0].type_id, 98418634)
2135
2136    # Check the actual bytes on the wire.
2137    self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes))
2138    bytes_read = message2.MergeFromString(raw.item[0].message)
2139    self.assertEqual(len(raw.item[0].message), bytes_read)
2140
2141    self.assertEqual(type(message2.str), str)
2142    self.assertEqual(message2.str, test_utf8)
2143
2144    # The pure Python API throws an exception on MergeFromString(),
2145    # if any of the string fields of the message can't be UTF-8 decoded.
2146    # The C++ implementation of the API has no way to check that on
2147    # MergeFromString and thus has no way to throw the exception.
2148    #
2149    # The pure Python API always returns objects of type 'unicode' (UTF-8
2150    # encoded), or 'bytes' (in 7 bit ASCII).
2151    badbytes = raw.item[0].message.replace(
2152        test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
2153
2154    unicode_decode_failed = False
2155    try:
2156      message2.MergeFromString(badbytes)
2157    except UnicodeDecodeError:
2158      unicode_decode_failed = True
2159    string_field = message2.str
2160    self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
2161
2162  def testSetInParent(self):
2163    proto = unittest_pb2.TestAllTypes()
2164    self.assertFalse(proto.HasField('optionalgroup'))
2165    proto.optionalgroup.SetInParent()
2166    self.assertTrue(proto.HasField('optionalgroup'))
2167
2168  def testPackageInitializationImport(self):
2169    """Test that we can import nested messages from their __init__.py.
2170
2171    Such setup is not trivial since at the time of processing of __init__.py one
2172    can't refer to its submodules by name in code, so expressions like
2173    google.protobuf.internal.import_test_package.inner_pb2
2174    don't work. They do work in imports, so we have assign an alias at import
2175    and then use that alias in generated code.
2176    """
2177    # We import here since it's the import that used to fail, and we want
2178    # the failure to have the right context.
2179    # pylint: disable=g-import-not-at-top
2180    from google.protobuf.internal import import_test_package
2181    # pylint: enable=g-import-not-at-top
2182    msg = import_test_package.myproto.Outer()
2183    # Just check the default value.
2184    self.assertEqual(57, msg.inner.value)
2185
2186#  Since we had so many tests for protocol buffer equality, we broke these out
2187#  into separate TestCase classes.
2188
2189
2190@testing_refleaks.TestCase
2191class TestAllTypesEqualityTest(unittest.TestCase):
2192
2193  def setUp(self):
2194    self.first_proto = unittest_pb2.TestAllTypes()
2195    self.second_proto = unittest_pb2.TestAllTypes()
2196
2197  def testNotHashable(self):
2198    self.assertRaises(TypeError, hash, self.first_proto)
2199
2200  def testSelfEquality(self):
2201    self.assertEqual(self.first_proto, self.first_proto)
2202
2203  def testEmptyProtosEqual(self):
2204    self.assertEqual(self.first_proto, self.second_proto)
2205
2206
2207@testing_refleaks.TestCase
2208class FullProtosEqualityTest(unittest.TestCase):
2209
2210  """Equality tests using completely-full protos as a starting point."""
2211
2212  def setUp(self):
2213    self.first_proto = unittest_pb2.TestAllTypes()
2214    self.second_proto = unittest_pb2.TestAllTypes()
2215    test_util.SetAllFields(self.first_proto)
2216    test_util.SetAllFields(self.second_proto)
2217
2218  def testNotHashable(self):
2219    self.assertRaises(TypeError, hash, self.first_proto)
2220
2221  def testNoneNotEqual(self):
2222    self.assertNotEqual(self.first_proto, None)
2223    self.assertNotEqual(None, self.second_proto)
2224
2225  def testNotEqualToOtherMessage(self):
2226    third_proto = unittest_pb2.TestRequired()
2227    self.assertNotEqual(self.first_proto, third_proto)
2228    self.assertNotEqual(third_proto, self.second_proto)
2229
2230  def testAllFieldsFilledEquality(self):
2231    self.assertEqual(self.first_proto, self.second_proto)
2232
2233  def testNonRepeatedScalar(self):
2234    # Nonrepeated scalar field change should cause inequality.
2235    self.first_proto.optional_int32 += 1
2236    self.assertNotEqual(self.first_proto, self.second_proto)
2237    # ...as should clearing a field.
2238    self.first_proto.ClearField('optional_int32')
2239    self.assertNotEqual(self.first_proto, self.second_proto)
2240
2241  def testNonRepeatedComposite(self):
2242    # Change a nonrepeated composite field.
2243    self.first_proto.optional_nested_message.bb += 1
2244    self.assertNotEqual(self.first_proto, self.second_proto)
2245    self.first_proto.optional_nested_message.bb -= 1
2246    self.assertEqual(self.first_proto, self.second_proto)
2247    # Clear a field in the nested message.
2248    self.first_proto.optional_nested_message.ClearField('bb')
2249    self.assertNotEqual(self.first_proto, self.second_proto)
2250    self.first_proto.optional_nested_message.bb = (
2251        self.second_proto.optional_nested_message.bb)
2252    self.assertEqual(self.first_proto, self.second_proto)
2253    # Remove the nested message entirely.
2254    self.first_proto.ClearField('optional_nested_message')
2255    self.assertNotEqual(self.first_proto, self.second_proto)
2256
2257  def testRepeatedScalar(self):
2258    # Change a repeated scalar field.
2259    self.first_proto.repeated_int32.append(5)
2260    self.assertNotEqual(self.first_proto, self.second_proto)
2261    self.first_proto.ClearField('repeated_int32')
2262    self.assertNotEqual(self.first_proto, self.second_proto)
2263
2264  def testRepeatedComposite(self):
2265    # Change value within a repeated composite field.
2266    self.first_proto.repeated_nested_message[0].bb += 1
2267    self.assertNotEqual(self.first_proto, self.second_proto)
2268    self.first_proto.repeated_nested_message[0].bb -= 1
2269    self.assertEqual(self.first_proto, self.second_proto)
2270    # Add a value to a repeated composite field.
2271    self.first_proto.repeated_nested_message.add()
2272    self.assertNotEqual(self.first_proto, self.second_proto)
2273    self.second_proto.repeated_nested_message.add()
2274    self.assertEqual(self.first_proto, self.second_proto)
2275
2276  def testNonRepeatedScalarHasBits(self):
2277    # Ensure that we test "has" bits as well as value for
2278    # nonrepeated scalar field.
2279    self.first_proto.ClearField('optional_int32')
2280    self.second_proto.optional_int32 = 0
2281    self.assertNotEqual(self.first_proto, self.second_proto)
2282
2283  def testNonRepeatedCompositeHasBits(self):
2284    # Ensure that we test "has" bits as well as value for
2285    # nonrepeated composite field.
2286    self.first_proto.ClearField('optional_nested_message')
2287    self.second_proto.optional_nested_message.ClearField('bb')
2288    self.assertNotEqual(self.first_proto, self.second_proto)
2289    self.first_proto.optional_nested_message.bb = 0
2290    self.first_proto.optional_nested_message.ClearField('bb')
2291    self.assertEqual(self.first_proto, self.second_proto)
2292
2293
2294@testing_refleaks.TestCase
2295class ExtensionEqualityTest(unittest.TestCase):
2296
2297  def testExtensionEquality(self):
2298    first_proto = unittest_pb2.TestAllExtensions()
2299    second_proto = unittest_pb2.TestAllExtensions()
2300    self.assertEqual(first_proto, second_proto)
2301    test_util.SetAllExtensions(first_proto)
2302    self.assertNotEqual(first_proto, second_proto)
2303    test_util.SetAllExtensions(second_proto)
2304    self.assertEqual(first_proto, second_proto)
2305
2306    # Ensure that we check value equality.
2307    first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
2308    self.assertNotEqual(first_proto, second_proto)
2309    first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
2310    self.assertEqual(first_proto, second_proto)
2311
2312    # Ensure that we also look at "has" bits.
2313    first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
2314    second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
2315    self.assertNotEqual(first_proto, second_proto)
2316    first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
2317    self.assertEqual(first_proto, second_proto)
2318
2319    # Ensure that differences in cached values
2320    # don't matter if "has" bits are both false.
2321    first_proto = unittest_pb2.TestAllExtensions()
2322    second_proto = unittest_pb2.TestAllExtensions()
2323    self.assertEqual(
2324        0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
2325    self.assertEqual(first_proto, second_proto)
2326
2327
2328@testing_refleaks.TestCase
2329class MutualRecursionEqualityTest(unittest.TestCase):
2330
2331  def testEqualityWithMutualRecursion(self):
2332    first_proto = unittest_pb2.TestMutualRecursionA()
2333    second_proto = unittest_pb2.TestMutualRecursionA()
2334    self.assertEqual(first_proto, second_proto)
2335    first_proto.bb.a.bb.optional_int32 = 23
2336    self.assertNotEqual(first_proto, second_proto)
2337    second_proto.bb.a.bb.optional_int32 = 23
2338    self.assertEqual(first_proto, second_proto)
2339
2340
2341@testing_refleaks.TestCase
2342class ByteSizeTest(unittest.TestCase):
2343
2344  def setUp(self):
2345    self.proto = unittest_pb2.TestAllTypes()
2346    self.extended_proto = more_extensions_pb2.ExtendedMessage()
2347    self.packed_proto = unittest_pb2.TestPackedTypes()
2348    self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
2349
2350  def Size(self):
2351    return self.proto.ByteSize()
2352
2353  def testEmptyMessage(self):
2354    self.assertEqual(0, self.proto.ByteSize())
2355
2356  def testSizedOnKwargs(self):
2357    # Use a separate message to ensure testing right after creation.
2358    proto = unittest_pb2.TestAllTypes()
2359    self.assertEqual(0, proto.ByteSize())
2360    proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
2361    # One byte for the tag, one to encode varint 1.
2362    self.assertEqual(2, proto_kwargs.ByteSize())
2363
2364  def testVarints(self):
2365    def Test(i, expected_varint_size):
2366      self.proto.Clear()
2367      self.proto.optional_int64 = i
2368      # Add one to the varint size for the tag info
2369      # for tag 1.
2370      self.assertEqual(expected_varint_size + 1, self.Size())
2371    Test(0, 1)
2372    Test(1, 1)
2373    for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
2374      Test((1 << i) - 1, num_bytes)
2375    Test(-1, 10)
2376    Test(-2, 10)
2377    Test(-(1 << 63), 10)
2378
2379  def testStrings(self):
2380    self.proto.optional_string = ''
2381    # Need one byte for tag info (tag #14), and one byte for length.
2382    self.assertEqual(2, self.Size())
2383
2384    self.proto.optional_string = 'abc'
2385    # Need one byte for tag info (tag #14), and one byte for length.
2386    self.assertEqual(2 + len(self.proto.optional_string), self.Size())
2387
2388    self.proto.optional_string = 'x' * 128
2389    # Need one byte for tag info (tag #14), and TWO bytes for length.
2390    self.assertEqual(3 + len(self.proto.optional_string), self.Size())
2391
2392  def testOtherNumerics(self):
2393    self.proto.optional_fixed32 = 1234
2394    # One byte for tag and 4 bytes for fixed32.
2395    self.assertEqual(5, self.Size())
2396    self.proto = unittest_pb2.TestAllTypes()
2397
2398    self.proto.optional_fixed64 = 1234
2399    # One byte for tag and 8 bytes for fixed64.
2400    self.assertEqual(9, self.Size())
2401    self.proto = unittest_pb2.TestAllTypes()
2402
2403    self.proto.optional_float = 1.234
2404    # One byte for tag and 4 bytes for float.
2405    self.assertEqual(5, self.Size())
2406    self.proto = unittest_pb2.TestAllTypes()
2407
2408    self.proto.optional_double = 1.234
2409    # One byte for tag and 8 bytes for float.
2410    self.assertEqual(9, self.Size())
2411    self.proto = unittest_pb2.TestAllTypes()
2412
2413    self.proto.optional_sint32 = 64
2414    # One byte for tag and 2 bytes for zig-zag-encoded 64.
2415    self.assertEqual(3, self.Size())
2416    self.proto = unittest_pb2.TestAllTypes()
2417
2418  def testComposites(self):
2419    # 3 bytes.
2420    self.proto.optional_nested_message.bb = (1 << 14)
2421    # Plus one byte for bb tag.
2422    # Plus 1 byte for optional_nested_message serialized size.
2423    # Plus two bytes for optional_nested_message tag.
2424    self.assertEqual(3 + 1 + 1 + 2, self.Size())
2425
2426  def testGroups(self):
2427    # 4 bytes.
2428    self.proto.optionalgroup.a = (1 << 21)
2429    # Plus two bytes for |a| tag.
2430    # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
2431    self.assertEqual(4 + 2 + 2*2, self.Size())
2432
2433  def testRepeatedScalars(self):
2434    self.proto.repeated_int32.append(10)  # 1 byte.
2435    self.proto.repeated_int32.append(128)  # 2 bytes.
2436    # Also need 2 bytes for each entry for tag.
2437    self.assertEqual(1 + 2 + 2*2, self.Size())
2438
2439  def testRepeatedScalarsExtend(self):
2440    self.proto.repeated_int32.extend([10, 128])  # 3 bytes.
2441    # Also need 2 bytes for each entry for tag.
2442    self.assertEqual(1 + 2 + 2*2, self.Size())
2443
2444  def testRepeatedScalarsRemove(self):
2445    self.proto.repeated_int32.append(10)  # 1 byte.
2446    self.proto.repeated_int32.append(128)  # 2 bytes.
2447    # Also need 2 bytes for each entry for tag.
2448    self.assertEqual(1 + 2 + 2*2, self.Size())
2449    self.proto.repeated_int32.remove(128)
2450    self.assertEqual(1 + 2, self.Size())
2451
2452  def testRepeatedComposites(self):
2453    # Empty message.  2 bytes tag plus 1 byte length.
2454    foreign_message_0 = self.proto.repeated_nested_message.add()
2455    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2456    foreign_message_1 = self.proto.repeated_nested_message.add()
2457    foreign_message_1.bb = 7
2458    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2459
2460  def testRepeatedCompositesDelete(self):
2461    # Empty message.  2 bytes tag plus 1 byte length.
2462    foreign_message_0 = self.proto.repeated_nested_message.add()
2463    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2464    foreign_message_1 = self.proto.repeated_nested_message.add()
2465    foreign_message_1.bb = 9
2466    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2467    repeated_nested_message = copy.deepcopy(
2468        self.proto.repeated_nested_message)
2469
2470    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2471    del self.proto.repeated_nested_message[0]
2472    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2473
2474    # Now add a new message.
2475    foreign_message_2 = self.proto.repeated_nested_message.add()
2476    foreign_message_2.bb = 12
2477
2478    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2479    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2480    self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
2481
2482    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2483    del self.proto.repeated_nested_message[1]
2484    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2485
2486    del self.proto.repeated_nested_message[0]
2487    self.assertEqual(0, self.Size())
2488
2489    self.assertEqual(2, len(repeated_nested_message))
2490    del repeated_nested_message[0:1]
2491    # TODO: Fix cpp extension bug when delete repeated message.
2492    if api_implementation.Type() == 'python':
2493      self.assertEqual(1, len(repeated_nested_message))
2494    del repeated_nested_message[-1]
2495    # TODO: Fix cpp extension bug when delete repeated message.
2496    if api_implementation.Type() == 'python':
2497      self.assertEqual(0, len(repeated_nested_message))
2498
2499  def testRepeatedGroups(self):
2500    # 2-byte START_GROUP plus 2-byte END_GROUP.
2501    group_0 = self.proto.repeatedgroup.add()
2502    # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
2503    # plus 2-byte END_GROUP.
2504    group_1 = self.proto.repeatedgroup.add()
2505    group_1.a =  7
2506    self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
2507
2508  def testExtensions(self):
2509    proto = unittest_pb2.TestAllExtensions()
2510    self.assertEqual(0, proto.ByteSize())
2511    extension = unittest_pb2.optional_int32_extension  # Field #1, 1 byte.
2512    proto.Extensions[extension] = 23
2513    # 1 byte for tag, 1 byte for value.
2514    self.assertEqual(2, proto.ByteSize())
2515    field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
2516        'optional_int32']
2517    with self.assertRaises(KeyError):
2518      proto.Extensions[field] = 23
2519
2520  def testCacheInvalidationForNonrepeatedScalar(self):
2521    # Test non-extension.
2522    self.proto.optional_int32 = 1
2523    self.assertEqual(2, self.proto.ByteSize())
2524    self.proto.optional_int32 = 128
2525    self.assertEqual(3, self.proto.ByteSize())
2526    self.proto.ClearField('optional_int32')
2527    self.assertEqual(0, self.proto.ByteSize())
2528
2529    # Test within extension.
2530    extension = more_extensions_pb2.optional_int_extension
2531    self.extended_proto.Extensions[extension] = 1
2532    self.assertEqual(2, self.extended_proto.ByteSize())
2533    self.extended_proto.Extensions[extension] = 128
2534    self.assertEqual(3, self.extended_proto.ByteSize())
2535    self.extended_proto.ClearExtension(extension)
2536    self.assertEqual(0, self.extended_proto.ByteSize())
2537
2538  def testCacheInvalidationForRepeatedScalar(self):
2539    # Test non-extension.
2540    self.proto.repeated_int32.append(1)
2541    self.assertEqual(3, self.proto.ByteSize())
2542    self.proto.repeated_int32.append(1)
2543    self.assertEqual(6, self.proto.ByteSize())
2544    self.proto.repeated_int32[1] = 128
2545    self.assertEqual(7, self.proto.ByteSize())
2546    self.proto.ClearField('repeated_int32')
2547    self.assertEqual(0, self.proto.ByteSize())
2548
2549    # Test within extension.
2550    extension = more_extensions_pb2.repeated_int_extension
2551    repeated = self.extended_proto.Extensions[extension]
2552    repeated.append(1)
2553    self.assertEqual(2, self.extended_proto.ByteSize())
2554    repeated.append(1)
2555    self.assertEqual(4, self.extended_proto.ByteSize())
2556    repeated[1] = 128
2557    self.assertEqual(5, self.extended_proto.ByteSize())
2558    self.extended_proto.ClearExtension(extension)
2559    self.assertEqual(0, self.extended_proto.ByteSize())
2560
2561  def testCacheInvalidationForNonrepeatedMessage(self):
2562    # Test non-extension.
2563    self.proto.optional_foreign_message.c = 1
2564    self.assertEqual(5, self.proto.ByteSize())
2565    self.proto.optional_foreign_message.c = 128
2566    self.assertEqual(6, self.proto.ByteSize())
2567    self.proto.optional_foreign_message.ClearField('c')
2568    self.assertEqual(3, self.proto.ByteSize())
2569    self.proto.ClearField('optional_foreign_message')
2570    self.assertEqual(0, self.proto.ByteSize())
2571
2572    if api_implementation.Type() == 'python':
2573      # This is only possible in pure-Python implementation of the API.
2574      child = self.proto.optional_foreign_message
2575      self.proto.ClearField('optional_foreign_message')
2576      child.c = 128
2577      self.assertEqual(0, self.proto.ByteSize())
2578
2579    # Test within extension.
2580    extension = more_extensions_pb2.optional_message_extension
2581    child = self.extended_proto.Extensions[extension]
2582    self.assertEqual(0, self.extended_proto.ByteSize())
2583    child.foreign_message_int = 1
2584    self.assertEqual(4, self.extended_proto.ByteSize())
2585    child.foreign_message_int = 128
2586    self.assertEqual(5, self.extended_proto.ByteSize())
2587    self.extended_proto.ClearExtension(extension)
2588    self.assertEqual(0, self.extended_proto.ByteSize())
2589
2590  def testCacheInvalidationForRepeatedMessage(self):
2591    # Test non-extension.
2592    child0 = self.proto.repeated_foreign_message.add()
2593    self.assertEqual(3, self.proto.ByteSize())
2594    self.proto.repeated_foreign_message.add()
2595    self.assertEqual(6, self.proto.ByteSize())
2596    child0.c = 1
2597    self.assertEqual(8, self.proto.ByteSize())
2598    self.proto.ClearField('repeated_foreign_message')
2599    self.assertEqual(0, self.proto.ByteSize())
2600
2601    # Test within extension.
2602    extension = more_extensions_pb2.repeated_message_extension
2603    child_list = self.extended_proto.Extensions[extension]
2604    child0 = child_list.add()
2605    self.assertEqual(2, self.extended_proto.ByteSize())
2606    child_list.add()
2607    self.assertEqual(4, self.extended_proto.ByteSize())
2608    child0.foreign_message_int = 1
2609    self.assertEqual(6, self.extended_proto.ByteSize())
2610    child0.ClearField('foreign_message_int')
2611    self.assertEqual(4, self.extended_proto.ByteSize())
2612    self.extended_proto.ClearExtension(extension)
2613    self.assertEqual(0, self.extended_proto.ByteSize())
2614
2615  def testPackedRepeatedScalars(self):
2616    self.assertEqual(0, self.packed_proto.ByteSize())
2617
2618    self.packed_proto.packed_int32.append(10)   # 1 byte.
2619    self.packed_proto.packed_int32.append(128)  # 2 bytes.
2620    # The tag is 2 bytes (the field number is 90), and the varint
2621    # storing the length is 1 byte.
2622    int_size = 1 + 2 + 3
2623    self.assertEqual(int_size, self.packed_proto.ByteSize())
2624
2625    self.packed_proto.packed_double.append(4.2)   # 8 bytes
2626    self.packed_proto.packed_double.append(3.25)  # 8 bytes
2627    # 2 more tag bytes, 1 more length byte.
2628    double_size = 8 + 8 + 3
2629    self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
2630
2631    self.packed_proto.ClearField('packed_int32')
2632    self.assertEqual(double_size, self.packed_proto.ByteSize())
2633
2634  def testPackedExtensions(self):
2635    self.assertEqual(0, self.packed_extended_proto.ByteSize())
2636    extension = self.packed_extended_proto.Extensions[
2637        unittest_pb2.packed_fixed32_extension]
2638    extension.extend([1, 2, 3, 4])   # 16 bytes
2639    # Tag is 3 bytes.
2640    self.assertEqual(19, self.packed_extended_proto.ByteSize())
2641
2642
2643# Issues to be sure to cover include:
2644#   * Handling of unrecognized tags ("uninterpreted_bytes").
2645#   * Handling of MessageSets.
2646#   * Consistent ordering of tags in the wire format,
2647#     including ordering between extensions and non-extension
2648#     fields.
2649#   * Consistent serialization of negative numbers, especially
2650#     negative int32s.
2651#   * Handling of empty submessages (with and without "has"
2652#     bits set).
2653
2654@testing_refleaks.TestCase
2655class SerializationTest(unittest.TestCase):
2656
2657  def testSerializeEmptyMessage(self):
2658    first_proto = unittest_pb2.TestAllTypes()
2659    second_proto = unittest_pb2.TestAllTypes()
2660    serialized = first_proto.SerializeToString()
2661    self.assertEqual(first_proto.ByteSize(), len(serialized))
2662    self.assertEqual(
2663        len(serialized),
2664        second_proto.MergeFromString(serialized))
2665    self.assertEqual(first_proto, second_proto)
2666
2667  def testSerializeAllFields(self):
2668    first_proto = unittest_pb2.TestAllTypes()
2669    second_proto = unittest_pb2.TestAllTypes()
2670    test_util.SetAllFields(first_proto)
2671    serialized = first_proto.SerializeToString()
2672    self.assertEqual(first_proto.ByteSize(), len(serialized))
2673    self.assertEqual(
2674        len(serialized),
2675        second_proto.MergeFromString(serialized))
2676    self.assertEqual(first_proto, second_proto)
2677
2678  def testSerializeAllExtensions(self):
2679    first_proto = unittest_pb2.TestAllExtensions()
2680    second_proto = unittest_pb2.TestAllExtensions()
2681    test_util.SetAllExtensions(first_proto)
2682    serialized = first_proto.SerializeToString()
2683    self.assertEqual(
2684        len(serialized),
2685        second_proto.MergeFromString(serialized))
2686    self.assertEqual(first_proto, second_proto)
2687
2688  def testSerializeWithOptionalGroup(self):
2689    first_proto = unittest_pb2.TestAllTypes()
2690    second_proto = unittest_pb2.TestAllTypes()
2691    first_proto.optionalgroup.a = 242
2692    serialized = first_proto.SerializeToString()
2693    self.assertEqual(
2694        len(serialized),
2695        second_proto.MergeFromString(serialized))
2696    self.assertEqual(first_proto, second_proto)
2697
2698  def testSerializeNegativeValues(self):
2699    first_proto = unittest_pb2.TestAllTypes()
2700
2701    first_proto.optional_int32 = -1
2702    first_proto.optional_int64 = -(2 << 40)
2703    first_proto.optional_sint32 = -3
2704    first_proto.optional_sint64 = -(4 << 40)
2705    first_proto.optional_sfixed32 = -5
2706    first_proto.optional_sfixed64 = -(6 << 40)
2707
2708    second_proto = unittest_pb2.TestAllTypes.FromString(
2709        first_proto.SerializeToString())
2710
2711    self.assertEqual(first_proto, second_proto)
2712
2713  def testParseTruncated(self):
2714    # This test is only applicable for the Python implementation of the API.
2715    if api_implementation.Type() != 'python':
2716      return
2717
2718    first_proto = unittest_pb2.TestAllTypes()
2719    test_util.SetAllFields(first_proto)
2720    serialized = memoryview(first_proto.SerializeToString())
2721
2722    for truncation_point in range(len(serialized) + 1):
2723      try:
2724        second_proto = unittest_pb2.TestAllTypes()
2725        unknown_fields = unittest_pb2.TestEmptyMessage()
2726        pos = second_proto._InternalParse(serialized, 0, truncation_point)
2727        # If we didn't raise an error then we read exactly the amount expected.
2728        self.assertEqual(truncation_point, pos)
2729
2730        # Parsing to unknown fields should not throw if parsing to known fields
2731        # did not.
2732        try:
2733          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
2734          self.assertEqual(truncation_point, pos2)
2735        except message.DecodeError:
2736          self.fail('Parsing unknown fields failed when parsing known fields '
2737                    'did not.')
2738      except message.DecodeError:
2739        # Parsing unknown fields should also fail.
2740        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
2741                          serialized, 0, truncation_point)
2742
2743  def testCanonicalSerializationOrder(self):
2744    proto = more_messages_pb2.OutOfOrderFields()
2745    # These are also their tag numbers.  Even though we're setting these in
2746    # reverse-tag order AND they're listed in reverse tag-order in the .proto
2747    # file, they should nonetheless be serialized in tag order.
2748    proto.optional_sint32 = 5
2749    proto.Extensions[more_messages_pb2.optional_uint64] = 4
2750    proto.optional_uint32 = 3
2751    proto.Extensions[more_messages_pb2.optional_int64] = 2
2752    proto.optional_int32 = 1
2753    serialized = proto.SerializeToString()
2754    self.assertEqual(proto.ByteSize(), len(serialized))
2755    d = _MiniDecoder(serialized)
2756    ReadTag = d.ReadFieldNumberAndWireType
2757    self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
2758    self.assertEqual(1, d.ReadInt32())
2759    self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
2760    self.assertEqual(2, d.ReadInt64())
2761    self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
2762    self.assertEqual(3, d.ReadUInt32())
2763    self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
2764    self.assertEqual(4, d.ReadUInt64())
2765    self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
2766    self.assertEqual(5, d.ReadSInt32())
2767
2768  def testCanonicalSerializationOrderSameAsCpp(self):
2769    # Copy of the same test we use for C++.
2770    proto = unittest_pb2.TestFieldOrderings()
2771    test_util.SetAllFieldsAndExtensions(proto)
2772    serialized = proto.SerializeToString()
2773    test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
2774
2775  def testMergeFromStringWhenFieldsAlreadySet(self):
2776    first_proto = unittest_pb2.TestAllTypes()
2777    first_proto.repeated_string.append('foobar')
2778    first_proto.optional_int32 = 23
2779    first_proto.optional_nested_message.bb = 42
2780    serialized = first_proto.SerializeToString()
2781
2782    second_proto = unittest_pb2.TestAllTypes()
2783    second_proto.repeated_string.append('baz')
2784    second_proto.optional_int32 = 100
2785    second_proto.optional_nested_message.bb = 999
2786
2787    bytes_parsed = second_proto.MergeFromString(serialized)
2788    self.assertEqual(len(serialized), bytes_parsed)
2789
2790    # Ensure that we append to repeated fields.
2791    self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
2792    # Ensure that we overwrite nonrepeatd scalars.
2793    self.assertEqual(23, second_proto.optional_int32)
2794    # Ensure that we recursively call MergeFromString() on
2795    # submessages.
2796    self.assertEqual(42, second_proto.optional_nested_message.bb)
2797
2798  def testMessageSetWireFormat(self):
2799    proto = message_set_extensions_pb2.TestMessageSet()
2800    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2801    extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
2802    extension1 = extension_message1.message_set_extension
2803    extension2 = extension_message2.message_set_extension
2804    extension3 = message_set_extensions_pb2.message_set_extension3
2805    proto.Extensions[extension1].i = 123
2806    proto.Extensions[extension2].str = 'foo'
2807    proto.Extensions[extension3].text = 'bar'
2808
2809    # Serialize using the MessageSet wire format (this is specified in the
2810    # .proto file).
2811    serialized = proto.SerializeToString()
2812
2813    raw = unittest_mset_pb2.RawMessageSet()
2814    self.assertEqual(False,
2815                     raw.DESCRIPTOR.GetOptions().message_set_wire_format)
2816    self.assertEqual(
2817        len(serialized),
2818        raw.MergeFromString(serialized))
2819    self.assertEqual(3, len(raw.item))
2820
2821    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2822    self.assertEqual(
2823        len(raw.item[0].message),
2824        message1.MergeFromString(raw.item[0].message))
2825    self.assertEqual(123, message1.i)
2826
2827    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
2828    self.assertEqual(
2829        len(raw.item[1].message),
2830        message2.MergeFromString(raw.item[1].message))
2831    self.assertEqual('foo', message2.str)
2832
2833    message3 = message_set_extensions_pb2.TestMessageSetExtension3()
2834    self.assertEqual(
2835        len(raw.item[2].message),
2836        message3.MergeFromString(raw.item[2].message))
2837    self.assertEqual('bar', message3.text)
2838
2839    # Deserialize using the MessageSet wire format.
2840    proto2 = message_set_extensions_pb2.TestMessageSet()
2841    self.assertEqual(
2842        len(serialized),
2843        proto2.MergeFromString(serialized))
2844    self.assertEqual(123, proto2.Extensions[extension1].i)
2845    self.assertEqual('foo', proto2.Extensions[extension2].str)
2846    self.assertEqual('bar', proto2.Extensions[extension3].text)
2847
2848    # Check byte size.
2849    self.assertEqual(proto2.ByteSize(), len(serialized))
2850    self.assertEqual(proto.ByteSize(), len(serialized))
2851
2852  def testMessageSetWireFormatUnknownExtension(self):
2853    # Create a message using the message set wire format with an unknown
2854    # message.
2855    raw = unittest_mset_pb2.RawMessageSet()
2856
2857    # Add an item.
2858    item = raw.item.add()
2859    item.type_id = 98418603
2860    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2861    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2862    message1.i = 12345
2863    item.message = message1.SerializeToString()
2864
2865    # Add a second, unknown extension.
2866    item = raw.item.add()
2867    item.type_id = 98418604
2868    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2869    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2870    message1.i = 12346
2871    item.message = message1.SerializeToString()
2872
2873    # Add another unknown extension.
2874    item = raw.item.add()
2875    item.type_id = 98418605
2876    message1 = message_set_extensions_pb2.TestMessageSetExtension2()
2877    message1.str = 'foo'
2878    item.message = message1.SerializeToString()
2879
2880    serialized = raw.SerializeToString()
2881
2882    # Parse message using the message set wire format.
2883    proto = message_set_extensions_pb2.TestMessageSet()
2884    self.assertEqual(
2885        len(serialized),
2886        proto.MergeFromString(serialized))
2887
2888    # Check that the message parsed well.
2889    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2890    extension1 = extension_message1.message_set_extension
2891    self.assertEqual(12345, proto.Extensions[extension1].i)
2892
2893  def testUnknownFields(self):
2894    proto = unittest_pb2.TestAllTypes()
2895    test_util.SetAllFields(proto)
2896
2897    serialized = proto.SerializeToString()
2898
2899    # The empty message should be parsable with all of the fields
2900    # unknown.
2901    proto2 = unittest_pb2.TestEmptyMessage()
2902
2903    # Parsing this message should succeed.
2904    self.assertEqual(
2905        len(serialized),
2906        proto2.MergeFromString(serialized))
2907
2908    # Now test with a int64 field set.
2909    proto = unittest_pb2.TestAllTypes()
2910    proto.optional_int64 = 0x0fffffffffffffff
2911    serialized = proto.SerializeToString()
2912    # The empty message should be parsable with all of the fields
2913    # unknown.
2914    proto2 = unittest_pb2.TestEmptyMessage()
2915    # Parsing this message should succeed.
2916    self.assertEqual(
2917        len(serialized),
2918        proto2.MergeFromString(serialized))
2919
2920  def _CheckRaises(self, exc_class, callable_obj, exception):
2921    """This method checks if the exception type and message are as expected."""
2922    try:
2923      callable_obj()
2924    except exc_class as ex:
2925      # Check if the exception message is the right one.
2926      self.assertEqual(exception, str(ex))
2927      return
2928    else:
2929      raise self.failureException('%s not raised' % str(exc_class))
2930
2931  def testSerializeUninitialized(self):
2932    proto = unittest_pb2.TestRequired()
2933    self._CheckRaises(
2934        message.EncodeError,
2935        proto.SerializeToString,
2936        'Message protobuf_unittest.TestRequired is missing required fields: '
2937        'a,b,c')
2938    # Shouldn't raise exceptions.
2939    partial = proto.SerializePartialToString()
2940
2941    proto2 = unittest_pb2.TestRequired()
2942    self.assertFalse(proto2.HasField('a'))
2943    # proto2 ParseFromString does not check that required fields are set.
2944    proto2.ParseFromString(partial)
2945    self.assertFalse(proto2.HasField('a'))
2946
2947    proto.a = 1
2948    self._CheckRaises(
2949        message.EncodeError,
2950        proto.SerializeToString,
2951        'Message protobuf_unittest.TestRequired is missing required fields: b,c')
2952    # Shouldn't raise exceptions.
2953    partial = proto.SerializePartialToString()
2954
2955    proto.b = 2
2956    self._CheckRaises(
2957        message.EncodeError,
2958        proto.SerializeToString,
2959        'Message protobuf_unittest.TestRequired is missing required fields: c')
2960    # Shouldn't raise exceptions.
2961    partial = proto.SerializePartialToString()
2962
2963    proto.c = 3
2964    serialized = proto.SerializeToString()
2965    # Shouldn't raise exceptions.
2966    partial = proto.SerializePartialToString()
2967
2968    proto2 = unittest_pb2.TestRequired()
2969    self.assertEqual(
2970        len(serialized),
2971        proto2.MergeFromString(serialized))
2972    self.assertEqual(1, proto2.a)
2973    self.assertEqual(2, proto2.b)
2974    self.assertEqual(3, proto2.c)
2975    self.assertEqual(
2976        len(partial),
2977        proto2.MergeFromString(partial))
2978    self.assertEqual(1, proto2.a)
2979    self.assertEqual(2, proto2.b)
2980    self.assertEqual(3, proto2.c)
2981
2982  def testSerializeUninitializedSubMessage(self):
2983    proto = unittest_pb2.TestRequiredForeign()
2984
2985    # Sub-message doesn't exist yet, so this succeeds.
2986    proto.SerializeToString()
2987
2988    proto.optional_message.a = 1
2989    self._CheckRaises(
2990        message.EncodeError,
2991        proto.SerializeToString,
2992        'Message protobuf_unittest.TestRequiredForeign '
2993        'is missing required fields: '
2994        'optional_message.b,optional_message.c')
2995
2996    proto.optional_message.b = 2
2997    proto.optional_message.c = 3
2998    proto.SerializeToString()
2999
3000    proto.repeated_message.add().a = 1
3001    proto.repeated_message.add().b = 2
3002    self._CheckRaises(
3003        message.EncodeError,
3004        proto.SerializeToString,
3005        'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
3006        'repeated_message[0].b,repeated_message[0].c,'
3007        'repeated_message[1].a,repeated_message[1].c')
3008
3009    proto.repeated_message[0].b = 2
3010    proto.repeated_message[0].c = 3
3011    proto.repeated_message[1].a = 1
3012    proto.repeated_message[1].c = 3
3013    proto.SerializeToString()
3014
3015  def testSerializeAllPackedFields(self):
3016    first_proto = unittest_pb2.TestPackedTypes()
3017    second_proto = unittest_pb2.TestPackedTypes()
3018    test_util.SetAllPackedFields(first_proto)
3019    serialized = first_proto.SerializeToString()
3020    self.assertEqual(first_proto.ByteSize(), len(serialized))
3021    bytes_read = second_proto.MergeFromString(serialized)
3022    self.assertEqual(second_proto.ByteSize(), bytes_read)
3023    self.assertEqual(first_proto, second_proto)
3024
3025  def testSerializeAllPackedExtensions(self):
3026    first_proto = unittest_pb2.TestPackedExtensions()
3027    second_proto = unittest_pb2.TestPackedExtensions()
3028    test_util.SetAllPackedExtensions(first_proto)
3029    serialized = first_proto.SerializeToString()
3030    bytes_read = second_proto.MergeFromString(serialized)
3031    self.assertEqual(second_proto.ByteSize(), bytes_read)
3032    self.assertEqual(first_proto, second_proto)
3033
3034  def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
3035    first_proto = unittest_pb2.TestPackedTypes()
3036    first_proto.packed_int32.extend([1, 2])
3037    first_proto.packed_double.append(3.0)
3038    serialized = first_proto.SerializeToString()
3039
3040    second_proto = unittest_pb2.TestPackedTypes()
3041    second_proto.packed_int32.append(3)
3042    second_proto.packed_double.extend([1.0, 2.0])
3043    second_proto.packed_sint32.append(4)
3044
3045    self.assertEqual(
3046        len(serialized),
3047        second_proto.MergeFromString(serialized))
3048    self.assertEqual([3, 1, 2], second_proto.packed_int32)
3049    self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
3050    self.assertEqual([4], second_proto.packed_sint32)
3051
3052  def testPackedFieldsWireFormat(self):
3053    proto = unittest_pb2.TestPackedTypes()
3054    proto.packed_int32.extend([1, 2, 150, 3])  # 1 + 1 + 2 + 1 bytes
3055    proto.packed_double.extend([1.0, 1000.0])  # 8 + 8 bytes
3056    proto.packed_float.append(2.0)             # 4 bytes, will be before double
3057    serialized = proto.SerializeToString()
3058    self.assertEqual(proto.ByteSize(), len(serialized))
3059    d = _MiniDecoder(serialized)
3060    ReadTag = d.ReadFieldNumberAndWireType
3061    self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
3062    self.assertEqual(1+1+1+2, d.ReadInt32())
3063    self.assertEqual(1, d.ReadInt32())
3064    self.assertEqual(2, d.ReadInt32())
3065    self.assertEqual(150, d.ReadInt32())
3066    self.assertEqual(3, d.ReadInt32())
3067    self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
3068    self.assertEqual(4, d.ReadInt32())
3069    self.assertEqual(2.0, d.ReadFloat())
3070    self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
3071    self.assertEqual(8+8, d.ReadInt32())
3072    self.assertEqual(1.0, d.ReadDouble())
3073    self.assertEqual(1000.0, d.ReadDouble())
3074    self.assertTrue(d.EndOfStream())
3075
3076  def testParsePackedFromUnpacked(self):
3077    unpacked = unittest_pb2.TestUnpackedTypes()
3078    test_util.SetAllUnpackedFields(unpacked)
3079    packed = unittest_pb2.TestPackedTypes()
3080    serialized = unpacked.SerializeToString()
3081    self.assertEqual(
3082        len(serialized),
3083        packed.MergeFromString(serialized))
3084    expected = unittest_pb2.TestPackedTypes()
3085    test_util.SetAllPackedFields(expected)
3086    self.assertEqual(expected, packed)
3087
3088  def testParseUnpackedFromPacked(self):
3089    packed = unittest_pb2.TestPackedTypes()
3090    test_util.SetAllPackedFields(packed)
3091    unpacked = unittest_pb2.TestUnpackedTypes()
3092    serialized = packed.SerializeToString()
3093    self.assertEqual(
3094        len(serialized),
3095        unpacked.MergeFromString(serialized))
3096    expected = unittest_pb2.TestUnpackedTypes()
3097    test_util.SetAllUnpackedFields(expected)
3098    self.assertEqual(expected, unpacked)
3099
3100  def testFieldNumbers(self):
3101    proto = unittest_pb2.TestAllTypes()
3102    self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
3103    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
3104    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
3105    self.assertEqual(
3106      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
3107    self.assertEqual(
3108      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
3109    self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
3110    self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
3111    self.assertEqual(
3112      unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
3113    self.assertEqual(
3114      unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
3115
3116  def testExtensionFieldNumbers(self):
3117    self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
3118    self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
3119    self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
3120    self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
3121    self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
3122    self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
3123    self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
3124    self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
3125    self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
3126    self.assertEqual(
3127      unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
3128    self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
3129    self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
3130      21)
3131    self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
3132    self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
3133    self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
3134    self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
3135    self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
3136    self.assertEqual(
3137      unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
3138    self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
3139    self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
3140      51)
3141
3142  def testFieldProperties(self):
3143    cls = unittest_pb2.TestAllTypes
3144    self.assertIs(cls.optional_int32.DESCRIPTOR,
3145                  cls.DESCRIPTOR.fields_by_name['optional_int32'])
3146    self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER,
3147                     cls.optional_int32.DESCRIPTOR.number)
3148    self.assertIs(cls.optional_nested_message.DESCRIPTOR,
3149                  cls.DESCRIPTOR.fields_by_name['optional_nested_message'])
3150    self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER,
3151                     cls.optional_nested_message.DESCRIPTOR.number)
3152    self.assertIs(cls.repeated_int32.DESCRIPTOR,
3153                  cls.DESCRIPTOR.fields_by_name['repeated_int32'])
3154    self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER,
3155                     cls.repeated_int32.DESCRIPTOR.number)
3156
3157  def testFieldDataDescriptor(self):
3158    msg = unittest_pb2.TestAllTypes()
3159    msg.optional_int32 = 42
3160    self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42)
3161    unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25)
3162    self.assertEqual(msg.optional_int32, 25)
3163    with self.assertRaises(AttributeError):
3164      del msg.optional_int32
3165    try:
3166      unittest_pb2.ForeignMessage.c.__get__(msg)
3167    except TypeError:
3168      pass  # The cpp implementation cannot mix fields from other messages.
3169      # This test exercises a specific check that avoids a crash.
3170    else:
3171      pass  # The python implementation allows fields from other messages.
3172      # This is useless, but works.
3173
3174  def testInitKwargs(self):
3175    proto = unittest_pb2.TestAllTypes(
3176        optional_int32=1,
3177        optional_string='foo',
3178        optional_bool=True,
3179        optional_bytes=b'bar',
3180        optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
3181        optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
3182        optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
3183        optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
3184        repeated_int32=[1, 2, 3])
3185    self.assertTrue(proto.IsInitialized())
3186    self.assertTrue(proto.HasField('optional_int32'))
3187    self.assertTrue(proto.HasField('optional_string'))
3188    self.assertTrue(proto.HasField('optional_bool'))
3189    self.assertTrue(proto.HasField('optional_bytes'))
3190    self.assertTrue(proto.HasField('optional_nested_message'))
3191    self.assertTrue(proto.HasField('optional_foreign_message'))
3192    self.assertTrue(proto.HasField('optional_nested_enum'))
3193    self.assertTrue(proto.HasField('optional_foreign_enum'))
3194    self.assertEqual(1, proto.optional_int32)
3195    self.assertEqual('foo', proto.optional_string)
3196    self.assertEqual(True, proto.optional_bool)
3197    self.assertEqual(b'bar', proto.optional_bytes)
3198    self.assertEqual(1, proto.optional_nested_message.bb)
3199    self.assertEqual(1, proto.optional_foreign_message.c)
3200    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
3201                     proto.optional_nested_enum)
3202    self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
3203    self.assertEqual([1, 2, 3], proto.repeated_int32)
3204
3205  def testInitArgsUnknownFieldName(self):
3206    def InitializeEmptyMessageWithExtraKeywordArg():
3207      unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
3208    self._CheckRaises(
3209        ValueError,
3210        InitializeEmptyMessageWithExtraKeywordArg,
3211        'Protocol message TestEmptyMessage has no "unknown" field.')
3212
3213  def testInitRequiredKwargs(self):
3214    proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
3215    self.assertTrue(proto.IsInitialized())
3216    self.assertTrue(proto.HasField('a'))
3217    self.assertTrue(proto.HasField('b'))
3218    self.assertTrue(proto.HasField('c'))
3219    self.assertFalse(proto.HasField('dummy2'))
3220    self.assertEqual(1, proto.a)
3221    self.assertEqual(1, proto.b)
3222    self.assertEqual(1, proto.c)
3223
3224  def testInitRequiredForeignKwargs(self):
3225    proto = unittest_pb2.TestRequiredForeign(
3226        optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
3227    self.assertTrue(proto.IsInitialized())
3228    self.assertTrue(proto.HasField('optional_message'))
3229    self.assertTrue(proto.optional_message.IsInitialized())
3230    self.assertTrue(proto.optional_message.HasField('a'))
3231    self.assertTrue(proto.optional_message.HasField('b'))
3232    self.assertTrue(proto.optional_message.HasField('c'))
3233    self.assertFalse(proto.optional_message.HasField('dummy2'))
3234    self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
3235                     proto.optional_message)
3236    self.assertEqual(1, proto.optional_message.a)
3237    self.assertEqual(1, proto.optional_message.b)
3238    self.assertEqual(1, proto.optional_message.c)
3239
3240  def testInitRepeatedKwargs(self):
3241    proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
3242    self.assertTrue(proto.IsInitialized())
3243    self.assertEqual(1, proto.repeated_int32[0])
3244    self.assertEqual(2, proto.repeated_int32[1])
3245    self.assertEqual(3, proto.repeated_int32[2])
3246
3247
3248@testing_refleaks.TestCase
3249class OptionsTest(unittest.TestCase):
3250
3251  def testMessageOptions(self):
3252    proto = message_set_extensions_pb2.TestMessageSet()
3253    self.assertEqual(True,
3254                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
3255    proto = unittest_pb2.TestAllTypes()
3256    self.assertEqual(False,
3257                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
3258
3259  def testPackedOptions(self):
3260    proto = unittest_pb2.TestAllTypes()
3261    proto.optional_int32 = 1
3262    proto.optional_double = 3.0
3263    for field_descriptor, _ in proto.ListFields():
3264      self.assertEqual(False, field_descriptor.is_packed)
3265
3266    proto = unittest_pb2.TestPackedTypes()
3267    proto.packed_int32.append(1)
3268    proto.packed_double.append(3.0)
3269    for field_descriptor, _ in proto.ListFields():
3270      self.assertEqual(True, field_descriptor.is_packed)
3271      self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED,
3272                       field_descriptor.label)
3273
3274
3275@testing_refleaks.TestCase
3276class ClassAPITest(unittest.TestCase):
3277
3278  @unittest.skipIf(
3279      api_implementation.Type() != 'python',
3280      'C++ implementation requires a call to MakeDescriptor()')
3281  @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable')
3282  def testMakeClassWithNestedDescriptor(self):
3283    leaf_desc = descriptor.Descriptor(
3284        'leaf', 'package.parent.child.leaf', '',
3285        containing_type=None, fields=[],
3286        nested_types=[], enum_types=[],
3287        extensions=[],
3288        # pylint: disable=protected-access
3289        create_key=descriptor._internal_create_key)
3290    child_desc = descriptor.Descriptor(
3291        'child', 'package.parent.child', '',
3292        containing_type=None, fields=[],
3293        nested_types=[leaf_desc], enum_types=[],
3294        extensions=[],
3295        # pylint: disable=protected-access
3296        create_key=descriptor._internal_create_key)
3297    sibling_desc = descriptor.Descriptor(
3298        'sibling', 'package.parent.sibling',
3299        '', containing_type=None, fields=[],
3300        nested_types=[], enum_types=[],
3301        extensions=[],
3302        # pylint: disable=protected-access
3303        create_key=descriptor._internal_create_key)
3304    parent_desc = descriptor.Descriptor(
3305        'parent', 'package.parent', '',
3306        containing_type=None, fields=[],
3307        nested_types=[child_desc, sibling_desc],
3308        enum_types=[], extensions=[],
3309        # pylint: disable=protected-access
3310        create_key=descriptor._internal_create_key)
3311    message_factory.GetMessageClass(parent_desc)
3312
3313  def _GetSerializedFileDescriptor(self, name):
3314    """Get a serialized representation of a test FileDescriptorProto.
3315
3316    Args:
3317      name: All calls to this must use a unique message name, to avoid
3318          collisions in the cpp descriptor pool.
3319    Returns:
3320      A string containing the serialized form of a test FileDescriptorProto.
3321    """
3322    file_descriptor_str = (
3323        'message_type {'
3324        '  name: "' + name + '"'
3325        '  field {'
3326        '    name: "flat"'
3327        '    number: 1'
3328        '    label: LABEL_REPEATED'
3329        '    type: TYPE_UINT32'
3330        '  }'
3331        '  field {'
3332        '    name: "bar"'
3333        '    number: 2'
3334        '    label: LABEL_OPTIONAL'
3335        '    type: TYPE_MESSAGE'
3336        '    type_name: "Bar"'
3337        '  }'
3338        '  nested_type {'
3339        '    name: "Bar"'
3340        '    field {'
3341        '      name: "baz"'
3342        '      number: 3'
3343        '      label: LABEL_OPTIONAL'
3344        '      type: TYPE_MESSAGE'
3345        '      type_name: "Baz"'
3346        '    }'
3347        '    nested_type {'
3348        '      name: "Baz"'
3349        '      enum_type {'
3350        '        name: "deep_enum"'
3351        '        value {'
3352        '          name: "VALUE_A"'
3353        '          number: 0'
3354        '        }'
3355        '      }'
3356        '      field {'
3357        '        name: "deep"'
3358        '        number: 4'
3359        '        label: LABEL_OPTIONAL'
3360        '        type: TYPE_UINT32'
3361        '      }'
3362        '    }'
3363        '  }'
3364        '}')
3365    file_descriptor = descriptor_pb2.FileDescriptorProto()
3366    text_format.Merge(file_descriptor_str, file_descriptor)
3367    return file_descriptor.SerializeToString()
3368
3369  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3370  # This test can only run once; the second time, it raises errors about
3371  # conflicting message descriptors.
3372  def testParsingFlatClassWithExplicitClassDeclaration(self):
3373    """Test that the generated class can parse a flat message."""
3374    # TODO: This test fails with cpp implementation in the call
3375    # of six.with_metaclass(). The other two callsites of with_metaclass
3376    # in this file are both excluded from cpp test, so it might be expected
3377    # to fail. Need someone more familiar with the python code to take a
3378    # look at this.
3379    if api_implementation.Type() != 'python':
3380      return
3381    file_descriptor = descriptor_pb2.FileDescriptorProto()
3382    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
3383    msg_descriptor = descriptor.MakeDescriptor(
3384        file_descriptor.message_type[0])
3385
3386    class MessageClass(
3387        message.Message, metaclass=reflection.GeneratedProtocolMessageType):
3388      DESCRIPTOR = msg_descriptor
3389    msg = MessageClass()
3390    msg_str = (
3391        'flat: 0 '
3392        'flat: 1 '
3393        'flat: 2 ')
3394    text_format.Merge(msg_str, msg)
3395    self.assertEqual(msg.flat, [0, 1, 2])
3396
3397  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3398  def testParsingFlatClass(self):
3399    """Test that the generated class can parse a flat message."""
3400    file_descriptor = descriptor_pb2.FileDescriptorProto()
3401    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
3402    msg_descriptor = descriptor.MakeDescriptor(
3403        file_descriptor.message_type[0])
3404    msg_class = message_factory.GetMessageClass(msg_descriptor)
3405    msg = msg_class()
3406    msg_str = (
3407        'flat: 0 '
3408        'flat: 1 '
3409        'flat: 2 ')
3410    text_format.Merge(msg_str, msg)
3411    self.assertEqual(msg.flat, [0, 1, 2])
3412
3413  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3414  def testParsingNestedClass(self):
3415    """Test that the generated class can parse a nested message."""
3416    file_descriptor = descriptor_pb2.FileDescriptorProto()
3417    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
3418    msg_descriptor = descriptor.MakeDescriptor(
3419        file_descriptor.message_type[0])
3420    msg_class = message_factory.GetMessageClass(msg_descriptor)
3421    msg = msg_class()
3422    msg_str = (
3423        'bar {'
3424        '  baz {'
3425        '    deep: 4'
3426        '  }'
3427        '}')
3428    text_format.Merge(msg_str, msg)
3429    self.assertEqual(msg.bar.baz.deep, 4)
3430
3431if __name__ == '__main__':
3432  unittest.main()
3433