• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# -*- coding: utf-8 -*-
2# Protocol Buffers - Google's data interchange format
3# Copyright 2008 Google Inc.  All rights reserved.
4#
5# Use of this source code is governed by a BSD-style
6# license that can be found in the LICENSE file or at
7# https://developers.google.com/open-source/licenses/bsd
8
9"""Test for preservation of unknown fields in the pure Python implementation."""
10
11__author__ = 'bohdank@google.com (Bohdan Koval)'
12
13import sys
14import unittest
15
16from google.protobuf.internal import api_implementation
17from google.protobuf.internal import encoder
18from google.protobuf.internal import message_set_extensions_pb2
19from google.protobuf.internal import missing_enum_values_pb2
20from google.protobuf.internal import test_util
21from google.protobuf.internal import testing_refleaks
22from google.protobuf.internal import type_checkers
23from google.protobuf.internal import wire_format
24from google.protobuf import descriptor
25from google.protobuf import unknown_fields
26from google.protobuf import map_unittest_pb2
27from google.protobuf import unittest_mset_pb2
28from google.protobuf import unittest_pb2
29from google.protobuf import unittest_proto3_arena_pb2
30try:
31  import tracemalloc  # pylint: disable=g-import-not-at-top
32except ImportError:
33  # Requires python 3.4+
34  pass
35
36
37@testing_refleaks.TestCase
38class UnknownFieldsTest(unittest.TestCase):
39
40  def setUp(self):
41    self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
42    self.all_fields = unittest_pb2.TestAllTypes()
43    test_util.SetAllFields(self.all_fields)
44    self.all_fields_data = self.all_fields.SerializeToString()
45    self.empty_message = unittest_pb2.TestEmptyMessage()
46    self.empty_message.ParseFromString(self.all_fields_data)
47
48  def testSerialize(self):
49    data = self.empty_message.SerializeToString()
50
51    # Don't use assertEqual because we don't want to dump raw binary data to
52    # stdout.
53    self.assertTrue(data == self.all_fields_data)
54
55  def testSerializeProto3(self):
56    # Verify proto3 unknown fields behavior.
57    message = unittest_proto3_arena_pb2.TestEmptyMessage()
58    message.ParseFromString(self.all_fields_data)
59    self.assertEqual(self.all_fields_data, message.SerializeToString())
60
61  def testByteSize(self):
62    self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
63
64  def testListFields(self):
65    # Make sure ListFields doesn't return unknown fields.
66    self.assertEqual(0, len(self.empty_message.ListFields()))
67
68  def testSerializeMessageSetWireFormatUnknownExtension(self):
69    # Create a message using the message set wire format with an unknown
70    # message.
71    raw = unittest_mset_pb2.RawMessageSet()
72
73    # Add an unknown extension.
74    item = raw.item.add()
75    item.type_id = 98218603
76    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
77    message1.i = 12345
78    item.message = message1.SerializeToString()
79
80    serialized = raw.SerializeToString()
81
82    # Parse message using the message set wire format.
83    proto = message_set_extensions_pb2.TestMessageSet()
84    proto.MergeFromString(serialized)
85
86    unknown_field_set = unknown_fields.UnknownFieldSet(proto)
87    self.assertEqual(len(unknown_field_set), 1)
88    # Unknown field should have wire format data which can be parsed back to
89    # original message.
90    self.assertEqual(unknown_field_set[0].field_number, item.type_id)
91    self.assertEqual(unknown_field_set[0].wire_type,
92                     wire_format.WIRETYPE_LENGTH_DELIMITED)
93    d = unknown_field_set[0].data
94    message_new = message_set_extensions_pb2.TestMessageSetExtension1()
95    message_new.ParseFromString(d)
96    self.assertEqual(message1, message_new)
97
98    # Verify that the unknown extension is serialized unchanged
99    reserialized = proto.SerializeToString()
100    new_raw = unittest_mset_pb2.RawMessageSet()
101    new_raw.MergeFromString(reserialized)
102    self.assertEqual(raw, new_raw)
103
104  def testEquals(self):
105    message = unittest_pb2.TestEmptyMessage()
106    message.ParseFromString(self.all_fields_data)
107    self.assertEqual(self.empty_message, message)
108
109    self.all_fields.ClearField('optional_string')
110    message.ParseFromString(self.all_fields.SerializeToString())
111    self.assertNotEqual(self.empty_message, message)
112
113  def testDiscardUnknownFields(self):
114    self.empty_message.DiscardUnknownFields()
115    self.assertEqual(b'', self.empty_message.SerializeToString())
116    # Test message field and repeated message field.
117    message = unittest_pb2.TestAllTypes()
118    other_message = unittest_pb2.TestAllTypes()
119    other_message.optional_string = 'discard'
120    message.optional_nested_message.ParseFromString(
121        other_message.SerializeToString())
122    message.repeated_nested_message.add().ParseFromString(
123        other_message.SerializeToString())
124    self.assertNotEqual(
125        b'', message.optional_nested_message.SerializeToString())
126    self.assertNotEqual(
127        b'', message.repeated_nested_message[0].SerializeToString())
128    message.DiscardUnknownFields()
129    self.assertEqual(b'', message.optional_nested_message.SerializeToString())
130    self.assertEqual(
131        b'', message.repeated_nested_message[0].SerializeToString())
132
133    msg = map_unittest_pb2.TestMap()
134    msg.map_int32_all_types[1].optional_nested_message.ParseFromString(
135        other_message.SerializeToString())
136    msg.map_string_string['1'] = 'test'
137    self.assertNotEqual(
138        b'',
139        msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
140    msg.DiscardUnknownFields()
141    self.assertEqual(
142        b'',
143        msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
144
145
146@testing_refleaks.TestCase
147class UnknownFieldsAccessorsTest(unittest.TestCase):
148
149  def setUp(self):
150    self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
151    self.all_fields = unittest_pb2.TestAllTypes()
152    test_util.SetAllFields(self.all_fields)
153    self.all_fields_data = self.all_fields.SerializeToString()
154    self.empty_message = unittest_pb2.TestEmptyMessage()
155    self.empty_message.ParseFromString(self.all_fields_data)
156
157  def CheckUnknownField(self, name, unknown_field_set, expected_value):
158    field_descriptor = self.descriptor.fields_by_name[name]
159    expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
160        field_descriptor.type]
161    for unknown_field in unknown_field_set:
162      if unknown_field.field_number == field_descriptor.number:
163        self.assertEqual(expected_type, unknown_field.wire_type)
164        if expected_type == 3:
165          # Check group
166          self.assertEqual(expected_value[0],
167                           unknown_field.data[0].field_number)
168          self.assertEqual(expected_value[1], unknown_field.data[0].wire_type)
169          self.assertEqual(expected_value[2], unknown_field.data[0].data)
170          continue
171        if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
172          self.assertIn(type(unknown_field.data), (str, bytes))
173        if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
174          self.assertIn(unknown_field.data, expected_value)
175        else:
176          self.assertEqual(expected_value, unknown_field.data)
177
178  def testCheckUnknownFieldValue(self):
179    unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message)
180    # Test enum.
181    self.CheckUnknownField('optional_nested_enum',
182                           unknown_field_set,
183                           self.all_fields.optional_nested_enum)
184
185    # Test repeated enum.
186    self.CheckUnknownField('repeated_nested_enum',
187                           unknown_field_set,
188                           self.all_fields.repeated_nested_enum)
189
190    # Test varint.
191    self.CheckUnknownField('optional_int32',
192                           unknown_field_set,
193                           self.all_fields.optional_int32)
194
195    # Test fixed32.
196    self.CheckUnknownField('optional_fixed32',
197                           unknown_field_set,
198                           self.all_fields.optional_fixed32)
199
200    # Test fixed64.
201    self.CheckUnknownField('optional_fixed64',
202                           unknown_field_set,
203                           self.all_fields.optional_fixed64)
204
205    # Test length delimited.
206    self.CheckUnknownField('optional_string',
207                           unknown_field_set,
208                           self.all_fields.optional_string.encode('utf-8'))
209
210    # Test group.
211    self.CheckUnknownField('optionalgroup',
212                           unknown_field_set,
213                           (17, 0, 117))
214
215    self.assertEqual(99, len(unknown_field_set))
216
217  def testCopyFrom(self):
218    message = unittest_pb2.TestEmptyMessage()
219    message.CopyFrom(self.empty_message)
220    self.assertEqual(message.SerializeToString(), self.all_fields_data)
221
222  def testMergeFrom(self):
223    message = unittest_pb2.TestAllTypes()
224    message.optional_int32 = 1
225    message.optional_uint32 = 2
226    source = unittest_pb2.TestEmptyMessage()
227    source.ParseFromString(message.SerializeToString())
228
229    message.ClearField('optional_int32')
230    message.optional_int64 = 3
231    message.optional_uint32 = 4
232    destination = unittest_pb2.TestEmptyMessage()
233    unknown_field_set = unknown_fields.UnknownFieldSet(destination)
234    self.assertEqual(0, len(unknown_field_set))
235    destination.ParseFromString(message.SerializeToString())
236    self.assertEqual(0, len(unknown_field_set))
237    unknown_field_set = unknown_fields.UnknownFieldSet(destination)
238    self.assertEqual(2, len(unknown_field_set))
239    destination.MergeFrom(source)
240    self.assertEqual(2, len(unknown_field_set))
241    # Check that the fields where correctly merged, even stored in the unknown
242    # fields set.
243    message.ParseFromString(destination.SerializeToString())
244    self.assertEqual(message.optional_int32, 1)
245    self.assertEqual(message.optional_uint32, 2)
246    self.assertEqual(message.optional_int64, 3)
247
248  def testClear(self):
249    unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message)
250    self.empty_message.Clear()
251    # All cleared, even unknown fields.
252    self.assertEqual(self.empty_message.SerializeToString(), b'')
253    self.assertEqual(len(unknown_field_set), 99)
254
255  @unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4),
256                   'tracemalloc requires python 3.4+')
257  def testUnknownFieldsNoMemoryLeak(self):
258    # Call to UnknownFields must not leak memory
259    nb_leaks = 1234
260
261    def leaking_function():
262      for _ in range(nb_leaks):
263        unknown_fields.UnknownFieldSet(self.empty_message)
264
265    tracemalloc.start()
266    snapshot1 = tracemalloc.take_snapshot()
267    leaking_function()
268    snapshot2 = tracemalloc.take_snapshot()
269    top_stats = snapshot2.compare_to(snapshot1, 'lineno')
270    tracemalloc.stop()
271    # There's no easy way to look for a precise leak source.
272    # Rely on a "marker" count value while checking allocated memory.
273    self.assertEqual([], [x for x in top_stats if x.count_diff == nb_leaks])
274
275  def testSubUnknownFields(self):
276    message = unittest_pb2.TestAllTypes()
277    message.optionalgroup.a = 123
278    destination = unittest_pb2.TestEmptyMessage()
279    destination.ParseFromString(message.SerializeToString())
280    sub_unknown_fields = unknown_fields.UnknownFieldSet(destination)[0].data
281    self.assertEqual(1, len(sub_unknown_fields))
282    self.assertEqual(sub_unknown_fields[0].data, 123)
283    destination.Clear()
284    self.assertEqual(1, len(sub_unknown_fields))
285    self.assertEqual(sub_unknown_fields[0].data, 123)
286    message.Clear()
287    message.optional_uint32 = 456
288    nested_message = unittest_pb2.NestedTestAllTypes()
289    nested_message.payload.optional_nested_message.ParseFromString(
290        message.SerializeToString())
291    unknown_field_set = unknown_fields.UnknownFieldSet(
292        nested_message.payload.optional_nested_message)
293    self.assertEqual(unknown_field_set[0].data, 456)
294    nested_message.ClearField('payload')
295    self.assertEqual(unknown_field_set[0].data, 456)
296    unknown_field_set = unknown_fields.UnknownFieldSet(
297        nested_message.payload.optional_nested_message)
298    self.assertEqual(0, len(unknown_field_set))
299
300  def testUnknownField(self):
301    message = unittest_pb2.TestAllTypes()
302    message.optional_int32 = 123
303    destination = unittest_pb2.TestEmptyMessage()
304    destination.ParseFromString(message.SerializeToString())
305    unknown_field = unknown_fields.UnknownFieldSet(destination)[0]
306    destination.Clear()
307    self.assertEqual(unknown_field.data, 123)
308
309  def testUnknownExtensions(self):
310    message = unittest_pb2.TestEmptyMessageWithExtensions()
311    message.ParseFromString(self.all_fields_data)
312    self.assertEqual(len(unknown_fields.UnknownFieldSet(message)), 99)
313    self.assertEqual(message.SerializeToString(), self.all_fields_data)
314
315
316@testing_refleaks.TestCase
317class UnknownEnumValuesTest(unittest.TestCase):
318
319  def setUp(self):
320    self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
321
322    self.message = missing_enum_values_pb2.TestEnumValues()
323    # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum.
324    self.message.optional_nested_enum = (
325        missing_enum_values_pb2.TestEnumValues.ZERO)
326    self.message.repeated_nested_enum.extend([
327        missing_enum_values_pb2.TestEnumValues.ZERO,
328        missing_enum_values_pb2.TestEnumValues.ONE,
329        ])
330    self.message.packed_nested_enum.extend([
331        missing_enum_values_pb2.TestEnumValues.ZERO,
332        missing_enum_values_pb2.TestEnumValues.ONE,
333        ])
334    self.message_data = self.message.SerializeToString()
335    self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
336    self.missing_message.ParseFromString(self.message_data)
337
338  # CheckUnknownField() is an additional Pure Python check which checks
339  # a detail of unknown fields. It cannot be used by the C++
340  # implementation because some protect members are called.
341  # The test is added for historical reasons. It is not necessary as
342  # serialized string is checked.
343
344  def CheckUnknownField(self, name, expected_value):
345    field_descriptor = self.descriptor.fields_by_name[name]
346    unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message)
347    self.assertIsInstance(unknown_field_set, unknown_fields.UnknownFieldSet)
348    count = 0
349    for field in unknown_field_set:
350      if field.field_number == field_descriptor.number:
351        count += 1
352        if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
353          self.assertIn(field.data, expected_value)
354        else:
355          self.assertEqual(expected_value, field.data)
356    if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
357      self.assertEqual(count, len(expected_value))
358    else:
359      self.assertEqual(count, 1)
360
361  def testUnknownParseMismatchEnumValue(self):
362    just_string = missing_enum_values_pb2.JustString()
363    just_string.dummy = 'blah'
364
365    missing = missing_enum_values_pb2.TestEnumValues()
366    # The parse is invalid, storing the string proto into the set of
367    # unknown fields.
368    missing.ParseFromString(just_string.SerializeToString())
369
370    # Fetching the enum field shouldn't crash, instead returning the
371    # default value.
372    self.assertEqual(missing.optional_nested_enum, 0)
373
374  def testUnknownEnumValue(self):
375    self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
376    self.assertEqual(self.missing_message.optional_nested_enum, 2)
377    # Clear does not do anything.
378    serialized = self.missing_message.SerializeToString()
379    self.missing_message.ClearField('optional_nested_enum')
380    self.assertEqual(self.missing_message.SerializeToString(), serialized)
381
382  def testUnknownRepeatedEnumValue(self):
383    self.assertEqual([], self.missing_message.repeated_nested_enum)
384
385  def testUnknownPackedEnumValue(self):
386    self.assertEqual([], self.missing_message.packed_nested_enum)
387
388  def testCheckUnknownFieldValueForEnum(self):
389    unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message)
390    self.assertEqual(len(unknown_field_set), 5)
391    self.CheckUnknownField('optional_nested_enum',
392                           self.message.optional_nested_enum)
393    self.CheckUnknownField('repeated_nested_enum',
394                           self.message.repeated_nested_enum)
395    self.CheckUnknownField('packed_nested_enum',
396                           self.message.packed_nested_enum)
397
398  def testRoundTrip(self):
399    new_message = missing_enum_values_pb2.TestEnumValues()
400    new_message.ParseFromString(self.missing_message.SerializeToString())
401    self.assertEqual(self.message, new_message)
402
403
404if __name__ == '__main__':
405  unittest.main()
406