• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#! /usr/bin/env python
2# -*- coding: utf-8 -*-
3#
4# Protocol Buffers - Google's data interchange format
5# Copyright 2008 Google Inc.  All rights reserved.
6# https://developers.google.com/protocol-buffers/
7#
8# Redistribution and use in source and binary forms, with or without
9# modification, are permitted provided that the following conditions are
10# met:
11#
12#     * Redistributions of source code must retain the above copyright
13# notice, this list of conditions and the following disclaimer.
14#     * Redistributions in binary form must reproduce the above
15# copyright notice, this list of conditions and the following disclaimer
16# in the documentation and/or other materials provided with the
17# distribution.
18#     * Neither the name of Google Inc. nor the names of its
19# contributors may be used to endorse or promote products derived from
20# this software without specific prior written permission.
21#
22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34"""Test for preservation of unknown fields in the pure Python implementation."""
35
36__author__ = 'bohdank@google.com (Bohdan Koval)'
37
38try:
39  import unittest2 as unittest  #PY26
40except ImportError:
41  import unittest
42from google.protobuf import map_unittest_pb2
43from google.protobuf import unittest_mset_pb2
44from google.protobuf import unittest_pb2
45from google.protobuf import unittest_proto3_arena_pb2
46from google.protobuf.internal import api_implementation
47from google.protobuf.internal import encoder
48from google.protobuf.internal import message_set_extensions_pb2
49from google.protobuf.internal import missing_enum_values_pb2
50from google.protobuf.internal import test_util
51from google.protobuf.internal import testing_refleaks
52from google.protobuf.internal import type_checkers
53from google.protobuf.internal import wire_format
54from google.protobuf import descriptor
55
56
57@testing_refleaks.TestCase
58class UnknownFieldsTest(unittest.TestCase):
59
60  def setUp(self):
61    self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
62    self.all_fields = unittest_pb2.TestAllTypes()
63    test_util.SetAllFields(self.all_fields)
64    self.all_fields_data = self.all_fields.SerializeToString()
65    self.empty_message = unittest_pb2.TestEmptyMessage()
66    self.empty_message.ParseFromString(self.all_fields_data)
67
68  def testSerialize(self):
69    data = self.empty_message.SerializeToString()
70
71    # Don't use assertEqual because we don't want to dump raw binary data to
72    # stdout.
73    self.assertTrue(data == self.all_fields_data)
74
75  def testSerializeProto3(self):
76    # Verify proto3 unknown fields behavior.
77    message = unittest_proto3_arena_pb2.TestEmptyMessage()
78    message.ParseFromString(self.all_fields_data)
79    self.assertEqual(self.all_fields_data, message.SerializeToString())
80
81  def testByteSize(self):
82    self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
83
84  def testListFields(self):
85    # Make sure ListFields doesn't return unknown fields.
86    self.assertEqual(0, len(self.empty_message.ListFields()))
87
88  def testSerializeMessageSetWireFormatUnknownExtension(self):
89    # Create a message using the message set wire format with an unknown
90    # message.
91    raw = unittest_mset_pb2.RawMessageSet()
92
93    # Add an unknown extension.
94    item = raw.item.add()
95    item.type_id = 98218603
96    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
97    message1.i = 12345
98    item.message = message1.SerializeToString()
99
100    serialized = raw.SerializeToString()
101
102    # Parse message using the message set wire format.
103    proto = message_set_extensions_pb2.TestMessageSet()
104    proto.MergeFromString(serialized)
105
106    unknown_fields = proto.UnknownFields()
107    self.assertEqual(len(unknown_fields), 1)
108    # Unknown field should have wire format data which can be parsed back to
109    # original message.
110    self.assertEqual(unknown_fields[0].field_number, item.type_id)
111    self.assertEqual(unknown_fields[0].wire_type,
112                     wire_format.WIRETYPE_LENGTH_DELIMITED)
113    d = unknown_fields[0].data
114    message_new = message_set_extensions_pb2.TestMessageSetExtension1()
115    message_new.ParseFromString(d)
116    self.assertEqual(message1, message_new)
117
118    # Verify that the unknown extension is serialized unchanged
119    reserialized = proto.SerializeToString()
120    new_raw = unittest_mset_pb2.RawMessageSet()
121    new_raw.MergeFromString(reserialized)
122    self.assertEqual(raw, new_raw)
123
124  def testEquals(self):
125    message = unittest_pb2.TestEmptyMessage()
126    message.ParseFromString(self.all_fields_data)
127    self.assertEqual(self.empty_message, message)
128
129    self.all_fields.ClearField('optional_string')
130    message.ParseFromString(self.all_fields.SerializeToString())
131    self.assertNotEqual(self.empty_message, message)
132
133  def testDiscardUnknownFields(self):
134    self.empty_message.DiscardUnknownFields()
135    self.assertEqual(b'', self.empty_message.SerializeToString())
136    # Test message field and repeated message field.
137    message = unittest_pb2.TestAllTypes()
138    other_message = unittest_pb2.TestAllTypes()
139    other_message.optional_string = 'discard'
140    message.optional_nested_message.ParseFromString(
141        other_message.SerializeToString())
142    message.repeated_nested_message.add().ParseFromString(
143        other_message.SerializeToString())
144    self.assertNotEqual(
145        b'', message.optional_nested_message.SerializeToString())
146    self.assertNotEqual(
147        b'', message.repeated_nested_message[0].SerializeToString())
148    message.DiscardUnknownFields()
149    self.assertEqual(b'', message.optional_nested_message.SerializeToString())
150    self.assertEqual(
151        b'', message.repeated_nested_message[0].SerializeToString())
152
153    msg = map_unittest_pb2.TestMap()
154    msg.map_int32_all_types[1].optional_nested_message.ParseFromString(
155        other_message.SerializeToString())
156    msg.map_string_string['1'] = 'test'
157    self.assertNotEqual(
158        b'',
159        msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
160    msg.DiscardUnknownFields()
161    self.assertEqual(
162        b'',
163        msg.map_int32_all_types[1].optional_nested_message.SerializeToString())
164
165
166@testing_refleaks.TestCase
167class UnknownFieldsAccessorsTest(unittest.TestCase):
168
169  def setUp(self):
170    self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
171    self.all_fields = unittest_pb2.TestAllTypes()
172    test_util.SetAllFields(self.all_fields)
173    self.all_fields_data = self.all_fields.SerializeToString()
174    self.empty_message = unittest_pb2.TestEmptyMessage()
175    self.empty_message.ParseFromString(self.all_fields_data)
176
177  # InternalCheckUnknownField() is an additional Pure Python check which checks
178  # a detail of unknown fields. It cannot be used by the C++
179  # implementation because some protect members are called.
180  # The test is added for historical reasons. It is not necessary as
181  # serialized string is checked.
182  # TODO(jieluo): Remove message._unknown_fields.
183  def InternalCheckUnknownField(self, name, expected_value):
184    if api_implementation.Type() == 'cpp':
185      return
186    field_descriptor = self.descriptor.fields_by_name[name]
187    wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
188    field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
189    result_dict = {}
190    for tag_bytes, value in self.empty_message._unknown_fields:
191      if tag_bytes == field_tag:
192        decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
193        decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
194    self.assertEqual(expected_value, result_dict[field_descriptor])
195
196  def CheckUnknownField(self, name, unknown_fields, expected_value):
197    field_descriptor = self.descriptor.fields_by_name[name]
198    expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
199        field_descriptor.type]
200    for unknown_field in unknown_fields:
201      if unknown_field.field_number == field_descriptor.number:
202        self.assertEqual(expected_type, unknown_field.wire_type)
203        if expected_type == 3:
204          # Check group
205          self.assertEqual(expected_value[0],
206                           unknown_field.data[0].field_number)
207          self.assertEqual(expected_value[1], unknown_field.data[0].wire_type)
208          self.assertEqual(expected_value[2], unknown_field.data[0].data)
209          continue
210        if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
211          self.assertIn(type(unknown_field.data), (str, bytes))
212        if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
213          self.assertIn(unknown_field.data, expected_value)
214        else:
215          self.assertEqual(expected_value, unknown_field.data)
216
217  def testCheckUnknownFieldValue(self):
218    unknown_fields = self.empty_message.UnknownFields()
219    # Test enum.
220    self.CheckUnknownField('optional_nested_enum',
221                           unknown_fields,
222                           self.all_fields.optional_nested_enum)
223    self.InternalCheckUnknownField('optional_nested_enum',
224                                   self.all_fields.optional_nested_enum)
225
226    # Test repeated enum.
227    self.CheckUnknownField('repeated_nested_enum',
228                           unknown_fields,
229                           self.all_fields.repeated_nested_enum)
230    self.InternalCheckUnknownField('repeated_nested_enum',
231                                   self.all_fields.repeated_nested_enum)
232
233    # Test varint.
234    self.CheckUnknownField('optional_int32',
235                           unknown_fields,
236                           self.all_fields.optional_int32)
237    self.InternalCheckUnknownField('optional_int32',
238                                   self.all_fields.optional_int32)
239
240    # Test fixed32.
241    self.CheckUnknownField('optional_fixed32',
242                           unknown_fields,
243                           self.all_fields.optional_fixed32)
244    self.InternalCheckUnknownField('optional_fixed32',
245                                   self.all_fields.optional_fixed32)
246
247    # Test fixed64.
248    self.CheckUnknownField('optional_fixed64',
249                           unknown_fields,
250                           self.all_fields.optional_fixed64)
251    self.InternalCheckUnknownField('optional_fixed64',
252                                   self.all_fields.optional_fixed64)
253
254    # Test length delimited.
255    self.CheckUnknownField('optional_string',
256                           unknown_fields,
257                           self.all_fields.optional_string.encode('utf-8'))
258    self.InternalCheckUnknownField('optional_string',
259                                   self.all_fields.optional_string)
260
261    # Test group.
262    self.CheckUnknownField('optionalgroup',
263                           unknown_fields,
264                           (17, 0, 117))
265    self.InternalCheckUnknownField('optionalgroup',
266                                   self.all_fields.optionalgroup)
267
268    self.assertEqual(97, len(unknown_fields))
269
270  def testCopyFrom(self):
271    message = unittest_pb2.TestEmptyMessage()
272    message.CopyFrom(self.empty_message)
273    self.assertEqual(message.SerializeToString(), self.all_fields_data)
274
275  def testMergeFrom(self):
276    message = unittest_pb2.TestAllTypes()
277    message.optional_int32 = 1
278    message.optional_uint32 = 2
279    source = unittest_pb2.TestEmptyMessage()
280    source.ParseFromString(message.SerializeToString())
281
282    message.ClearField('optional_int32')
283    message.optional_int64 = 3
284    message.optional_uint32 = 4
285    destination = unittest_pb2.TestEmptyMessage()
286    unknown_fields = destination.UnknownFields()
287    self.assertEqual(0, len(unknown_fields))
288    destination.ParseFromString(message.SerializeToString())
289    # ParseFromString clears the message thus unknown fields is invalid.
290    with self.assertRaises(ValueError) as context:
291      len(unknown_fields)
292    self.assertIn('UnknownFields does not exist.',
293                  str(context.exception))
294    unknown_fields = destination.UnknownFields()
295    self.assertEqual(2, len(unknown_fields))
296    destination.MergeFrom(source)
297    self.assertEqual(4, len(unknown_fields))
298    # Check that the fields where correctly merged, even stored in the unknown
299    # fields set.
300    message.ParseFromString(destination.SerializeToString())
301    self.assertEqual(message.optional_int32, 1)
302    self.assertEqual(message.optional_uint32, 2)
303    self.assertEqual(message.optional_int64, 3)
304
305  def testClear(self):
306    unknown_fields = self.empty_message.UnknownFields()
307    self.empty_message.Clear()
308    # All cleared, even unknown fields.
309    self.assertEqual(self.empty_message.SerializeToString(), b'')
310    with self.assertRaises(ValueError) as context:
311      len(unknown_fields)
312    self.assertIn('UnknownFields does not exist.',
313                  str(context.exception))
314
315  def testSubUnknownFields(self):
316    message = unittest_pb2.TestAllTypes()
317    message.optionalgroup.a = 123
318    destination = unittest_pb2.TestEmptyMessage()
319    destination.ParseFromString(message.SerializeToString())
320    sub_unknown_fields = destination.UnknownFields()[0].data
321    self.assertEqual(1, len(sub_unknown_fields))
322    self.assertEqual(sub_unknown_fields[0].data, 123)
323    destination.Clear()
324    with self.assertRaises(ValueError) as context:
325      len(sub_unknown_fields)
326    self.assertIn('UnknownFields does not exist.',
327                  str(context.exception))
328    with self.assertRaises(ValueError) as context:
329      # pylint: disable=pointless-statement
330      sub_unknown_fields[0]
331    self.assertIn('UnknownFields does not exist.',
332                  str(context.exception))
333    message.Clear()
334    message.optional_uint32 = 456
335    nested_message = unittest_pb2.NestedTestAllTypes()
336    nested_message.payload.optional_nested_message.ParseFromString(
337        message.SerializeToString())
338    unknown_fields = (
339        nested_message.payload.optional_nested_message.UnknownFields())
340    self.assertEqual(unknown_fields[0].data, 456)
341    nested_message.ClearField('payload')
342    self.assertEqual(unknown_fields[0].data, 456)
343    unknown_fields = (
344        nested_message.payload.optional_nested_message.UnknownFields())
345    self.assertEqual(0, len(unknown_fields))
346
347  def testUnknownField(self):
348    message = unittest_pb2.TestAllTypes()
349    message.optional_int32 = 123
350    destination = unittest_pb2.TestEmptyMessage()
351    destination.ParseFromString(message.SerializeToString())
352    unknown_field = destination.UnknownFields()[0]
353    destination.Clear()
354    with self.assertRaises(ValueError) as context:
355      unknown_field.data    # pylint: disable=pointless-statement
356    self.assertIn('The parent message might be cleared.',
357                  str(context.exception))
358
359  def testUnknownExtensions(self):
360    message = unittest_pb2.TestEmptyMessageWithExtensions()
361    message.ParseFromString(self.all_fields_data)
362    self.assertEqual(len(message.UnknownFields()), 97)
363    self.assertEqual(message.SerializeToString(), self.all_fields_data)
364
365
366@testing_refleaks.TestCase
367class UnknownEnumValuesTest(unittest.TestCase):
368
369  def setUp(self):
370    self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
371
372    self.message = missing_enum_values_pb2.TestEnumValues()
373    # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum.
374    self.message.optional_nested_enum = (
375        missing_enum_values_pb2.TestEnumValues.ZERO)
376    self.message.repeated_nested_enum.extend([
377        missing_enum_values_pb2.TestEnumValues.ZERO,
378        missing_enum_values_pb2.TestEnumValues.ONE,
379        ])
380    self.message.packed_nested_enum.extend([
381        missing_enum_values_pb2.TestEnumValues.ZERO,
382        missing_enum_values_pb2.TestEnumValues.ONE,
383        ])
384    self.message_data = self.message.SerializeToString()
385    self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
386    self.missing_message.ParseFromString(self.message_data)
387
388  # CheckUnknownField() is an additional Pure Python check which checks
389  # a detail of unknown fields. It cannot be used by the C++
390  # implementation because some protect members are called.
391  # The test is added for historical reasons. It is not necessary as
392  # serialized string is checked.
393
394  def CheckUnknownField(self, name, expected_value):
395    field_descriptor = self.descriptor.fields_by_name[name]
396    unknown_fields = self.missing_message.UnknownFields()
397    count = 0
398    for field in unknown_fields:
399      if field.field_number == field_descriptor.number:
400        count += 1
401        if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
402          self.assertIn(field.data, expected_value)
403        else:
404          self.assertEqual(expected_value, field.data)
405    if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
406      self.assertEqual(count, len(expected_value))
407    else:
408      self.assertEqual(count, 1)
409
410  def testUnknownParseMismatchEnumValue(self):
411    just_string = missing_enum_values_pb2.JustString()
412    just_string.dummy = 'blah'
413
414    missing = missing_enum_values_pb2.TestEnumValues()
415    # The parse is invalid, storing the string proto into the set of
416    # unknown fields.
417    missing.ParseFromString(just_string.SerializeToString())
418
419    # Fetching the enum field shouldn't crash, instead returning the
420    # default value.
421    self.assertEqual(missing.optional_nested_enum, 0)
422
423  def testUnknownEnumValue(self):
424    self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
425    self.assertEqual(self.missing_message.optional_nested_enum, 2)
426    # Clear does not do anything.
427    serialized = self.missing_message.SerializeToString()
428    self.missing_message.ClearField('optional_nested_enum')
429    self.assertEqual(self.missing_message.SerializeToString(), serialized)
430
431  def testUnknownRepeatedEnumValue(self):
432    self.assertEqual([], self.missing_message.repeated_nested_enum)
433
434  def testUnknownPackedEnumValue(self):
435    self.assertEqual([], self.missing_message.packed_nested_enum)
436
437  def testCheckUnknownFieldValueForEnum(self):
438    unknown_fields = self.missing_message.UnknownFields()
439    self.assertEqual(len(unknown_fields), 5)
440    self.CheckUnknownField('optional_nested_enum',
441                           self.message.optional_nested_enum)
442    self.CheckUnknownField('repeated_nested_enum',
443                           self.message.repeated_nested_enum)
444    self.CheckUnknownField('packed_nested_enum',
445                           self.message.packed_nested_enum)
446
447  def testRoundTrip(self):
448    new_message = missing_enum_values_pb2.TestEnumValues()
449    new_message.ParseFromString(self.missing_message.SerializeToString())
450    self.assertEqual(self.message, new_message)
451
452
453if __name__ == '__main__':
454  unittest.main()
455