• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#! /usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# https://developers.google.com/protocol-buffers/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Test for preservation of unknown fields in the pure Python implementation."""
35
36__author__ = 'bohdank@google.com (Bohdan Koval)'
37
38try:
39  import unittest2 as unittest  #PY26
40except ImportError:
41  import unittest
42import sys
43try:
44  import tracemalloc
45except ImportError:
46  # Requires python 3.4+
47  pass
48from google.protobuf import map_unittest_pb2
49from google.protobuf import unittest_mset_pb2
50from google.protobuf import unittest_pb2
51from google.protobuf import unittest_proto3_arena_pb2
52from google.protobuf.internal import api_implementation
53from google.protobuf.internal import encoder
54from google.protobuf.internal import message_set_extensions_pb2
55from google.protobuf.internal import missing_enum_values_pb2
56from google.protobuf.internal import test_util
57from google.protobuf.internal import testing_refleaks
58from google.protobuf.internal import type_checkers
59from google.protobuf.internal import wire_format
60from google.protobuf import descriptor
61
62
63@testing_refleaks.TestCase
64class UnknownFieldsTest(unittest.TestCase):
65
66  def setUp(self):
67    self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
68    self.all_fields = unittest_pb2.TestAllTypes()
69    test_util.SetAllFields(self.all_fields)
70    self.all_fields_data = self.all_fields.SerializeToString()
71    self.empty_message = unittest_pb2.TestEmptyMessage()
72    self.empty_message.ParseFromString(self.all_fields_data)
73
74  def testSerialize(self):
75    data = self.empty_message.SerializeToString()
76
77    # Don't use assertEqual because we don't want to dump raw binary data to
78    # stdout.
79    self.assertTrue(data == self.all_fields_data)
80
81  def testSerializeProto3(self):
82    # Verify proto3 unknown fields behavior.
83    message = unittest_proto3_arena_pb2.TestEmptyMessage()
84    message.ParseFromString(self.all_fields_data)
85    self.assertEqual(self.all_fields_data, message.SerializeToString())
86
87  def testByteSize(self):
88    self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
89
90  def testListFields(self):
91    # Make sure ListFields doesn't return unknown fields.
92    self.assertEqual(0, len(self.empty_message.ListFields()))
93
94  def testSerializeMessageSetWireFormatUnknownExtension(self):
95    # Create a message using the message set wire format with an unknown
96    # message.
97    raw = unittest_mset_pb2.RawMessageSet()
98
99    # Add an unknown extension.
100    item = raw.item.add()
101    item.type_id = 98218603
102    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
103    message1.i = 12345
104    item.message = message1.SerializeToString()
105
106    serialized = raw.SerializeToString()
107
108    # Parse message using the message set wire format.
109    proto = message_set_extensions_pb2.TestMessageSet()
110    proto.MergeFromString(serialized)
111
112    unknown_fields = proto.UnknownFields()
113    self.assertEqual(len(unknown_fields), 1)
114    # Unknown field should have wire format data which can be parsed back to
115    # original message.
116    self.assertEqual(unknown_fields[0].field_number, item.type_id)
117    self.assertEqual(unknown_fields[0].wire_type,
118                     wire_format.WIRETYPE_LENGTH_DELIMITED)
119    d = unknown_fields[0].data
120    message_new = message_set_extensions_pb2.TestMessageSetExtension1()
121    message_new.ParseFromString(d)
122    self.assertEqual(message1, message_new)
123
124    # Verify that the unknown extension is serialized unchanged
125    reserialized = proto.SerializeToString()
126    new_raw = unittest_mset_pb2.RawMessageSet()
127    new_raw.MergeFromString(reserialized)
128    self.assertEqual(raw, new_raw)
129
130  def testEquals(self):
131    message = unittest_pb2.TestEmptyMessage()
132    message.ParseFromString(self.all_fields_data)
133    self.assertEqual(self.empty_message, message)
134
135    self.all_fields.ClearField('optional_string')
136    message.ParseFromString(self.all_fields.SerializeToString())
137    self.assertNotEqual(self.empty_message, message)
138
139  def testDiscardUnknownFields(self):
140    self.empty_message.DiscardUnknownFields()
141    self.assertEqual(b'', self.empty_message.SerializeToString())
142    # Test message field and repeated message field.
143    message = unittest_pb2.TestAllTypes()
144    other_message = unittest_pb2.TestAllTypes()
145    other_message.optional_string = 'discard'
146    message.optional_nested_message.ParseFromString(
147        other_message.SerializeToString())
148    message.repeated_nested_message.add().ParseFromString(
149        other_message.SerializeToString())
150    self.assertNotEqual(
151        b'', message.optional_nested_message.SerializeToString())
152    self.assertNotEqual(
153        b'', message.repeated_nested_message[0].SerializeToString())
154    message.DiscardUnknownFields()
155    self.assertEqual(b'', message.optional_nested_message.SerializeToString())
156    self.assertEqual(
157        b'', message.repeated_nested_message[0].SerializeToString())
158
159    msg = map_unittest_pb2.TestMap()
160    msg.map_int32_all_types[1].optional_nested_message.ParseFromString(
161        other_message.SerializeToString())
162    msg.map_string_string['1'] = 'test'
163    self.assertNotEqual(
164        b'',
165        msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
166    msg.DiscardUnknownFields()
167    self.assertEqual(
168        b'',
169        msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
170
171
172@testing_refleaks.TestCase
173class UnknownFieldsAccessorsTest(unittest.TestCase):
174
175  def setUp(self):
176    self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
177    self.all_fields = unittest_pb2.TestAllTypes()
178    test_util.SetAllFields(self.all_fields)
179    self.all_fields_data = self.all_fields.SerializeToString()
180    self.empty_message = unittest_pb2.TestEmptyMessage()
181    self.empty_message.ParseFromString(self.all_fields_data)
182
183  # InternalCheckUnknownField() is an additional Pure Python check which checks
184  # a detail of unknown fields. It cannot be used by the C++
185  # implementation because some protect members are called.
186  # The test is added for historical reasons. It is not necessary as
187  # serialized string is checked.
188  # TODO(jieluo): Remove message._unknown_fields.
189  def InternalCheckUnknownField(self, name, expected_value):
190    if api_implementation.Type() == 'cpp':
191      return
192    field_descriptor = self.descriptor.fields_by_name[name]
193    wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
194    field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
195    result_dict = {}
196    for tag_bytes, value in self.empty_message._unknown_fields:
197      if tag_bytes == field_tag:
198        decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
199        decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
200    self.assertEqual(expected_value, result_dict[field_descriptor])
201
202  def CheckUnknownField(self, name, unknown_fields, expected_value):
203    field_descriptor = self.descriptor.fields_by_name[name]
204    expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
205        field_descriptor.type]
206    for unknown_field in unknown_fields:
207      if unknown_field.field_number == field_descriptor.number:
208        self.assertEqual(expected_type, unknown_field.wire_type)
209        if expected_type == 3:
210          # Check group
211          self.assertEqual(expected_value[0],
212                           unknown_field.data[0].field_number)
213          self.assertEqual(expected_value[1], unknown_field.data[0].wire_type)
214          self.assertEqual(expected_value[2], unknown_field.data[0].data)
215          continue
216        if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
217          self.assertIn(type(unknown_field.data), (str, bytes))
218        if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
219          self.assertIn(unknown_field.data, expected_value)
220        else:
221          self.assertEqual(expected_value, unknown_field.data)
222
223  def testCheckUnknownFieldValue(self):
224    unknown_fields = self.empty_message.UnknownFields()
225    # Test enum.
226    self.CheckUnknownField('optional_nested_enum',
227                           unknown_fields,
228                           self.all_fields.optional_nested_enum)
229    self.InternalCheckUnknownField('optional_nested_enum',
230                                   self.all_fields.optional_nested_enum)
231
232    # Test repeated enum.
233    self.CheckUnknownField('repeated_nested_enum',
234                           unknown_fields,
235                           self.all_fields.repeated_nested_enum)
236    self.InternalCheckUnknownField('repeated_nested_enum',
237                                   self.all_fields.repeated_nested_enum)
238
239    # Test varint.
240    self.CheckUnknownField('optional_int32',
241                           unknown_fields,
242                           self.all_fields.optional_int32)
243    self.InternalCheckUnknownField('optional_int32',
244                                   self.all_fields.optional_int32)
245
246    # Test fixed32.
247    self.CheckUnknownField('optional_fixed32',
248                           unknown_fields,
249                           self.all_fields.optional_fixed32)
250    self.InternalCheckUnknownField('optional_fixed32',
251                                   self.all_fields.optional_fixed32)
252
253    # Test fixed64.
254    self.CheckUnknownField('optional_fixed64',
255                           unknown_fields,
256                           self.all_fields.optional_fixed64)
257    self.InternalCheckUnknownField('optional_fixed64',
258                                   self.all_fields.optional_fixed64)
259
260    # Test length delimited.
261    self.CheckUnknownField('optional_string',
262                           unknown_fields,
263                           self.all_fields.optional_string.encode('utf-8'))
264    self.InternalCheckUnknownField('optional_string',
265                                   self.all_fields.optional_string)
266
267    # Test group.
268    self.CheckUnknownField('optionalgroup',
269                           unknown_fields,
270                           (17, 0, 117))
271    self.InternalCheckUnknownField('optionalgroup',
272                                   self.all_fields.optionalgroup)
273
274    self.assertEqual(97, len(unknown_fields))
275
276  def testCopyFrom(self):
277    message = unittest_pb2.TestEmptyMessage()
278    message.CopyFrom(self.empty_message)
279    self.assertEqual(message.SerializeToString(), self.all_fields_data)
280
281  def testMergeFrom(self):
282    message = unittest_pb2.TestAllTypes()
283    message.optional_int32 = 1
284    message.optional_uint32 = 2
285    source = unittest_pb2.TestEmptyMessage()
286    source.ParseFromString(message.SerializeToString())
287
288    message.ClearField('optional_int32')
289    message.optional_int64 = 3
290    message.optional_uint32 = 4
291    destination = unittest_pb2.TestEmptyMessage()
292    unknown_fields = destination.UnknownFields()
293    self.assertEqual(0, len(unknown_fields))
294    destination.ParseFromString(message.SerializeToString())
295    # ParseFromString clears the message thus unknown fields is invalid.
296    with self.assertRaises(ValueError) as context:
297      len(unknown_fields)
298    self.assertIn('UnknownFields does not exist.',
299                  str(context.exception))
300    unknown_fields = destination.UnknownFields()
301    self.assertEqual(2, len(unknown_fields))
302    destination.MergeFrom(source)
303    self.assertEqual(4, len(unknown_fields))
304    # Check that the fields where correctly merged, even stored in the unknown
305    # fields set.
306    message.ParseFromString(destination.SerializeToString())
307    self.assertEqual(message.optional_int32, 1)
308    self.assertEqual(message.optional_uint32, 2)
309    self.assertEqual(message.optional_int64, 3)
310
311  def testClear(self):
312    unknown_fields = self.empty_message.UnknownFields()
313    self.empty_message.Clear()
314    # All cleared, even unknown fields.
315    self.assertEqual(self.empty_message.SerializeToString(), b'')
316    with self.assertRaises(ValueError) as context:
317      len(unknown_fields)
318    self.assertIn('UnknownFields does not exist.',
319                  str(context.exception))
320
321  @unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4),
322                   'tracemalloc requires python 3.4+')
323  def testUnknownFieldsNoMemoryLeak(self):
324    # Call to UnknownFields must not leak memory
325    nb_leaks = 1234
326    def leaking_function():
327      for _ in range(nb_leaks):
328        self.empty_message.UnknownFields()
329    tracemalloc.start()
330    snapshot1 = tracemalloc.take_snapshot()
331    leaking_function()
332    snapshot2 = tracemalloc.take_snapshot()
333    top_stats = snapshot2.compare_to(snapshot1, 'lineno')
334    tracemalloc.stop()
335    # There's no easy way to look for a precise leak source.
336    # Rely on a "marker" count value while checking allocated memory.
337    self.assertEqual([], [x for x in top_stats if x.count_diff == nb_leaks])
338
339  def testSubUnknownFields(self):
340    message = unittest_pb2.TestAllTypes()
341    message.optionalgroup.a = 123
342    destination = unittest_pb2.TestEmptyMessage()
343    destination.ParseFromString(message.SerializeToString())
344    sub_unknown_fields = destination.UnknownFields()[0].data
345    self.assertEqual(1, len(sub_unknown_fields))
346    self.assertEqual(sub_unknown_fields[0].data, 123)
347    destination.Clear()
348    with self.assertRaises(ValueError) as context:
349      len(sub_unknown_fields)
350    self.assertIn('UnknownFields does not exist.',
351                  str(context.exception))
352    with self.assertRaises(ValueError) as context:
353      # pylint: disable=pointless-statement
354      sub_unknown_fields[0]
355    self.assertIn('UnknownFields does not exist.',
356                  str(context.exception))
357    message.Clear()
358    message.optional_uint32 = 456
359    nested_message = unittest_pb2.NestedTestAllTypes()
360    nested_message.payload.optional_nested_message.ParseFromString(
361        message.SerializeToString())
362    unknown_fields = (
363        nested_message.payload.optional_nested_message.UnknownFields())
364    self.assertEqual(unknown_fields[0].data, 456)
365    nested_message.ClearField('payload')
366    self.assertEqual(unknown_fields[0].data, 456)
367    unknown_fields = (
368        nested_message.payload.optional_nested_message.UnknownFields())
369    self.assertEqual(0, len(unknown_fields))
370
371  def testUnknownField(self):
372    message = unittest_pb2.TestAllTypes()
373    message.optional_int32 = 123
374    destination = unittest_pb2.TestEmptyMessage()
375    destination.ParseFromString(message.SerializeToString())
376    unknown_field = destination.UnknownFields()[0]
377    destination.Clear()
378    with self.assertRaises(ValueError) as context:
379      unknown_field.data    # pylint: disable=pointless-statement
380    self.assertIn('The parent message might be cleared.',
381                  str(context.exception))
382
383  def testUnknownExtensions(self):
384    message = unittest_pb2.TestEmptyMessageWithExtensions()
385    message.ParseFromString(self.all_fields_data)
386    self.assertEqual(len(message.UnknownFields()), 97)
387    self.assertEqual(message.SerializeToString(), self.all_fields_data)
388
389
390@testing_refleaks.TestCase
391class UnknownEnumValuesTest(unittest.TestCase):
392
393  def setUp(self):
394    self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
395
396    self.message = missing_enum_values_pb2.TestEnumValues()
397    # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum.
398    self.message.optional_nested_enum = (
399        missing_enum_values_pb2.TestEnumValues.ZERO)
400    self.message.repeated_nested_enum.extend([
401        missing_enum_values_pb2.TestEnumValues.ZERO,
402        missing_enum_values_pb2.TestEnumValues.ONE,
403        ])
404    self.message.packed_nested_enum.extend([
405        missing_enum_values_pb2.TestEnumValues.ZERO,
406        missing_enum_values_pb2.TestEnumValues.ONE,
407        ])
408    self.message_data = self.message.SerializeToString()
409    self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
410    self.missing_message.ParseFromString(self.message_data)
411
412  # CheckUnknownField() is an additional Pure Python check which checks
413  # a detail of unknown fields. It cannot be used by the C++
414  # implementation because some protect members are called.
415  # The test is added for historical reasons. It is not necessary as
416  # serialized string is checked.
417
418  def CheckUnknownField(self, name, expected_value):
419    field_descriptor = self.descriptor.fields_by_name[name]
420    unknown_fields = self.missing_message.UnknownFields()
421    count = 0
422    for field in unknown_fields:
423      if field.field_number == field_descriptor.number:
424        count += 1
425        if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
426          self.assertIn(field.data, expected_value)
427        else:
428          self.assertEqual(expected_value, field.data)
429    if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
430      self.assertEqual(count, len(expected_value))
431    else:
432      self.assertEqual(count, 1)
433
434  def testUnknownParseMismatchEnumValue(self):
435    just_string = missing_enum_values_pb2.JustString()
436    just_string.dummy = 'blah'
437
438    missing = missing_enum_values_pb2.TestEnumValues()
439    # The parse is invalid, storing the string proto into the set of
440    # unknown fields.
441    missing.ParseFromString(just_string.SerializeToString())
442
443    # Fetching the enum field shouldn't crash, instead returning the
444    # default value.
445    self.assertEqual(missing.optional_nested_enum, 0)
446
447  def testUnknownEnumValue(self):
448    self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
449    self.assertEqual(self.missing_message.optional_nested_enum, 2)
450    # Clear does not do anything.
451    serialized = self.missing_message.SerializeToString()
452    self.missing_message.ClearField('optional_nested_enum')
453    self.assertEqual(self.missing_message.SerializeToString(), serialized)
454
455  def testUnknownRepeatedEnumValue(self):
456    self.assertEqual([], self.missing_message.repeated_nested_enum)
457
458  def testUnknownPackedEnumValue(self):
459    self.assertEqual([], self.missing_message.packed_nested_enum)
460
461  def testCheckUnknownFieldValueForEnum(self):
462    unknown_fields = self.missing_message.UnknownFields()
463    self.assertEqual(len(unknown_fields), 5)
464    self.CheckUnknownField('optional_nested_enum',
465                           self.message.optional_nested_enum)
466    self.CheckUnknownField('repeated_nested_enum',
467                           self.message.repeated_nested_enum)
468    self.CheckUnknownField('packed_nested_enum',
469                           self.message.packed_nested_enum)
470
471  def testRoundTrip(self):
472    new_message = missing_enum_values_pb2.TestEnumValues()
473    new_message.ParseFromString(self.missing_message.SerializeToString())
474    self.assertEqual(self.message, new_message)
475
476
477if __name__ == '__main__':
478  unittest.main()
479