1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# https://developers.google.com/protocol-buffers/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Test for preservation of unknown fields in the pure Python implementation.""" 35 36__author__ = 'bohdank@google.com (Bohdan Koval)' 37 38try: 39 import unittest2 as unittest #PY26 40except ImportError: 41 import unittest 42from google.protobuf import map_unittest_pb2 43from google.protobuf import unittest_mset_pb2 44from google.protobuf import unittest_pb2 45from google.protobuf import unittest_proto3_arena_pb2 46from google.protobuf.internal import api_implementation 47from google.protobuf.internal import encoder 48from google.protobuf.internal import message_set_extensions_pb2 49from google.protobuf.internal import missing_enum_values_pb2 50from google.protobuf.internal import test_util 51from google.protobuf.internal import testing_refleaks 52from google.protobuf.internal import type_checkers 53from google.protobuf.internal import wire_format 54from google.protobuf import descriptor 55 56 57@testing_refleaks.TestCase 58class UnknownFieldsTest(unittest.TestCase): 59 60 def setUp(self): 61 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 62 self.all_fields = unittest_pb2.TestAllTypes() 63 test_util.SetAllFields(self.all_fields) 64 self.all_fields_data = self.all_fields.SerializeToString() 65 self.empty_message = unittest_pb2.TestEmptyMessage() 66 self.empty_message.ParseFromString(self.all_fields_data) 67 68 def testSerialize(self): 69 data = self.empty_message.SerializeToString() 70 71 # Don't use assertEqual because we don't want to dump raw binary data to 72 # stdout. 73 self.assertTrue(data == self.all_fields_data) 74 75 def testSerializeProto3(self): 76 # Verify proto3 unknown fields behavior. 77 message = unittest_proto3_arena_pb2.TestEmptyMessage() 78 message.ParseFromString(self.all_fields_data) 79 self.assertEqual(self.all_fields_data, message.SerializeToString()) 80 81 def testByteSize(self): 82 self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) 83 84 def testListFields(self): 85 # Make sure ListFields doesn't return unknown fields. 86 self.assertEqual(0, len(self.empty_message.ListFields())) 87 88 def testSerializeMessageSetWireFormatUnknownExtension(self): 89 # Create a message using the message set wire format with an unknown 90 # message. 91 raw = unittest_mset_pb2.RawMessageSet() 92 93 # Add an unknown extension. 94 item = raw.item.add() 95 item.type_id = 98218603 96 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 97 message1.i = 12345 98 item.message = message1.SerializeToString() 99 100 serialized = raw.SerializeToString() 101 102 # Parse message using the message set wire format. 103 proto = message_set_extensions_pb2.TestMessageSet() 104 proto.MergeFromString(serialized) 105 106 unknown_fields = proto.UnknownFields() 107 self.assertEqual(len(unknown_fields), 1) 108 # Unknown field should have wire format data which can be parsed back to 109 # original message. 110 self.assertEqual(unknown_fields[0].field_number, item.type_id) 111 self.assertEqual(unknown_fields[0].wire_type, 112 wire_format.WIRETYPE_LENGTH_DELIMITED) 113 d = unknown_fields[0].data 114 message_new = message_set_extensions_pb2.TestMessageSetExtension1() 115 message_new.ParseFromString(d) 116 self.assertEqual(message1, message_new) 117 118 # Verify that the unknown extension is serialized unchanged 119 reserialized = proto.SerializeToString() 120 new_raw = unittest_mset_pb2.RawMessageSet() 121 new_raw.MergeFromString(reserialized) 122 self.assertEqual(raw, new_raw) 123 124 def testEquals(self): 125 message = unittest_pb2.TestEmptyMessage() 126 message.ParseFromString(self.all_fields_data) 127 self.assertEqual(self.empty_message, message) 128 129 self.all_fields.ClearField('optional_string') 130 message.ParseFromString(self.all_fields.SerializeToString()) 131 self.assertNotEqual(self.empty_message, message) 132 133 def testDiscardUnknownFields(self): 134 self.empty_message.DiscardUnknownFields() 135 self.assertEqual(b'', self.empty_message.SerializeToString()) 136 # Test message field and repeated message field. 137 message = unittest_pb2.TestAllTypes() 138 other_message = unittest_pb2.TestAllTypes() 139 other_message.optional_string = 'discard' 140 message.optional_nested_message.ParseFromString( 141 other_message.SerializeToString()) 142 message.repeated_nested_message.add().ParseFromString( 143 other_message.SerializeToString()) 144 self.assertNotEqual( 145 b'', message.optional_nested_message.SerializeToString()) 146 self.assertNotEqual( 147 b'', message.repeated_nested_message[0].SerializeToString()) 148 message.DiscardUnknownFields() 149 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) 150 self.assertEqual( 151 b'', message.repeated_nested_message[0].SerializeToString()) 152 153 msg = map_unittest_pb2.TestMap() 154 msg.map_int32_all_types[1].optional_nested_message.ParseFromString( 155 other_message.SerializeToString()) 156 msg.map_string_string['1'] = 'test' 157 self.assertNotEqual( 158 b'', 159 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 160 msg.DiscardUnknownFields() 161 self.assertEqual( 162 b'', 163 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 164 165 166@testing_refleaks.TestCase 167class UnknownFieldsAccessorsTest(unittest.TestCase): 168 169 def setUp(self): 170 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 171 self.all_fields = unittest_pb2.TestAllTypes() 172 test_util.SetAllFields(self.all_fields) 173 self.all_fields_data = self.all_fields.SerializeToString() 174 self.empty_message = unittest_pb2.TestEmptyMessage() 175 self.empty_message.ParseFromString(self.all_fields_data) 176 177 # InternalCheckUnknownField() is an additional Pure Python check which checks 178 # a detail of unknown fields. It cannot be used by the C++ 179 # implementation because some protect members are called. 180 # The test is added for historical reasons. It is not necessary as 181 # serialized string is checked. 182 # TODO(jieluo): Remove message._unknown_fields. 183 def InternalCheckUnknownField(self, name, expected_value): 184 if api_implementation.Type() == 'cpp': 185 return 186 field_descriptor = self.descriptor.fields_by_name[name] 187 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] 188 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) 189 result_dict = {} 190 for tag_bytes, value in self.empty_message._unknown_fields: 191 if tag_bytes == field_tag: 192 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] 193 decoder(memoryview(value), 0, len(value), self.all_fields, result_dict) 194 self.assertEqual(expected_value, result_dict[field_descriptor]) 195 196 def CheckUnknownField(self, name, unknown_fields, expected_value): 197 field_descriptor = self.descriptor.fields_by_name[name] 198 expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ 199 field_descriptor.type] 200 for unknown_field in unknown_fields: 201 if unknown_field.field_number == field_descriptor.number: 202 self.assertEqual(expected_type, unknown_field.wire_type) 203 if expected_type == 3: 204 # Check group 205 self.assertEqual(expected_value[0], 206 unknown_field.data[0].field_number) 207 self.assertEqual(expected_value[1], unknown_field.data[0].wire_type) 208 self.assertEqual(expected_value[2], unknown_field.data[0].data) 209 continue 210 if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 211 self.assertIn(type(unknown_field.data), (str, bytes)) 212 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 213 self.assertIn(unknown_field.data, expected_value) 214 else: 215 self.assertEqual(expected_value, unknown_field.data) 216 217 def testCheckUnknownFieldValue(self): 218 unknown_fields = self.empty_message.UnknownFields() 219 # Test enum. 220 self.CheckUnknownField('optional_nested_enum', 221 unknown_fields, 222 self.all_fields.optional_nested_enum) 223 self.InternalCheckUnknownField('optional_nested_enum', 224 self.all_fields.optional_nested_enum) 225 226 # Test repeated enum. 227 self.CheckUnknownField('repeated_nested_enum', 228 unknown_fields, 229 self.all_fields.repeated_nested_enum) 230 self.InternalCheckUnknownField('repeated_nested_enum', 231 self.all_fields.repeated_nested_enum) 232 233 # Test varint. 234 self.CheckUnknownField('optional_int32', 235 unknown_fields, 236 self.all_fields.optional_int32) 237 self.InternalCheckUnknownField('optional_int32', 238 self.all_fields.optional_int32) 239 240 # Test fixed32. 241 self.CheckUnknownField('optional_fixed32', 242 unknown_fields, 243 self.all_fields.optional_fixed32) 244 self.InternalCheckUnknownField('optional_fixed32', 245 self.all_fields.optional_fixed32) 246 247 # Test fixed64. 248 self.CheckUnknownField('optional_fixed64', 249 unknown_fields, 250 self.all_fields.optional_fixed64) 251 self.InternalCheckUnknownField('optional_fixed64', 252 self.all_fields.optional_fixed64) 253 254 # Test length delimited. 255 self.CheckUnknownField('optional_string', 256 unknown_fields, 257 self.all_fields.optional_string.encode('utf-8')) 258 self.InternalCheckUnknownField('optional_string', 259 self.all_fields.optional_string) 260 261 # Test group. 262 self.CheckUnknownField('optionalgroup', 263 unknown_fields, 264 (17, 0, 117)) 265 self.InternalCheckUnknownField('optionalgroup', 266 self.all_fields.optionalgroup) 267 268 self.assertEqual(97, len(unknown_fields)) 269 270 def testCopyFrom(self): 271 message = unittest_pb2.TestEmptyMessage() 272 message.CopyFrom(self.empty_message) 273 self.assertEqual(message.SerializeToString(), self.all_fields_data) 274 275 def testMergeFrom(self): 276 message = unittest_pb2.TestAllTypes() 277 message.optional_int32 = 1 278 message.optional_uint32 = 2 279 source = unittest_pb2.TestEmptyMessage() 280 source.ParseFromString(message.SerializeToString()) 281 282 message.ClearField('optional_int32') 283 message.optional_int64 = 3 284 message.optional_uint32 = 4 285 destination = unittest_pb2.TestEmptyMessage() 286 unknown_fields = destination.UnknownFields() 287 self.assertEqual(0, len(unknown_fields)) 288 destination.ParseFromString(message.SerializeToString()) 289 # ParseFromString clears the message thus unknown fields is invalid. 290 with self.assertRaises(ValueError) as context: 291 len(unknown_fields) 292 self.assertIn('UnknownFields does not exist.', 293 str(context.exception)) 294 unknown_fields = destination.UnknownFields() 295 self.assertEqual(2, len(unknown_fields)) 296 destination.MergeFrom(source) 297 self.assertEqual(4, len(unknown_fields)) 298 # Check that the fields where correctly merged, even stored in the unknown 299 # fields set. 300 message.ParseFromString(destination.SerializeToString()) 301 self.assertEqual(message.optional_int32, 1) 302 self.assertEqual(message.optional_uint32, 2) 303 self.assertEqual(message.optional_int64, 3) 304 305 def testClear(self): 306 unknown_fields = self.empty_message.UnknownFields() 307 self.empty_message.Clear() 308 # All cleared, even unknown fields. 309 self.assertEqual(self.empty_message.SerializeToString(), b'') 310 with self.assertRaises(ValueError) as context: 311 len(unknown_fields) 312 self.assertIn('UnknownFields does not exist.', 313 str(context.exception)) 314 315 def testSubUnknownFields(self): 316 message = unittest_pb2.TestAllTypes() 317 message.optionalgroup.a = 123 318 destination = unittest_pb2.TestEmptyMessage() 319 destination.ParseFromString(message.SerializeToString()) 320 sub_unknown_fields = destination.UnknownFields()[0].data 321 self.assertEqual(1, len(sub_unknown_fields)) 322 self.assertEqual(sub_unknown_fields[0].data, 123) 323 destination.Clear() 324 with self.assertRaises(ValueError) as context: 325 len(sub_unknown_fields) 326 self.assertIn('UnknownFields does not exist.', 327 str(context.exception)) 328 with self.assertRaises(ValueError) as context: 329 # pylint: disable=pointless-statement 330 sub_unknown_fields[0] 331 self.assertIn('UnknownFields does not exist.', 332 str(context.exception)) 333 message.Clear() 334 message.optional_uint32 = 456 335 nested_message = unittest_pb2.NestedTestAllTypes() 336 nested_message.payload.optional_nested_message.ParseFromString( 337 message.SerializeToString()) 338 unknown_fields = ( 339 nested_message.payload.optional_nested_message.UnknownFields()) 340 self.assertEqual(unknown_fields[0].data, 456) 341 nested_message.ClearField('payload') 342 self.assertEqual(unknown_fields[0].data, 456) 343 unknown_fields = ( 344 nested_message.payload.optional_nested_message.UnknownFields()) 345 self.assertEqual(0, len(unknown_fields)) 346 347 def testUnknownField(self): 348 message = unittest_pb2.TestAllTypes() 349 message.optional_int32 = 123 350 destination = unittest_pb2.TestEmptyMessage() 351 destination.ParseFromString(message.SerializeToString()) 352 unknown_field = destination.UnknownFields()[0] 353 destination.Clear() 354 with self.assertRaises(ValueError) as context: 355 unknown_field.data # pylint: disable=pointless-statement 356 self.assertIn('The parent message might be cleared.', 357 str(context.exception)) 358 359 def testUnknownExtensions(self): 360 message = unittest_pb2.TestEmptyMessageWithExtensions() 361 message.ParseFromString(self.all_fields_data) 362 self.assertEqual(len(message.UnknownFields()), 97) 363 self.assertEqual(message.SerializeToString(), self.all_fields_data) 364 365 366@testing_refleaks.TestCase 367class UnknownEnumValuesTest(unittest.TestCase): 368 369 def setUp(self): 370 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR 371 372 self.message = missing_enum_values_pb2.TestEnumValues() 373 # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum. 374 self.message.optional_nested_enum = ( 375 missing_enum_values_pb2.TestEnumValues.ZERO) 376 self.message.repeated_nested_enum.extend([ 377 missing_enum_values_pb2.TestEnumValues.ZERO, 378 missing_enum_values_pb2.TestEnumValues.ONE, 379 ]) 380 self.message.packed_nested_enum.extend([ 381 missing_enum_values_pb2.TestEnumValues.ZERO, 382 missing_enum_values_pb2.TestEnumValues.ONE, 383 ]) 384 self.message_data = self.message.SerializeToString() 385 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() 386 self.missing_message.ParseFromString(self.message_data) 387 388 # CheckUnknownField() is an additional Pure Python check which checks 389 # a detail of unknown fields. It cannot be used by the C++ 390 # implementation because some protect members are called. 391 # The test is added for historical reasons. It is not necessary as 392 # serialized string is checked. 393 394 def CheckUnknownField(self, name, expected_value): 395 field_descriptor = self.descriptor.fields_by_name[name] 396 unknown_fields = self.missing_message.UnknownFields() 397 count = 0 398 for field in unknown_fields: 399 if field.field_number == field_descriptor.number: 400 count += 1 401 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 402 self.assertIn(field.data, expected_value) 403 else: 404 self.assertEqual(expected_value, field.data) 405 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 406 self.assertEqual(count, len(expected_value)) 407 else: 408 self.assertEqual(count, 1) 409 410 def testUnknownParseMismatchEnumValue(self): 411 just_string = missing_enum_values_pb2.JustString() 412 just_string.dummy = 'blah' 413 414 missing = missing_enum_values_pb2.TestEnumValues() 415 # The parse is invalid, storing the string proto into the set of 416 # unknown fields. 417 missing.ParseFromString(just_string.SerializeToString()) 418 419 # Fetching the enum field shouldn't crash, instead returning the 420 # default value. 421 self.assertEqual(missing.optional_nested_enum, 0) 422 423 def testUnknownEnumValue(self): 424 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) 425 self.assertEqual(self.missing_message.optional_nested_enum, 2) 426 # Clear does not do anything. 427 serialized = self.missing_message.SerializeToString() 428 self.missing_message.ClearField('optional_nested_enum') 429 self.assertEqual(self.missing_message.SerializeToString(), serialized) 430 431 def testUnknownRepeatedEnumValue(self): 432 self.assertEqual([], self.missing_message.repeated_nested_enum) 433 434 def testUnknownPackedEnumValue(self): 435 self.assertEqual([], self.missing_message.packed_nested_enum) 436 437 def testCheckUnknownFieldValueForEnum(self): 438 unknown_fields = self.missing_message.UnknownFields() 439 self.assertEqual(len(unknown_fields), 5) 440 self.CheckUnknownField('optional_nested_enum', 441 self.message.optional_nested_enum) 442 self.CheckUnknownField('repeated_nested_enum', 443 self.message.repeated_nested_enum) 444 self.CheckUnknownField('packed_nested_enum', 445 self.message.packed_nested_enum) 446 447 def testRoundTrip(self): 448 new_message = missing_enum_values_pb2.TestEnumValues() 449 new_message.ParseFromString(self.missing_message.SerializeToString()) 450 self.assertEqual(self.message, new_message) 451 452 453if __name__ == '__main__': 454 unittest.main() 455