1# -*- coding: utf-8 -*- 2# Protocol Buffers - Google's data interchange format 3# Copyright 2008 Google Inc. All rights reserved. 4# 5# Use of this source code is governed by a BSD-style 6# license that can be found in the LICENSE file or at 7# https://developers.google.com/open-source/licenses/bsd 8 9"""Test for preservation of unknown fields in the pure Python implementation.""" 10 11__author__ = 'bohdank@google.com (Bohdan Koval)' 12 13import sys 14import unittest 15 16from google.protobuf.internal import api_implementation 17from google.protobuf.internal import encoder 18from google.protobuf.internal import message_set_extensions_pb2 19from google.protobuf.internal import missing_enum_values_pb2 20from google.protobuf.internal import test_util 21from google.protobuf.internal import testing_refleaks 22from google.protobuf.internal import type_checkers 23from google.protobuf.internal import wire_format 24from google.protobuf import descriptor 25from google.protobuf import unknown_fields 26from google.protobuf import map_unittest_pb2 27from google.protobuf import unittest_mset_pb2 28from google.protobuf import unittest_pb2 29from google.protobuf import unittest_proto3_arena_pb2 30try: 31 import tracemalloc # pylint: disable=g-import-not-at-top 32except ImportError: 33 # Requires python 3.4+ 34 pass 35 36 37@testing_refleaks.TestCase 38class UnknownFieldsTest(unittest.TestCase): 39 40 def setUp(self): 41 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 42 self.all_fields = unittest_pb2.TestAllTypes() 43 test_util.SetAllFields(self.all_fields) 44 self.all_fields_data = self.all_fields.SerializeToString() 45 self.empty_message = unittest_pb2.TestEmptyMessage() 46 self.empty_message.ParseFromString(self.all_fields_data) 47 48 def testSerialize(self): 49 data = self.empty_message.SerializeToString() 50 51 # Don't use assertEqual because we don't want to dump raw binary data to 52 # stdout. 53 self.assertTrue(data == self.all_fields_data) 54 55 def testSerializeProto3(self): 56 # Verify proto3 unknown fields behavior. 57 message = unittest_proto3_arena_pb2.TestEmptyMessage() 58 message.ParseFromString(self.all_fields_data) 59 self.assertEqual(self.all_fields_data, message.SerializeToString()) 60 61 def testByteSize(self): 62 self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) 63 64 def testListFields(self): 65 # Make sure ListFields doesn't return unknown fields. 66 self.assertEqual(0, len(self.empty_message.ListFields())) 67 68 def testSerializeMessageSetWireFormatUnknownExtension(self): 69 # Create a message using the message set wire format with an unknown 70 # message. 71 raw = unittest_mset_pb2.RawMessageSet() 72 73 # Add an unknown extension. 74 item = raw.item.add() 75 item.type_id = 98218603 76 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 77 message1.i = 12345 78 item.message = message1.SerializeToString() 79 80 serialized = raw.SerializeToString() 81 82 # Parse message using the message set wire format. 83 proto = message_set_extensions_pb2.TestMessageSet() 84 proto.MergeFromString(serialized) 85 86 unknown_field_set = unknown_fields.UnknownFieldSet(proto) 87 self.assertEqual(len(unknown_field_set), 1) 88 # Unknown field should have wire format data which can be parsed back to 89 # original message. 90 self.assertEqual(unknown_field_set[0].field_number, item.type_id) 91 self.assertEqual(unknown_field_set[0].wire_type, 92 wire_format.WIRETYPE_LENGTH_DELIMITED) 93 d = unknown_field_set[0].data 94 message_new = message_set_extensions_pb2.TestMessageSetExtension1() 95 message_new.ParseFromString(d) 96 self.assertEqual(message1, message_new) 97 98 # Verify that the unknown extension is serialized unchanged 99 reserialized = proto.SerializeToString() 100 new_raw = unittest_mset_pb2.RawMessageSet() 101 new_raw.MergeFromString(reserialized) 102 self.assertEqual(raw, new_raw) 103 104 def testEquals(self): 105 message = unittest_pb2.TestEmptyMessage() 106 message.ParseFromString(self.all_fields_data) 107 self.assertEqual(self.empty_message, message) 108 109 self.all_fields.ClearField('optional_string') 110 message.ParseFromString(self.all_fields.SerializeToString()) 111 self.assertNotEqual(self.empty_message, message) 112 113 def testDiscardUnknownFields(self): 114 self.empty_message.DiscardUnknownFields() 115 self.assertEqual(b'', self.empty_message.SerializeToString()) 116 # Test message field and repeated message field. 117 message = unittest_pb2.TestAllTypes() 118 other_message = unittest_pb2.TestAllTypes() 119 other_message.optional_string = 'discard' 120 message.optional_nested_message.ParseFromString( 121 other_message.SerializeToString()) 122 message.repeated_nested_message.add().ParseFromString( 123 other_message.SerializeToString()) 124 self.assertNotEqual( 125 b'', message.optional_nested_message.SerializeToString()) 126 self.assertNotEqual( 127 b'', message.repeated_nested_message[0].SerializeToString()) 128 message.DiscardUnknownFields() 129 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) 130 self.assertEqual( 131 b'', message.repeated_nested_message[0].SerializeToString()) 132 133 msg = map_unittest_pb2.TestMap() 134 msg.map_int32_all_types[1].optional_nested_message.ParseFromString( 135 other_message.SerializeToString()) 136 msg.map_string_string['1'] = 'test' 137 self.assertNotEqual( 138 b'', 139 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 140 msg.DiscardUnknownFields() 141 self.assertEqual( 142 b'', 143 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 144 145 146@testing_refleaks.TestCase 147class UnknownFieldsAccessorsTest(unittest.TestCase): 148 149 def setUp(self): 150 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 151 self.all_fields = unittest_pb2.TestAllTypes() 152 test_util.SetAllFields(self.all_fields) 153 self.all_fields_data = self.all_fields.SerializeToString() 154 self.empty_message = unittest_pb2.TestEmptyMessage() 155 self.empty_message.ParseFromString(self.all_fields_data) 156 157 def CheckUnknownField(self, name, unknown_field_set, expected_value): 158 field_descriptor = self.descriptor.fields_by_name[name] 159 expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ 160 field_descriptor.type] 161 for unknown_field in unknown_field_set: 162 if unknown_field.field_number == field_descriptor.number: 163 self.assertEqual(expected_type, unknown_field.wire_type) 164 if expected_type == 3: 165 # Check group 166 self.assertEqual(expected_value[0], 167 unknown_field.data[0].field_number) 168 self.assertEqual(expected_value[1], unknown_field.data[0].wire_type) 169 self.assertEqual(expected_value[2], unknown_field.data[0].data) 170 continue 171 if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 172 self.assertIn(type(unknown_field.data), (str, bytes)) 173 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 174 self.assertIn(unknown_field.data, expected_value) 175 else: 176 self.assertEqual(expected_value, unknown_field.data) 177 178 def testCheckUnknownFieldValue(self): 179 unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message) 180 # Test enum. 181 self.CheckUnknownField('optional_nested_enum', 182 unknown_field_set, 183 self.all_fields.optional_nested_enum) 184 185 # Test repeated enum. 186 self.CheckUnknownField('repeated_nested_enum', 187 unknown_field_set, 188 self.all_fields.repeated_nested_enum) 189 190 # Test varint. 191 self.CheckUnknownField('optional_int32', 192 unknown_field_set, 193 self.all_fields.optional_int32) 194 195 # Test fixed32. 196 self.CheckUnknownField('optional_fixed32', 197 unknown_field_set, 198 self.all_fields.optional_fixed32) 199 200 # Test fixed64. 201 self.CheckUnknownField('optional_fixed64', 202 unknown_field_set, 203 self.all_fields.optional_fixed64) 204 205 # Test length delimited. 206 self.CheckUnknownField('optional_string', 207 unknown_field_set, 208 self.all_fields.optional_string.encode('utf-8')) 209 210 # Test group. 211 self.CheckUnknownField('optionalgroup', 212 unknown_field_set, 213 (17, 0, 117)) 214 215 self.assertEqual(99, len(unknown_field_set)) 216 217 def testCopyFrom(self): 218 message = unittest_pb2.TestEmptyMessage() 219 message.CopyFrom(self.empty_message) 220 self.assertEqual(message.SerializeToString(), self.all_fields_data) 221 222 def testMergeFrom(self): 223 message = unittest_pb2.TestAllTypes() 224 message.optional_int32 = 1 225 message.optional_uint32 = 2 226 source = unittest_pb2.TestEmptyMessage() 227 source.ParseFromString(message.SerializeToString()) 228 229 message.ClearField('optional_int32') 230 message.optional_int64 = 3 231 message.optional_uint32 = 4 232 destination = unittest_pb2.TestEmptyMessage() 233 unknown_field_set = unknown_fields.UnknownFieldSet(destination) 234 self.assertEqual(0, len(unknown_field_set)) 235 destination.ParseFromString(message.SerializeToString()) 236 self.assertEqual(0, len(unknown_field_set)) 237 unknown_field_set = unknown_fields.UnknownFieldSet(destination) 238 self.assertEqual(2, len(unknown_field_set)) 239 destination.MergeFrom(source) 240 self.assertEqual(2, len(unknown_field_set)) 241 # Check that the fields where correctly merged, even stored in the unknown 242 # fields set. 243 message.ParseFromString(destination.SerializeToString()) 244 self.assertEqual(message.optional_int32, 1) 245 self.assertEqual(message.optional_uint32, 2) 246 self.assertEqual(message.optional_int64, 3) 247 248 def testClear(self): 249 unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message) 250 self.empty_message.Clear() 251 # All cleared, even unknown fields. 252 self.assertEqual(self.empty_message.SerializeToString(), b'') 253 self.assertEqual(len(unknown_field_set), 99) 254 255 @unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4), 256 'tracemalloc requires python 3.4+') 257 def testUnknownFieldsNoMemoryLeak(self): 258 # Call to UnknownFields must not leak memory 259 nb_leaks = 1234 260 261 def leaking_function(): 262 for _ in range(nb_leaks): 263 unknown_fields.UnknownFieldSet(self.empty_message) 264 265 tracemalloc.start() 266 snapshot1 = tracemalloc.take_snapshot() 267 leaking_function() 268 snapshot2 = tracemalloc.take_snapshot() 269 top_stats = snapshot2.compare_to(snapshot1, 'lineno') 270 tracemalloc.stop() 271 # There's no easy way to look for a precise leak source. 272 # Rely on a "marker" count value while checking allocated memory. 273 self.assertEqual([], [x for x in top_stats if x.count_diff == nb_leaks]) 274 275 def testSubUnknownFields(self): 276 message = unittest_pb2.TestAllTypes() 277 message.optionalgroup.a = 123 278 destination = unittest_pb2.TestEmptyMessage() 279 destination.ParseFromString(message.SerializeToString()) 280 sub_unknown_fields = unknown_fields.UnknownFieldSet(destination)[0].data 281 self.assertEqual(1, len(sub_unknown_fields)) 282 self.assertEqual(sub_unknown_fields[0].data, 123) 283 destination.Clear() 284 self.assertEqual(1, len(sub_unknown_fields)) 285 self.assertEqual(sub_unknown_fields[0].data, 123) 286 message.Clear() 287 message.optional_uint32 = 456 288 nested_message = unittest_pb2.NestedTestAllTypes() 289 nested_message.payload.optional_nested_message.ParseFromString( 290 message.SerializeToString()) 291 unknown_field_set = unknown_fields.UnknownFieldSet( 292 nested_message.payload.optional_nested_message) 293 self.assertEqual(unknown_field_set[0].data, 456) 294 nested_message.ClearField('payload') 295 self.assertEqual(unknown_field_set[0].data, 456) 296 unknown_field_set = unknown_fields.UnknownFieldSet( 297 nested_message.payload.optional_nested_message) 298 self.assertEqual(0, len(unknown_field_set)) 299 300 def testUnknownField(self): 301 message = unittest_pb2.TestAllTypes() 302 message.optional_int32 = 123 303 destination = unittest_pb2.TestEmptyMessage() 304 destination.ParseFromString(message.SerializeToString()) 305 unknown_field = unknown_fields.UnknownFieldSet(destination)[0] 306 destination.Clear() 307 self.assertEqual(unknown_field.data, 123) 308 309 def testUnknownExtensions(self): 310 message = unittest_pb2.TestEmptyMessageWithExtensions() 311 message.ParseFromString(self.all_fields_data) 312 self.assertEqual(len(unknown_fields.UnknownFieldSet(message)), 99) 313 self.assertEqual(message.SerializeToString(), self.all_fields_data) 314 315 316@testing_refleaks.TestCase 317class UnknownEnumValuesTest(unittest.TestCase): 318 319 def setUp(self): 320 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR 321 322 self.message = missing_enum_values_pb2.TestEnumValues() 323 # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum. 324 self.message.optional_nested_enum = ( 325 missing_enum_values_pb2.TestEnumValues.ZERO) 326 self.message.repeated_nested_enum.extend([ 327 missing_enum_values_pb2.TestEnumValues.ZERO, 328 missing_enum_values_pb2.TestEnumValues.ONE, 329 ]) 330 self.message.packed_nested_enum.extend([ 331 missing_enum_values_pb2.TestEnumValues.ZERO, 332 missing_enum_values_pb2.TestEnumValues.ONE, 333 ]) 334 self.message_data = self.message.SerializeToString() 335 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() 336 self.missing_message.ParseFromString(self.message_data) 337 338 # CheckUnknownField() is an additional Pure Python check which checks 339 # a detail of unknown fields. It cannot be used by the C++ 340 # implementation because some protect members are called. 341 # The test is added for historical reasons. It is not necessary as 342 # serialized string is checked. 343 344 def CheckUnknownField(self, name, expected_value): 345 field_descriptor = self.descriptor.fields_by_name[name] 346 unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message) 347 self.assertIsInstance(unknown_field_set, unknown_fields.UnknownFieldSet) 348 count = 0 349 for field in unknown_field_set: 350 if field.field_number == field_descriptor.number: 351 count += 1 352 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 353 self.assertIn(field.data, expected_value) 354 else: 355 self.assertEqual(expected_value, field.data) 356 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 357 self.assertEqual(count, len(expected_value)) 358 else: 359 self.assertEqual(count, 1) 360 361 def testUnknownParseMismatchEnumValue(self): 362 just_string = missing_enum_values_pb2.JustString() 363 just_string.dummy = 'blah' 364 365 missing = missing_enum_values_pb2.TestEnumValues() 366 # The parse is invalid, storing the string proto into the set of 367 # unknown fields. 368 missing.ParseFromString(just_string.SerializeToString()) 369 370 # Fetching the enum field shouldn't crash, instead returning the 371 # default value. 372 self.assertEqual(missing.optional_nested_enum, 0) 373 374 def testUnknownEnumValue(self): 375 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) 376 self.assertEqual(self.missing_message.optional_nested_enum, 2) 377 # Clear does not do anything. 378 serialized = self.missing_message.SerializeToString() 379 self.missing_message.ClearField('optional_nested_enum') 380 self.assertEqual(self.missing_message.SerializeToString(), serialized) 381 382 def testUnknownRepeatedEnumValue(self): 383 self.assertEqual([], self.missing_message.repeated_nested_enum) 384 385 def testUnknownPackedEnumValue(self): 386 self.assertEqual([], self.missing_message.packed_nested_enum) 387 388 def testCheckUnknownFieldValueForEnum(self): 389 unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message) 390 self.assertEqual(len(unknown_field_set), 5) 391 self.CheckUnknownField('optional_nested_enum', 392 self.message.optional_nested_enum) 393 self.CheckUnknownField('repeated_nested_enum', 394 self.message.repeated_nested_enum) 395 self.CheckUnknownField('packed_nested_enum', 396 self.message.packed_nested_enum) 397 398 def testRoundTrip(self): 399 new_message = missing_enum_values_pb2.TestEnumValues() 400 new_message.ParseFromString(self.missing_message.SerializeToString()) 401 self.assertEqual(self.message, new_message) 402 403 404if __name__ == '__main__': 405 unittest.main() 406