1# -*- coding: utf-8 -*- 2# Protocol Buffers - Google's data interchange format 3# Copyright 2008 Google Inc. All rights reserved. 4# https://developers.google.com/protocol-buffers/ 5# 6# Redistribution and use in source and binary forms, with or without 7# modification, are permitted provided that the following conditions are 8# met: 9# 10# * Redistributions of source code must retain the above copyright 11# notice, this list of conditions and the following disclaimer. 12# * Redistributions in binary form must reproduce the above 13# copyright notice, this list of conditions and the following disclaimer 14# in the documentation and/or other materials provided with the 15# distribution. 16# * Neither the name of Google Inc. nor the names of its 17# contributors may be used to endorse or promote products derived from 18# this software without specific prior written permission. 19# 20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32"""Test for preservation of unknown fields in the pure Python implementation.""" 33 34__author__ = 'bohdank@google.com (Bohdan Koval)' 35 36import sys 37import unittest 38 39from google.protobuf import map_unittest_pb2 40from google.protobuf import unittest_mset_pb2 41from google.protobuf import unittest_pb2 42from google.protobuf import unittest_proto3_arena_pb2 43from google.protobuf.internal import api_implementation 44from google.protobuf.internal import encoder 45from google.protobuf.internal import message_set_extensions_pb2 46from google.protobuf.internal import missing_enum_values_pb2 47from google.protobuf.internal import test_util 48from google.protobuf.internal import testing_refleaks 49from google.protobuf.internal import type_checkers 50from google.protobuf.internal import wire_format 51from google.protobuf import descriptor 52 53try: 54 import tracemalloc # pylint: disable=g-import-not-at-top 55except ImportError: 56 # Requires python 3.4+ 57 pass 58 59 60@testing_refleaks.TestCase 61class UnknownFieldsTest(unittest.TestCase): 62 63 def setUp(self): 64 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 65 self.all_fields = unittest_pb2.TestAllTypes() 66 test_util.SetAllFields(self.all_fields) 67 self.all_fields_data = self.all_fields.SerializeToString() 68 self.empty_message = unittest_pb2.TestEmptyMessage() 69 self.empty_message.ParseFromString(self.all_fields_data) 70 71 def testSerialize(self): 72 data = self.empty_message.SerializeToString() 73 74 # Don't use assertEqual because we don't want to dump raw binary data to 75 # stdout. 76 self.assertTrue(data == self.all_fields_data) 77 78 def testSerializeProto3(self): 79 # Verify proto3 unknown fields behavior. 80 message = unittest_proto3_arena_pb2.TestEmptyMessage() 81 message.ParseFromString(self.all_fields_data) 82 self.assertEqual(self.all_fields_data, message.SerializeToString()) 83 84 def testByteSize(self): 85 self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) 86 87 def testListFields(self): 88 # Make sure ListFields doesn't return unknown fields. 89 self.assertEqual(0, len(self.empty_message.ListFields())) 90 91 def testSerializeMessageSetWireFormatUnknownExtension(self): 92 # Create a message using the message set wire format with an unknown 93 # message. 94 raw = unittest_mset_pb2.RawMessageSet() 95 96 # Add an unknown extension. 97 item = raw.item.add() 98 item.type_id = 98218603 99 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 100 message1.i = 12345 101 item.message = message1.SerializeToString() 102 103 serialized = raw.SerializeToString() 104 105 # Parse message using the message set wire format. 106 proto = message_set_extensions_pb2.TestMessageSet() 107 proto.MergeFromString(serialized) 108 109 unknown_fields = proto.UnknownFields() 110 self.assertEqual(len(unknown_fields), 1) 111 # Unknown field should have wire format data which can be parsed back to 112 # original message. 113 self.assertEqual(unknown_fields[0].field_number, item.type_id) 114 self.assertEqual(unknown_fields[0].wire_type, 115 wire_format.WIRETYPE_LENGTH_DELIMITED) 116 d = unknown_fields[0].data 117 message_new = message_set_extensions_pb2.TestMessageSetExtension1() 118 message_new.ParseFromString(d) 119 self.assertEqual(message1, message_new) 120 121 # Verify that the unknown extension is serialized unchanged 122 reserialized = proto.SerializeToString() 123 new_raw = unittest_mset_pb2.RawMessageSet() 124 new_raw.MergeFromString(reserialized) 125 self.assertEqual(raw, new_raw) 126 127 def testEquals(self): 128 message = unittest_pb2.TestEmptyMessage() 129 message.ParseFromString(self.all_fields_data) 130 self.assertEqual(self.empty_message, message) 131 132 self.all_fields.ClearField('optional_string') 133 message.ParseFromString(self.all_fields.SerializeToString()) 134 self.assertNotEqual(self.empty_message, message) 135 136 def testDiscardUnknownFields(self): 137 self.empty_message.DiscardUnknownFields() 138 self.assertEqual(b'', self.empty_message.SerializeToString()) 139 # Test message field and repeated message field. 140 message = unittest_pb2.TestAllTypes() 141 other_message = unittest_pb2.TestAllTypes() 142 other_message.optional_string = 'discard' 143 message.optional_nested_message.ParseFromString( 144 other_message.SerializeToString()) 145 message.repeated_nested_message.add().ParseFromString( 146 other_message.SerializeToString()) 147 self.assertNotEqual( 148 b'', message.optional_nested_message.SerializeToString()) 149 self.assertNotEqual( 150 b'', message.repeated_nested_message[0].SerializeToString()) 151 message.DiscardUnknownFields() 152 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) 153 self.assertEqual( 154 b'', message.repeated_nested_message[0].SerializeToString()) 155 156 msg = map_unittest_pb2.TestMap() 157 msg.map_int32_all_types[1].optional_nested_message.ParseFromString( 158 other_message.SerializeToString()) 159 msg.map_string_string['1'] = 'test' 160 self.assertNotEqual( 161 b'', 162 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 163 msg.DiscardUnknownFields() 164 self.assertEqual( 165 b'', 166 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 167 168 169@testing_refleaks.TestCase 170class UnknownFieldsAccessorsTest(unittest.TestCase): 171 172 def setUp(self): 173 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 174 self.all_fields = unittest_pb2.TestAllTypes() 175 test_util.SetAllFields(self.all_fields) 176 self.all_fields_data = self.all_fields.SerializeToString() 177 self.empty_message = unittest_pb2.TestEmptyMessage() 178 self.empty_message.ParseFromString(self.all_fields_data) 179 180 # InternalCheckUnknownField() is an additional Pure Python check which checks 181 # a detail of unknown fields. It cannot be used by the C++ 182 # implementation because some protect members are called. 183 # The test is added for historical reasons. It is not necessary as 184 # serialized string is checked. 185 # TODO(jieluo): Remove message._unknown_fields. 186 def InternalCheckUnknownField(self, name, expected_value): 187 if api_implementation.Type() == 'cpp': 188 return 189 field_descriptor = self.descriptor.fields_by_name[name] 190 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] 191 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) 192 result_dict = {} 193 for tag_bytes, value in self.empty_message._unknown_fields: 194 if tag_bytes == field_tag: 195 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] 196 decoder(memoryview(value), 0, len(value), self.all_fields, result_dict) 197 self.assertEqual(expected_value, result_dict[field_descriptor]) 198 199 def CheckUnknownField(self, name, unknown_fields, expected_value): 200 field_descriptor = self.descriptor.fields_by_name[name] 201 expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ 202 field_descriptor.type] 203 for unknown_field in unknown_fields: 204 if unknown_field.field_number == field_descriptor.number: 205 self.assertEqual(expected_type, unknown_field.wire_type) 206 if expected_type == 3: 207 # Check group 208 self.assertEqual(expected_value[0], 209 unknown_field.data[0].field_number) 210 self.assertEqual(expected_value[1], unknown_field.data[0].wire_type) 211 self.assertEqual(expected_value[2], unknown_field.data[0].data) 212 continue 213 if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 214 self.assertIn(type(unknown_field.data), (str, bytes)) 215 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 216 self.assertIn(unknown_field.data, expected_value) 217 else: 218 self.assertEqual(expected_value, unknown_field.data) 219 220 def testCheckUnknownFieldValue(self): 221 unknown_fields = self.empty_message.UnknownFields() 222 # Test enum. 223 self.CheckUnknownField('optional_nested_enum', 224 unknown_fields, 225 self.all_fields.optional_nested_enum) 226 self.InternalCheckUnknownField('optional_nested_enum', 227 self.all_fields.optional_nested_enum) 228 229 # Test repeated enum. 230 self.CheckUnknownField('repeated_nested_enum', 231 unknown_fields, 232 self.all_fields.repeated_nested_enum) 233 self.InternalCheckUnknownField('repeated_nested_enum', 234 self.all_fields.repeated_nested_enum) 235 236 # Test varint. 237 self.CheckUnknownField('optional_int32', 238 unknown_fields, 239 self.all_fields.optional_int32) 240 self.InternalCheckUnknownField('optional_int32', 241 self.all_fields.optional_int32) 242 243 # Test fixed32. 244 self.CheckUnknownField('optional_fixed32', 245 unknown_fields, 246 self.all_fields.optional_fixed32) 247 self.InternalCheckUnknownField('optional_fixed32', 248 self.all_fields.optional_fixed32) 249 250 # Test fixed64. 251 self.CheckUnknownField('optional_fixed64', 252 unknown_fields, 253 self.all_fields.optional_fixed64) 254 self.InternalCheckUnknownField('optional_fixed64', 255 self.all_fields.optional_fixed64) 256 257 # Test length delimited. 258 self.CheckUnknownField('optional_string', 259 unknown_fields, 260 self.all_fields.optional_string.encode('utf-8')) 261 self.InternalCheckUnknownField('optional_string', 262 self.all_fields.optional_string) 263 264 # Test group. 265 self.CheckUnknownField('optionalgroup', 266 unknown_fields, 267 (17, 0, 117)) 268 self.InternalCheckUnknownField('optionalgroup', 269 self.all_fields.optionalgroup) 270 271 self.assertEqual(97, len(unknown_fields)) 272 273 def testCopyFrom(self): 274 message = unittest_pb2.TestEmptyMessage() 275 message.CopyFrom(self.empty_message) 276 self.assertEqual(message.SerializeToString(), self.all_fields_data) 277 278 def testMergeFrom(self): 279 message = unittest_pb2.TestAllTypes() 280 message.optional_int32 = 1 281 message.optional_uint32 = 2 282 source = unittest_pb2.TestEmptyMessage() 283 source.ParseFromString(message.SerializeToString()) 284 285 message.ClearField('optional_int32') 286 message.optional_int64 = 3 287 message.optional_uint32 = 4 288 destination = unittest_pb2.TestEmptyMessage() 289 unknown_fields = destination.UnknownFields() 290 self.assertEqual(0, len(unknown_fields)) 291 destination.ParseFromString(message.SerializeToString()) 292 # ParseFromString clears the message thus unknown fields is invalid. 293 with self.assertRaises(ValueError) as context: 294 len(unknown_fields) 295 self.assertIn('UnknownFields does not exist.', 296 str(context.exception)) 297 unknown_fields = destination.UnknownFields() 298 self.assertEqual(2, len(unknown_fields)) 299 destination.MergeFrom(source) 300 self.assertEqual(4, len(unknown_fields)) 301 # Check that the fields where correctly merged, even stored in the unknown 302 # fields set. 303 message.ParseFromString(destination.SerializeToString()) 304 self.assertEqual(message.optional_int32, 1) 305 self.assertEqual(message.optional_uint32, 2) 306 self.assertEqual(message.optional_int64, 3) 307 308 def testClear(self): 309 unknown_fields = self.empty_message.UnknownFields() 310 self.empty_message.Clear() 311 # All cleared, even unknown fields. 312 self.assertEqual(self.empty_message.SerializeToString(), b'') 313 with self.assertRaises(ValueError) as context: 314 len(unknown_fields) 315 self.assertIn('UnknownFields does not exist.', 316 str(context.exception)) 317 318 @unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4), 319 'tracemalloc requires python 3.4+') 320 def testUnknownFieldsNoMemoryLeak(self): 321 # Call to UnknownFields must not leak memory 322 nb_leaks = 1234 323 324 def leaking_function(): 325 for _ in range(nb_leaks): 326 self.empty_message.UnknownFields() 327 328 tracemalloc.start() 329 snapshot1 = tracemalloc.take_snapshot() 330 leaking_function() 331 snapshot2 = tracemalloc.take_snapshot() 332 top_stats = snapshot2.compare_to(snapshot1, 'lineno') 333 tracemalloc.stop() 334 # There's no easy way to look for a precise leak source. 335 # Rely on a "marker" count value while checking allocated memory. 336 self.assertEqual([], [x for x in top_stats if x.count_diff == nb_leaks]) 337 338 def testSubUnknownFields(self): 339 message = unittest_pb2.TestAllTypes() 340 message.optionalgroup.a = 123 341 destination = unittest_pb2.TestEmptyMessage() 342 destination.ParseFromString(message.SerializeToString()) 343 sub_unknown_fields = destination.UnknownFields()[0].data 344 self.assertEqual(1, len(sub_unknown_fields)) 345 self.assertEqual(sub_unknown_fields[0].data, 123) 346 destination.Clear() 347 with self.assertRaises(ValueError) as context: 348 len(sub_unknown_fields) 349 self.assertIn('UnknownFields does not exist.', 350 str(context.exception)) 351 with self.assertRaises(ValueError) as context: 352 # pylint: disable=pointless-statement 353 sub_unknown_fields[0] 354 self.assertIn('UnknownFields does not exist.', 355 str(context.exception)) 356 message.Clear() 357 message.optional_uint32 = 456 358 nested_message = unittest_pb2.NestedTestAllTypes() 359 nested_message.payload.optional_nested_message.ParseFromString( 360 message.SerializeToString()) 361 unknown_fields = ( 362 nested_message.payload.optional_nested_message.UnknownFields()) 363 self.assertEqual(unknown_fields[0].data, 456) 364 nested_message.ClearField('payload') 365 self.assertEqual(unknown_fields[0].data, 456) 366 unknown_fields = ( 367 nested_message.payload.optional_nested_message.UnknownFields()) 368 self.assertEqual(0, len(unknown_fields)) 369 370 def testUnknownField(self): 371 message = unittest_pb2.TestAllTypes() 372 message.optional_int32 = 123 373 destination = unittest_pb2.TestEmptyMessage() 374 destination.ParseFromString(message.SerializeToString()) 375 unknown_field = destination.UnknownFields()[0] 376 destination.Clear() 377 with self.assertRaises(ValueError) as context: 378 unknown_field.data # pylint: disable=pointless-statement 379 self.assertIn('The parent message might be cleared.', 380 str(context.exception)) 381 382 def testUnknownExtensions(self): 383 message = unittest_pb2.TestEmptyMessageWithExtensions() 384 message.ParseFromString(self.all_fields_data) 385 self.assertEqual(len(message.UnknownFields()), 97) 386 self.assertEqual(message.SerializeToString(), self.all_fields_data) 387 388 389@testing_refleaks.TestCase 390class UnknownEnumValuesTest(unittest.TestCase): 391 392 def setUp(self): 393 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR 394 395 self.message = missing_enum_values_pb2.TestEnumValues() 396 # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum. 397 self.message.optional_nested_enum = ( 398 missing_enum_values_pb2.TestEnumValues.ZERO) 399 self.message.repeated_nested_enum.extend([ 400 missing_enum_values_pb2.TestEnumValues.ZERO, 401 missing_enum_values_pb2.TestEnumValues.ONE, 402 ]) 403 self.message.packed_nested_enum.extend([ 404 missing_enum_values_pb2.TestEnumValues.ZERO, 405 missing_enum_values_pb2.TestEnumValues.ONE, 406 ]) 407 self.message_data = self.message.SerializeToString() 408 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() 409 self.missing_message.ParseFromString(self.message_data) 410 411 # CheckUnknownField() is an additional Pure Python check which checks 412 # a detail of unknown fields. It cannot be used by the C++ 413 # implementation because some protect members are called. 414 # The test is added for historical reasons. It is not necessary as 415 # serialized string is checked. 416 417 def CheckUnknownField(self, name, expected_value): 418 field_descriptor = self.descriptor.fields_by_name[name] 419 unknown_fields = self.missing_message.UnknownFields() 420 count = 0 421 for field in unknown_fields: 422 if field.field_number == field_descriptor.number: 423 count += 1 424 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 425 self.assertIn(field.data, expected_value) 426 else: 427 self.assertEqual(expected_value, field.data) 428 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 429 self.assertEqual(count, len(expected_value)) 430 else: 431 self.assertEqual(count, 1) 432 433 def testUnknownParseMismatchEnumValue(self): 434 just_string = missing_enum_values_pb2.JustString() 435 just_string.dummy = 'blah' 436 437 missing = missing_enum_values_pb2.TestEnumValues() 438 # The parse is invalid, storing the string proto into the set of 439 # unknown fields. 440 missing.ParseFromString(just_string.SerializeToString()) 441 442 # Fetching the enum field shouldn't crash, instead returning the 443 # default value. 444 self.assertEqual(missing.optional_nested_enum, 0) 445 446 def testUnknownEnumValue(self): 447 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) 448 self.assertEqual(self.missing_message.optional_nested_enum, 2) 449 # Clear does not do anything. 450 serialized = self.missing_message.SerializeToString() 451 self.missing_message.ClearField('optional_nested_enum') 452 self.assertEqual(self.missing_message.SerializeToString(), serialized) 453 454 def testUnknownRepeatedEnumValue(self): 455 self.assertEqual([], self.missing_message.repeated_nested_enum) 456 457 def testUnknownPackedEnumValue(self): 458 self.assertEqual([], self.missing_message.packed_nested_enum) 459 460 def testCheckUnknownFieldValueForEnum(self): 461 unknown_fields = self.missing_message.UnknownFields() 462 self.assertEqual(len(unknown_fields), 5) 463 self.CheckUnknownField('optional_nested_enum', 464 self.message.optional_nested_enum) 465 self.CheckUnknownField('repeated_nested_enum', 466 self.message.repeated_nested_enum) 467 self.CheckUnknownField('packed_nested_enum', 468 self.message.packed_nested_enum) 469 470 def testRoundTrip(self): 471 new_message = missing_enum_values_pb2.TestEnumValues() 472 new_message.ParseFromString(self.missing_message.SerializeToString()) 473 self.assertEqual(self.message, new_message) 474 475 476if __name__ == '__main__': 477 unittest.main() 478