1#! /usr/bin/env python 2# -*- coding: utf-8 -*- 3# 4# Protocol Buffers - Google's data interchange format 5# Copyright 2008 Google Inc. All rights reserved. 6# https://developers.google.com/protocol-buffers/ 7# 8# Redistribution and use in source and binary forms, with or without 9# modification, are permitted provided that the following conditions are 10# met: 11# 12# * Redistributions of source code must retain the above copyright 13# notice, this list of conditions and the following disclaimer. 14# * Redistributions in binary form must reproduce the above 15# copyright notice, this list of conditions and the following disclaimer 16# in the documentation and/or other materials provided with the 17# distribution. 18# * Neither the name of Google Inc. nor the names of its 19# contributors may be used to endorse or promote products derived from 20# this software without specific prior written permission. 21# 22# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 26# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 34"""Test for preservation of unknown fields in the pure Python implementation.""" 35 36__author__ = 'bohdank@google.com (Bohdan Koval)' 37 38try: 39 import unittest2 as unittest #PY26 40except ImportError: 41 import unittest 42import sys 43try: 44 import tracemalloc 45except ImportError: 46 # Requires python 3.4+ 47 pass 48from google.protobuf import map_unittest_pb2 49from google.protobuf import unittest_mset_pb2 50from google.protobuf import unittest_pb2 51from google.protobuf import unittest_proto3_arena_pb2 52from google.protobuf.internal import api_implementation 53from google.protobuf.internal import encoder 54from google.protobuf.internal import message_set_extensions_pb2 55from google.protobuf.internal import missing_enum_values_pb2 56from google.protobuf.internal import test_util 57from google.protobuf.internal import testing_refleaks 58from google.protobuf.internal import type_checkers 59from google.protobuf.internal import wire_format 60from google.protobuf import descriptor 61 62 63@testing_refleaks.TestCase 64class UnknownFieldsTest(unittest.TestCase): 65 66 def setUp(self): 67 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 68 self.all_fields = unittest_pb2.TestAllTypes() 69 test_util.SetAllFields(self.all_fields) 70 self.all_fields_data = self.all_fields.SerializeToString() 71 self.empty_message = unittest_pb2.TestEmptyMessage() 72 self.empty_message.ParseFromString(self.all_fields_data) 73 74 def testSerialize(self): 75 data = self.empty_message.SerializeToString() 76 77 # Don't use assertEqual because we don't want to dump raw binary data to 78 # stdout. 79 self.assertTrue(data == self.all_fields_data) 80 81 def testSerializeProto3(self): 82 # Verify proto3 unknown fields behavior. 83 message = unittest_proto3_arena_pb2.TestEmptyMessage() 84 message.ParseFromString(self.all_fields_data) 85 self.assertEqual(self.all_fields_data, message.SerializeToString()) 86 87 def testByteSize(self): 88 self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) 89 90 def testListFields(self): 91 # Make sure ListFields doesn't return unknown fields. 92 self.assertEqual(0, len(self.empty_message.ListFields())) 93 94 def testSerializeMessageSetWireFormatUnknownExtension(self): 95 # Create a message using the message set wire format with an unknown 96 # message. 97 raw = unittest_mset_pb2.RawMessageSet() 98 99 # Add an unknown extension. 100 item = raw.item.add() 101 item.type_id = 98218603 102 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 103 message1.i = 12345 104 item.message = message1.SerializeToString() 105 106 serialized = raw.SerializeToString() 107 108 # Parse message using the message set wire format. 109 proto = message_set_extensions_pb2.TestMessageSet() 110 proto.MergeFromString(serialized) 111 112 unknown_fields = proto.UnknownFields() 113 self.assertEqual(len(unknown_fields), 1) 114 # Unknown field should have wire format data which can be parsed back to 115 # original message. 116 self.assertEqual(unknown_fields[0].field_number, item.type_id) 117 self.assertEqual(unknown_fields[0].wire_type, 118 wire_format.WIRETYPE_LENGTH_DELIMITED) 119 d = unknown_fields[0].data 120 message_new = message_set_extensions_pb2.TestMessageSetExtension1() 121 message_new.ParseFromString(d) 122 self.assertEqual(message1, message_new) 123 124 # Verify that the unknown extension is serialized unchanged 125 reserialized = proto.SerializeToString() 126 new_raw = unittest_mset_pb2.RawMessageSet() 127 new_raw.MergeFromString(reserialized) 128 self.assertEqual(raw, new_raw) 129 130 def testEquals(self): 131 message = unittest_pb2.TestEmptyMessage() 132 message.ParseFromString(self.all_fields_data) 133 self.assertEqual(self.empty_message, message) 134 135 self.all_fields.ClearField('optional_string') 136 message.ParseFromString(self.all_fields.SerializeToString()) 137 self.assertNotEqual(self.empty_message, message) 138 139 def testDiscardUnknownFields(self): 140 self.empty_message.DiscardUnknownFields() 141 self.assertEqual(b'', self.empty_message.SerializeToString()) 142 # Test message field and repeated message field. 143 message = unittest_pb2.TestAllTypes() 144 other_message = unittest_pb2.TestAllTypes() 145 other_message.optional_string = 'discard' 146 message.optional_nested_message.ParseFromString( 147 other_message.SerializeToString()) 148 message.repeated_nested_message.add().ParseFromString( 149 other_message.SerializeToString()) 150 self.assertNotEqual( 151 b'', message.optional_nested_message.SerializeToString()) 152 self.assertNotEqual( 153 b'', message.repeated_nested_message[0].SerializeToString()) 154 message.DiscardUnknownFields() 155 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) 156 self.assertEqual( 157 b'', message.repeated_nested_message[0].SerializeToString()) 158 159 msg = map_unittest_pb2.TestMap() 160 msg.map_int32_all_types[1].optional_nested_message.ParseFromString( 161 other_message.SerializeToString()) 162 msg.map_string_string['1'] = 'test' 163 self.assertNotEqual( 164 b'', 165 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 166 msg.DiscardUnknownFields() 167 self.assertEqual( 168 b'', 169 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 170 171 172@testing_refleaks.TestCase 173class UnknownFieldsAccessorsTest(unittest.TestCase): 174 175 def setUp(self): 176 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 177 self.all_fields = unittest_pb2.TestAllTypes() 178 test_util.SetAllFields(self.all_fields) 179 self.all_fields_data = self.all_fields.SerializeToString() 180 self.empty_message = unittest_pb2.TestEmptyMessage() 181 self.empty_message.ParseFromString(self.all_fields_data) 182 183 # InternalCheckUnknownField() is an additional Pure Python check which checks 184 # a detail of unknown fields. It cannot be used by the C++ 185 # implementation because some protect members are called. 186 # The test is added for historical reasons. It is not necessary as 187 # serialized string is checked. 188 # TODO(jieluo): Remove message._unknown_fields. 189 def InternalCheckUnknownField(self, name, expected_value): 190 if api_implementation.Type() == 'cpp': 191 return 192 field_descriptor = self.descriptor.fields_by_name[name] 193 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] 194 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) 195 result_dict = {} 196 for tag_bytes, value in self.empty_message._unknown_fields: 197 if tag_bytes == field_tag: 198 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] 199 decoder(memoryview(value), 0, len(value), self.all_fields, result_dict) 200 self.assertEqual(expected_value, result_dict[field_descriptor]) 201 202 def CheckUnknownField(self, name, unknown_fields, expected_value): 203 field_descriptor = self.descriptor.fields_by_name[name] 204 expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ 205 field_descriptor.type] 206 for unknown_field in unknown_fields: 207 if unknown_field.field_number == field_descriptor.number: 208 self.assertEqual(expected_type, unknown_field.wire_type) 209 if expected_type == 3: 210 # Check group 211 self.assertEqual(expected_value[0], 212 unknown_field.data[0].field_number) 213 self.assertEqual(expected_value[1], unknown_field.data[0].wire_type) 214 self.assertEqual(expected_value[2], unknown_field.data[0].data) 215 continue 216 if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 217 self.assertIn(type(unknown_field.data), (str, bytes)) 218 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 219 self.assertIn(unknown_field.data, expected_value) 220 else: 221 self.assertEqual(expected_value, unknown_field.data) 222 223 def testCheckUnknownFieldValue(self): 224 unknown_fields = self.empty_message.UnknownFields() 225 # Test enum. 226 self.CheckUnknownField('optional_nested_enum', 227 unknown_fields, 228 self.all_fields.optional_nested_enum) 229 self.InternalCheckUnknownField('optional_nested_enum', 230 self.all_fields.optional_nested_enum) 231 232 # Test repeated enum. 233 self.CheckUnknownField('repeated_nested_enum', 234 unknown_fields, 235 self.all_fields.repeated_nested_enum) 236 self.InternalCheckUnknownField('repeated_nested_enum', 237 self.all_fields.repeated_nested_enum) 238 239 # Test varint. 240 self.CheckUnknownField('optional_int32', 241 unknown_fields, 242 self.all_fields.optional_int32) 243 self.InternalCheckUnknownField('optional_int32', 244 self.all_fields.optional_int32) 245 246 # Test fixed32. 247 self.CheckUnknownField('optional_fixed32', 248 unknown_fields, 249 self.all_fields.optional_fixed32) 250 self.InternalCheckUnknownField('optional_fixed32', 251 self.all_fields.optional_fixed32) 252 253 # Test fixed64. 254 self.CheckUnknownField('optional_fixed64', 255 unknown_fields, 256 self.all_fields.optional_fixed64) 257 self.InternalCheckUnknownField('optional_fixed64', 258 self.all_fields.optional_fixed64) 259 260 # Test length delimited. 261 self.CheckUnknownField('optional_string', 262 unknown_fields, 263 self.all_fields.optional_string.encode('utf-8')) 264 self.InternalCheckUnknownField('optional_string', 265 self.all_fields.optional_string) 266 267 # Test group. 268 self.CheckUnknownField('optionalgroup', 269 unknown_fields, 270 (17, 0, 117)) 271 self.InternalCheckUnknownField('optionalgroup', 272 self.all_fields.optionalgroup) 273 274 self.assertEqual(97, len(unknown_fields)) 275 276 def testCopyFrom(self): 277 message = unittest_pb2.TestEmptyMessage() 278 message.CopyFrom(self.empty_message) 279 self.assertEqual(message.SerializeToString(), self.all_fields_data) 280 281 def testMergeFrom(self): 282 message = unittest_pb2.TestAllTypes() 283 message.optional_int32 = 1 284 message.optional_uint32 = 2 285 source = unittest_pb2.TestEmptyMessage() 286 source.ParseFromString(message.SerializeToString()) 287 288 message.ClearField('optional_int32') 289 message.optional_int64 = 3 290 message.optional_uint32 = 4 291 destination = unittest_pb2.TestEmptyMessage() 292 unknown_fields = destination.UnknownFields() 293 self.assertEqual(0, len(unknown_fields)) 294 destination.ParseFromString(message.SerializeToString()) 295 # ParseFromString clears the message thus unknown fields is invalid. 296 with self.assertRaises(ValueError) as context: 297 len(unknown_fields) 298 self.assertIn('UnknownFields does not exist.', 299 str(context.exception)) 300 unknown_fields = destination.UnknownFields() 301 self.assertEqual(2, len(unknown_fields)) 302 destination.MergeFrom(source) 303 self.assertEqual(4, len(unknown_fields)) 304 # Check that the fields where correctly merged, even stored in the unknown 305 # fields set. 306 message.ParseFromString(destination.SerializeToString()) 307 self.assertEqual(message.optional_int32, 1) 308 self.assertEqual(message.optional_uint32, 2) 309 self.assertEqual(message.optional_int64, 3) 310 311 def testClear(self): 312 unknown_fields = self.empty_message.UnknownFields() 313 self.empty_message.Clear() 314 # All cleared, even unknown fields. 315 self.assertEqual(self.empty_message.SerializeToString(), b'') 316 with self.assertRaises(ValueError) as context: 317 len(unknown_fields) 318 self.assertIn('UnknownFields does not exist.', 319 str(context.exception)) 320 321 @unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4), 322 'tracemalloc requires python 3.4+') 323 def testUnknownFieldsNoMemoryLeak(self): 324 # Call to UnknownFields must not leak memory 325 nb_leaks = 1234 326 def leaking_function(): 327 for _ in range(nb_leaks): 328 self.empty_message.UnknownFields() 329 tracemalloc.start() 330 snapshot1 = tracemalloc.take_snapshot() 331 leaking_function() 332 snapshot2 = tracemalloc.take_snapshot() 333 top_stats = snapshot2.compare_to(snapshot1, 'lineno') 334 tracemalloc.stop() 335 # There's no easy way to look for a precise leak source. 336 # Rely on a "marker" count value while checking allocated memory. 337 self.assertEqual([], [x for x in top_stats if x.count_diff == nb_leaks]) 338 339 def testSubUnknownFields(self): 340 message = unittest_pb2.TestAllTypes() 341 message.optionalgroup.a = 123 342 destination = unittest_pb2.TestEmptyMessage() 343 destination.ParseFromString(message.SerializeToString()) 344 sub_unknown_fields = destination.UnknownFields()[0].data 345 self.assertEqual(1, len(sub_unknown_fields)) 346 self.assertEqual(sub_unknown_fields[0].data, 123) 347 destination.Clear() 348 with self.assertRaises(ValueError) as context: 349 len(sub_unknown_fields) 350 self.assertIn('UnknownFields does not exist.', 351 str(context.exception)) 352 with self.assertRaises(ValueError) as context: 353 # pylint: disable=pointless-statement 354 sub_unknown_fields[0] 355 self.assertIn('UnknownFields does not exist.', 356 str(context.exception)) 357 message.Clear() 358 message.optional_uint32 = 456 359 nested_message = unittest_pb2.NestedTestAllTypes() 360 nested_message.payload.optional_nested_message.ParseFromString( 361 message.SerializeToString()) 362 unknown_fields = ( 363 nested_message.payload.optional_nested_message.UnknownFields()) 364 self.assertEqual(unknown_fields[0].data, 456) 365 nested_message.ClearField('payload') 366 self.assertEqual(unknown_fields[0].data, 456) 367 unknown_fields = ( 368 nested_message.payload.optional_nested_message.UnknownFields()) 369 self.assertEqual(0, len(unknown_fields)) 370 371 def testUnknownField(self): 372 message = unittest_pb2.TestAllTypes() 373 message.optional_int32 = 123 374 destination = unittest_pb2.TestEmptyMessage() 375 destination.ParseFromString(message.SerializeToString()) 376 unknown_field = destination.UnknownFields()[0] 377 destination.Clear() 378 with self.assertRaises(ValueError) as context: 379 unknown_field.data # pylint: disable=pointless-statement 380 self.assertIn('The parent message might be cleared.', 381 str(context.exception)) 382 383 def testUnknownExtensions(self): 384 message = unittest_pb2.TestEmptyMessageWithExtensions() 385 message.ParseFromString(self.all_fields_data) 386 self.assertEqual(len(message.UnknownFields()), 97) 387 self.assertEqual(message.SerializeToString(), self.all_fields_data) 388 389 390@testing_refleaks.TestCase 391class UnknownEnumValuesTest(unittest.TestCase): 392 393 def setUp(self): 394 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR 395 396 self.message = missing_enum_values_pb2.TestEnumValues() 397 # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum. 398 self.message.optional_nested_enum = ( 399 missing_enum_values_pb2.TestEnumValues.ZERO) 400 self.message.repeated_nested_enum.extend([ 401 missing_enum_values_pb2.TestEnumValues.ZERO, 402 missing_enum_values_pb2.TestEnumValues.ONE, 403 ]) 404 self.message.packed_nested_enum.extend([ 405 missing_enum_values_pb2.TestEnumValues.ZERO, 406 missing_enum_values_pb2.TestEnumValues.ONE, 407 ]) 408 self.message_data = self.message.SerializeToString() 409 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() 410 self.missing_message.ParseFromString(self.message_data) 411 412 # CheckUnknownField() is an additional Pure Python check which checks 413 # a detail of unknown fields. It cannot be used by the C++ 414 # implementation because some protect members are called. 415 # The test is added for historical reasons. It is not necessary as 416 # serialized string is checked. 417 418 def CheckUnknownField(self, name, expected_value): 419 field_descriptor = self.descriptor.fields_by_name[name] 420 unknown_fields = self.missing_message.UnknownFields() 421 count = 0 422 for field in unknown_fields: 423 if field.field_number == field_descriptor.number: 424 count += 1 425 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 426 self.assertIn(field.data, expected_value) 427 else: 428 self.assertEqual(expected_value, field.data) 429 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 430 self.assertEqual(count, len(expected_value)) 431 else: 432 self.assertEqual(count, 1) 433 434 def testUnknownParseMismatchEnumValue(self): 435 just_string = missing_enum_values_pb2.JustString() 436 just_string.dummy = 'blah' 437 438 missing = missing_enum_values_pb2.TestEnumValues() 439 # The parse is invalid, storing the string proto into the set of 440 # unknown fields. 441 missing.ParseFromString(just_string.SerializeToString()) 442 443 # Fetching the enum field shouldn't crash, instead returning the 444 # default value. 445 self.assertEqual(missing.optional_nested_enum, 0) 446 447 def testUnknownEnumValue(self): 448 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) 449 self.assertEqual(self.missing_message.optional_nested_enum, 2) 450 # Clear does not do anything. 451 serialized = self.missing_message.SerializeToString() 452 self.missing_message.ClearField('optional_nested_enum') 453 self.assertEqual(self.missing_message.SerializeToString(), serialized) 454 455 def testUnknownRepeatedEnumValue(self): 456 self.assertEqual([], self.missing_message.repeated_nested_enum) 457 458 def testUnknownPackedEnumValue(self): 459 self.assertEqual([], self.missing_message.packed_nested_enum) 460 461 def testCheckUnknownFieldValueForEnum(self): 462 unknown_fields = self.missing_message.UnknownFields() 463 self.assertEqual(len(unknown_fields), 5) 464 self.CheckUnknownField('optional_nested_enum', 465 self.message.optional_nested_enum) 466 self.CheckUnknownField('repeated_nested_enum', 467 self.message.repeated_nested_enum) 468 self.CheckUnknownField('packed_nested_enum', 469 self.message.packed_nested_enum) 470 471 def testRoundTrip(self): 472 new_message = missing_enum_values_pb2.TestEnumValues() 473 new_message.ParseFromString(self.missing_message.SerializeToString()) 474 self.assertEqual(self.message, new_message) 475 476 477if __name__ == '__main__': 478 unittest.main() 479