1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# 4# Use of this source code is governed by a BSD-style 5# license that can be found in the LICENSE file or at 6# https://developers.google.com/open-source/licenses/bsd 7 8"""Test for google3.net.proto2.python.internal.well_known_types.""" 9 10import unittest 11 12from google.protobuf import field_mask_pb2 13from google.protobuf.internal import field_mask 14from google.protobuf.internal import test_util 15from google.protobuf import descriptor 16from google.protobuf import map_unittest_pb2 17from google.protobuf import unittest_pb2 18 19 20class FieldMaskTest(unittest.TestCase): 21 22 def testStringFormat(self): 23 mask = field_mask_pb2.FieldMask() 24 self.assertEqual('', mask.ToJsonString()) 25 mask.paths.append('foo') 26 self.assertEqual('foo', mask.ToJsonString()) 27 mask.paths.append('bar') 28 self.assertEqual('foo,bar', mask.ToJsonString()) 29 30 mask.FromJsonString('') 31 self.assertEqual('', mask.ToJsonString()) 32 mask.FromJsonString('foo') 33 self.assertEqual(['foo'], mask.paths) 34 mask.FromJsonString('foo,bar') 35 self.assertEqual(['foo', 'bar'], mask.paths) 36 37 # Test camel case 38 mask.Clear() 39 mask.paths.append('foo_bar') 40 self.assertEqual('fooBar', mask.ToJsonString()) 41 mask.paths.append('bar_quz') 42 self.assertEqual('fooBar,barQuz', mask.ToJsonString()) 43 44 mask.FromJsonString('') 45 self.assertEqual('', mask.ToJsonString()) 46 self.assertEqual([], mask.paths) 47 mask.FromJsonString('fooBar') 48 self.assertEqual(['foo_bar'], mask.paths) 49 mask.FromJsonString('fooBar,barQuz') 50 self.assertEqual(['foo_bar', 'bar_quz'], mask.paths) 51 52 def testDescriptorToFieldMask(self): 53 mask = field_mask_pb2.FieldMask() 54 msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 55 mask.AllFieldsFromDescriptor(msg_descriptor) 56 self.assertEqual(80, len(mask.paths)) 57 self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) 58 for field in msg_descriptor.fields: 59 self.assertTrue(field.name in mask.paths) 60 61 def testIsValidForDescriptor(self): 62 msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 63 # Empty mask 64 mask = field_mask_pb2.FieldMask() 65 self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) 66 # All fields from descriptor 67 mask.AllFieldsFromDescriptor(msg_descriptor) 68 self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) 69 # Child under optional message 70 mask.paths.append('optional_nested_message.bb') 71 self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) 72 # Repeated field is only allowed in the last position of path 73 mask.paths.append('repeated_nested_message.bb') 74 self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) 75 # Invalid top level field 76 mask = field_mask_pb2.FieldMask() 77 mask.paths.append('xxx') 78 self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) 79 # Invalid field in root 80 mask = field_mask_pb2.FieldMask() 81 mask.paths.append('xxx.zzz') 82 self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) 83 # Invalid field in internal node 84 mask = field_mask_pb2.FieldMask() 85 mask.paths.append('optional_nested_message.xxx.zzz') 86 self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) 87 # Invalid field in leaf 88 mask = field_mask_pb2.FieldMask() 89 mask.paths.append('optional_nested_message.xxx') 90 self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) 91 92 def testCanonicalFrom(self): 93 mask = field_mask_pb2.FieldMask() 94 out_mask = field_mask_pb2.FieldMask() 95 # Paths will be sorted. 96 mask.FromJsonString('baz.quz,bar,foo') 97 out_mask.CanonicalFormFromMask(mask) 98 self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString()) 99 # Duplicated paths will be removed. 100 mask.FromJsonString('foo,bar,foo') 101 out_mask.CanonicalFormFromMask(mask) 102 self.assertEqual('bar,foo', out_mask.ToJsonString()) 103 # Sub-paths of other paths will be removed. 104 mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar') 105 out_mask.CanonicalFormFromMask(mask) 106 self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString()) 107 108 # Test more deeply nested cases. 109 mask.FromJsonString( 110 'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2') 111 out_mask.CanonicalFormFromMask(mask) 112 self.assertEqual('foo.bar.baz1,foo.bar.baz2', 113 out_mask.ToJsonString()) 114 mask.FromJsonString( 115 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz') 116 out_mask.CanonicalFormFromMask(mask) 117 self.assertEqual('foo.bar.baz1,foo.bar.baz2', 118 out_mask.ToJsonString()) 119 mask.FromJsonString( 120 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar') 121 out_mask.CanonicalFormFromMask(mask) 122 self.assertEqual('foo.bar', out_mask.ToJsonString()) 123 mask.FromJsonString( 124 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo') 125 out_mask.CanonicalFormFromMask(mask) 126 self.assertEqual('foo', out_mask.ToJsonString()) 127 128 def testUnion(self): 129 mask1 = field_mask_pb2.FieldMask() 130 mask2 = field_mask_pb2.FieldMask() 131 out_mask = field_mask_pb2.FieldMask() 132 mask1.FromJsonString('foo,baz') 133 mask2.FromJsonString('bar,quz') 134 out_mask.Union(mask1, mask2) 135 self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString()) 136 # Overlap with duplicated paths. 137 mask1.FromJsonString('foo,baz.bb') 138 mask2.FromJsonString('baz.bb,quz') 139 out_mask.Union(mask1, mask2) 140 self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString()) 141 # Overlap with paths covering some other paths. 142 mask1.FromJsonString('foo.bar.baz,quz') 143 mask2.FromJsonString('foo.bar,bar') 144 out_mask.Union(mask1, mask2) 145 self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString()) 146 src = unittest_pb2.TestAllTypes() 147 with self.assertRaises(ValueError): 148 out_mask.Union(src, mask2) 149 150 def testIntersect(self): 151 mask1 = field_mask_pb2.FieldMask() 152 mask2 = field_mask_pb2.FieldMask() 153 out_mask = field_mask_pb2.FieldMask() 154 # Test cases without overlapping. 155 mask1.FromJsonString('foo,baz') 156 mask2.FromJsonString('bar,quz') 157 out_mask.Intersect(mask1, mask2) 158 self.assertEqual('', out_mask.ToJsonString()) 159 self.assertEqual(len(out_mask.paths), 0) 160 self.assertEqual(out_mask.paths, []) 161 # Overlap with duplicated paths. 162 mask1.FromJsonString('foo,baz.bb') 163 mask2.FromJsonString('baz.bb,quz') 164 out_mask.Intersect(mask1, mask2) 165 self.assertEqual('baz.bb', out_mask.ToJsonString()) 166 # Overlap with paths covering some other paths. 167 mask1.FromJsonString('foo.bar.baz,quz') 168 mask2.FromJsonString('foo.bar,bar') 169 out_mask.Intersect(mask1, mask2) 170 self.assertEqual('foo.bar.baz', out_mask.ToJsonString()) 171 mask1.FromJsonString('foo.bar,bar') 172 mask2.FromJsonString('foo.bar.baz,quz') 173 out_mask.Intersect(mask1, mask2) 174 self.assertEqual('foo.bar.baz', out_mask.ToJsonString()) 175 # Intersect '' with '' 176 mask1.Clear() 177 mask2.Clear() 178 mask1.paths.append('') 179 mask2.paths.append('') 180 self.assertEqual(mask1.paths, ['']) 181 self.assertEqual('', mask1.ToJsonString()) 182 out_mask.Intersect(mask1, mask2) 183 self.assertEqual(out_mask.paths, []) 184 185 def testMergeMessageWithoutMapFields(self): 186 # Test merge one field. 187 src = unittest_pb2.TestAllTypes() 188 test_util.SetAllFields(src) 189 for field in src.DESCRIPTOR.fields: 190 if field.containing_oneof: 191 continue 192 field_name = field.name 193 dst = unittest_pb2.TestAllTypes() 194 # Only set one path to mask. 195 mask = field_mask_pb2.FieldMask() 196 mask.paths.append(field_name) 197 mask.MergeMessage(src, dst) 198 # The expected result message. 199 msg = unittest_pb2.TestAllTypes() 200 if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: 201 repeated_src = getattr(src, field_name) 202 repeated_msg = getattr(msg, field_name) 203 if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: 204 for item in repeated_src: 205 repeated_msg.add().CopyFrom(item) 206 else: 207 repeated_msg.extend(repeated_src) 208 elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: 209 getattr(msg, field_name).CopyFrom(getattr(src, field_name)) 210 else: 211 setattr(msg, field_name, getattr(src, field_name)) 212 # Only field specified in mask is merged. 213 self.assertEqual(msg, dst) 214 215 # Test merge nested fields. 216 nested_src = unittest_pb2.NestedTestAllTypes() 217 nested_dst = unittest_pb2.NestedTestAllTypes() 218 nested_src.child.payload.optional_int32 = 1234 219 nested_src.child.child.payload.optional_int32 = 5678 220 mask = field_mask_pb2.FieldMask() 221 mask.FromJsonString('child.payload') 222 mask.MergeMessage(nested_src, nested_dst) 223 self.assertEqual(1234, nested_dst.child.payload.optional_int32) 224 self.assertEqual(0, nested_dst.child.child.payload.optional_int32) 225 226 mask.FromJsonString('child.child.payload') 227 mask.MergeMessage(nested_src, nested_dst) 228 self.assertEqual(1234, nested_dst.child.payload.optional_int32) 229 self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) 230 231 nested_dst.Clear() 232 mask.FromJsonString('child.child.payload') 233 mask.MergeMessage(nested_src, nested_dst) 234 self.assertEqual(0, nested_dst.child.payload.optional_int32) 235 self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) 236 237 nested_dst.Clear() 238 mask.FromJsonString('child') 239 mask.MergeMessage(nested_src, nested_dst) 240 self.assertEqual(1234, nested_dst.child.payload.optional_int32) 241 self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) 242 243 # Test MergeOptions. 244 nested_dst.Clear() 245 nested_dst.child.payload.optional_int64 = 4321 246 # Message fields will be merged by default. 247 mask.FromJsonString('child.payload') 248 mask.MergeMessage(nested_src, nested_dst) 249 self.assertEqual(1234, nested_dst.child.payload.optional_int32) 250 self.assertEqual(4321, nested_dst.child.payload.optional_int64) 251 # Change the behavior to replace message fields. 252 mask.FromJsonString('child.payload') 253 mask.MergeMessage(nested_src, nested_dst, True, False) 254 self.assertEqual(1234, nested_dst.child.payload.optional_int32) 255 self.assertEqual(0, nested_dst.child.payload.optional_int64) 256 257 # By default, fields missing in source are not cleared in destination. 258 nested_dst.payload.optional_int32 = 1234 259 self.assertTrue(nested_dst.HasField('payload')) 260 mask.FromJsonString('payload') 261 mask.MergeMessage(nested_src, nested_dst) 262 self.assertTrue(nested_dst.HasField('payload')) 263 # But they are cleared when replacing message fields. 264 nested_dst.Clear() 265 nested_dst.payload.optional_int32 = 1234 266 mask.FromJsonString('payload') 267 mask.MergeMessage(nested_src, nested_dst, True, False) 268 self.assertFalse(nested_dst.HasField('payload')) 269 270 nested_src.payload.repeated_int32.append(1234) 271 nested_dst.payload.repeated_int32.append(5678) 272 # Repeated fields will be appended by default. 273 mask.FromJsonString('payload.repeatedInt32') 274 mask.MergeMessage(nested_src, nested_dst) 275 self.assertEqual(2, len(nested_dst.payload.repeated_int32)) 276 self.assertEqual(5678, nested_dst.payload.repeated_int32[0]) 277 self.assertEqual(1234, nested_dst.payload.repeated_int32[1]) 278 # Change the behavior to replace repeated fields. 279 mask.FromJsonString('payload.repeatedInt32') 280 mask.MergeMessage(nested_src, nested_dst, False, True) 281 self.assertEqual(1, len(nested_dst.payload.repeated_int32)) 282 self.assertEqual(1234, nested_dst.payload.repeated_int32[0]) 283 284 # Test Merge oneof field. 285 new_msg = unittest_pb2.TestOneof2() 286 dst = unittest_pb2.TestOneof2() 287 dst.foo_message.moo_int = 1 288 mask = field_mask_pb2.FieldMask() 289 mask.FromJsonString('fooMessage,fooLazyMessage.mooInt') 290 mask.MergeMessage(new_msg, dst) 291 self.assertTrue(dst.HasField('foo_message')) 292 self.assertFalse(dst.HasField('foo_lazy_message')) 293 294 def testMergeMessageWithMapField(self): 295 empty_map = map_unittest_pb2.TestRecursiveMapMessage() 296 src_level_2 = map_unittest_pb2.TestRecursiveMapMessage() 297 src_level_2.a['src level 2'].CopyFrom(empty_map) 298 src = map_unittest_pb2.TestRecursiveMapMessage() 299 src.a['common key'].CopyFrom(src_level_2) 300 src.a['src level 1'].CopyFrom(src_level_2) 301 302 dst_level_2 = map_unittest_pb2.TestRecursiveMapMessage() 303 dst_level_2.a['dst level 2'].CopyFrom(empty_map) 304 dst = map_unittest_pb2.TestRecursiveMapMessage() 305 dst.a['common key'].CopyFrom(dst_level_2) 306 dst.a['dst level 1'].CopyFrom(empty_map) 307 308 mask = field_mask_pb2.FieldMask() 309 mask.FromJsonString('a') 310 mask.MergeMessage(src, dst) 311 312 # map from dst is replaced with map from src. 313 self.assertEqual(dst.a['common key'], src_level_2) 314 self.assertEqual(dst.a['src level 1'], src_level_2) 315 self.assertEqual(dst.a['dst level 1'], empty_map) 316 317 def testMergeErrors(self): 318 src = unittest_pb2.TestAllTypes() 319 dst = unittest_pb2.TestAllTypes() 320 mask = field_mask_pb2.FieldMask() 321 test_util.SetAllFields(src) 322 mask.FromJsonString('optionalInt32.field') 323 with self.assertRaises(ValueError) as e: 324 mask.MergeMessage(src, dst) 325 self.assertEqual('Error: Field optional_int32 in message ' 326 'protobuf_unittest.TestAllTypes is not a singular ' 327 'message field and cannot have sub-fields.', 328 str(e.exception)) 329 330 def testSnakeCaseToCamelCase(self): 331 self.assertEqual('fooBar', 332 field_mask._SnakeCaseToCamelCase('foo_bar')) 333 self.assertEqual('FooBar', 334 field_mask._SnakeCaseToCamelCase('_foo_bar')) 335 self.assertEqual('foo3Bar', 336 field_mask._SnakeCaseToCamelCase('foo3_bar')) 337 338 # No uppercase letter is allowed. 339 self.assertRaisesRegex( 340 ValueError, 341 'Fail to print FieldMask to Json string: Path name Foo must ' 342 'not contain uppercase letters.', 343 field_mask._SnakeCaseToCamelCase, 'Foo') 344 # Any character after a "_" must be a lowercase letter. 345 # 1. "_" cannot be followed by another "_". 346 # 2. "_" cannot be followed by a digit. 347 # 3. "_" cannot appear as the last character. 348 self.assertRaisesRegex( 349 ValueError, 350 'Fail to print FieldMask to Json string: The character after a ' 351 '"_" must be a lowercase letter in path name foo__bar.', 352 field_mask._SnakeCaseToCamelCase, 'foo__bar') 353 self.assertRaisesRegex( 354 ValueError, 355 'Fail to print FieldMask to Json string: The character after a ' 356 '"_" must be a lowercase letter in path name foo_3bar.', 357 field_mask._SnakeCaseToCamelCase, 'foo_3bar') 358 self.assertRaisesRegex( 359 ValueError, 360 'Fail to print FieldMask to Json string: Trailing "_" in path ' 361 'name foo_bar_.', field_mask._SnakeCaseToCamelCase, 'foo_bar_') 362 363 def testCamelCaseToSnakeCase(self): 364 self.assertEqual('foo_bar', 365 field_mask._CamelCaseToSnakeCase('fooBar')) 366 self.assertEqual('_foo_bar', 367 field_mask._CamelCaseToSnakeCase('FooBar')) 368 self.assertEqual('foo3_bar', 369 field_mask._CamelCaseToSnakeCase('foo3Bar')) 370 self.assertRaisesRegex( 371 ValueError, 372 'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.', 373 field_mask._CamelCaseToSnakeCase, 'foo_bar') 374 375 376if __name__ == '__main__': 377 unittest.main() 378