• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Test for google3.net.proto2.python.internal.well_known_types."""
9
10import unittest
11
12from google.protobuf import field_mask_pb2
13from google.protobuf.internal import field_mask
14from google.protobuf.internal import test_util
15from google.protobuf import descriptor
16from google.protobuf import map_unittest_pb2
17from google.protobuf import unittest_pb2
18
19
20class FieldMaskTest(unittest.TestCase):
21
22  def testStringFormat(self):
23    mask = field_mask_pb2.FieldMask()
24    self.assertEqual('', mask.ToJsonString())
25    mask.paths.append('foo')
26    self.assertEqual('foo', mask.ToJsonString())
27    mask.paths.append('bar')
28    self.assertEqual('foo,bar', mask.ToJsonString())
29
30    mask.FromJsonString('')
31    self.assertEqual('', mask.ToJsonString())
32    mask.FromJsonString('foo')
33    self.assertEqual(['foo'], mask.paths)
34    mask.FromJsonString('foo,bar')
35    self.assertEqual(['foo', 'bar'], mask.paths)
36
37    # Test camel case
38    mask.Clear()
39    mask.paths.append('foo_bar')
40    self.assertEqual('fooBar', mask.ToJsonString())
41    mask.paths.append('bar_quz')
42    self.assertEqual('fooBar,barQuz', mask.ToJsonString())
43
44    mask.FromJsonString('')
45    self.assertEqual('', mask.ToJsonString())
46    self.assertEqual([], mask.paths)
47    mask.FromJsonString('fooBar')
48    self.assertEqual(['foo_bar'], mask.paths)
49    mask.FromJsonString('fooBar,barQuz')
50    self.assertEqual(['foo_bar', 'bar_quz'], mask.paths)
51
52  def testDescriptorToFieldMask(self):
53    mask = field_mask_pb2.FieldMask()
54    msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
55    mask.AllFieldsFromDescriptor(msg_descriptor)
56    self.assertEqual(80, len(mask.paths))
57    self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
58    for field in msg_descriptor.fields:
59      self.assertTrue(field.name in mask.paths)
60
61  def testIsValidForDescriptor(self):
62    msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
63    # Empty mask
64    mask = field_mask_pb2.FieldMask()
65    self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
66    # All fields from descriptor
67    mask.AllFieldsFromDescriptor(msg_descriptor)
68    self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
69    # Child under optional message
70    mask.paths.append('optional_nested_message.bb')
71    self.assertTrue(mask.IsValidForDescriptor(msg_descriptor))
72    # Repeated field is only allowed in the last position of path
73    mask.paths.append('repeated_nested_message.bb')
74    self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
75    # Invalid top level field
76    mask = field_mask_pb2.FieldMask()
77    mask.paths.append('xxx')
78    self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
79    # Invalid field in root
80    mask = field_mask_pb2.FieldMask()
81    mask.paths.append('xxx.zzz')
82    self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
83    # Invalid field in internal node
84    mask = field_mask_pb2.FieldMask()
85    mask.paths.append('optional_nested_message.xxx.zzz')
86    self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
87    # Invalid field in leaf
88    mask = field_mask_pb2.FieldMask()
89    mask.paths.append('optional_nested_message.xxx')
90    self.assertFalse(mask.IsValidForDescriptor(msg_descriptor))
91
92  def testCanonicalFrom(self):
93    mask = field_mask_pb2.FieldMask()
94    out_mask = field_mask_pb2.FieldMask()
95    # Paths will be sorted.
96    mask.FromJsonString('baz.quz,bar,foo')
97    out_mask.CanonicalFormFromMask(mask)
98    self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString())
99    # Duplicated paths will be removed.
100    mask.FromJsonString('foo,bar,foo')
101    out_mask.CanonicalFormFromMask(mask)
102    self.assertEqual('bar,foo', out_mask.ToJsonString())
103    # Sub-paths of other paths will be removed.
104    mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar')
105    out_mask.CanonicalFormFromMask(mask)
106    self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString())
107
108    # Test more deeply nested cases.
109    mask.FromJsonString(
110        'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2')
111    out_mask.CanonicalFormFromMask(mask)
112    self.assertEqual('foo.bar.baz1,foo.bar.baz2',
113                     out_mask.ToJsonString())
114    mask.FromJsonString(
115        'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz')
116    out_mask.CanonicalFormFromMask(mask)
117    self.assertEqual('foo.bar.baz1,foo.bar.baz2',
118                     out_mask.ToJsonString())
119    mask.FromJsonString(
120        'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar')
121    out_mask.CanonicalFormFromMask(mask)
122    self.assertEqual('foo.bar', out_mask.ToJsonString())
123    mask.FromJsonString(
124        'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo')
125    out_mask.CanonicalFormFromMask(mask)
126    self.assertEqual('foo', out_mask.ToJsonString())
127
128  def testUnion(self):
129    mask1 = field_mask_pb2.FieldMask()
130    mask2 = field_mask_pb2.FieldMask()
131    out_mask = field_mask_pb2.FieldMask()
132    mask1.FromJsonString('foo,baz')
133    mask2.FromJsonString('bar,quz')
134    out_mask.Union(mask1, mask2)
135    self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString())
136    # Overlap with duplicated paths.
137    mask1.FromJsonString('foo,baz.bb')
138    mask2.FromJsonString('baz.bb,quz')
139    out_mask.Union(mask1, mask2)
140    self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString())
141    # Overlap with paths covering some other paths.
142    mask1.FromJsonString('foo.bar.baz,quz')
143    mask2.FromJsonString('foo.bar,bar')
144    out_mask.Union(mask1, mask2)
145    self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString())
146    src = unittest_pb2.TestAllTypes()
147    with self.assertRaises(ValueError):
148      out_mask.Union(src, mask2)
149
150  def testIntersect(self):
151    mask1 = field_mask_pb2.FieldMask()
152    mask2 = field_mask_pb2.FieldMask()
153    out_mask = field_mask_pb2.FieldMask()
154    # Test cases without overlapping.
155    mask1.FromJsonString('foo,baz')
156    mask2.FromJsonString('bar,quz')
157    out_mask.Intersect(mask1, mask2)
158    self.assertEqual('', out_mask.ToJsonString())
159    self.assertEqual(len(out_mask.paths), 0)
160    self.assertEqual(out_mask.paths, [])
161    # Overlap with duplicated paths.
162    mask1.FromJsonString('foo,baz.bb')
163    mask2.FromJsonString('baz.bb,quz')
164    out_mask.Intersect(mask1, mask2)
165    self.assertEqual('baz.bb', out_mask.ToJsonString())
166    # Overlap with paths covering some other paths.
167    mask1.FromJsonString('foo.bar.baz,quz')
168    mask2.FromJsonString('foo.bar,bar')
169    out_mask.Intersect(mask1, mask2)
170    self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
171    mask1.FromJsonString('foo.bar,bar')
172    mask2.FromJsonString('foo.bar.baz,quz')
173    out_mask.Intersect(mask1, mask2)
174    self.assertEqual('foo.bar.baz', out_mask.ToJsonString())
175    # Intersect '' with ''
176    mask1.Clear()
177    mask2.Clear()
178    mask1.paths.append('')
179    mask2.paths.append('')
180    self.assertEqual(mask1.paths, [''])
181    self.assertEqual('', mask1.ToJsonString())
182    out_mask.Intersect(mask1, mask2)
183    self.assertEqual(out_mask.paths, [])
184
185  def testMergeMessageWithoutMapFields(self):
186    # Test merge one field.
187    src = unittest_pb2.TestAllTypes()
188    test_util.SetAllFields(src)
189    for field in src.DESCRIPTOR.fields:
190      if field.containing_oneof:
191        continue
192      field_name = field.name
193      dst = unittest_pb2.TestAllTypes()
194      # Only set one path to mask.
195      mask = field_mask_pb2.FieldMask()
196      mask.paths.append(field_name)
197      mask.MergeMessage(src, dst)
198      # The expected result message.
199      msg = unittest_pb2.TestAllTypes()
200      if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
201        repeated_src = getattr(src, field_name)
202        repeated_msg = getattr(msg, field_name)
203        if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
204          for item in repeated_src:
205            repeated_msg.add().CopyFrom(item)
206        else:
207          repeated_msg.extend(repeated_src)
208      elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
209        getattr(msg, field_name).CopyFrom(getattr(src, field_name))
210      else:
211        setattr(msg, field_name, getattr(src, field_name))
212      # Only field specified in mask is merged.
213      self.assertEqual(msg, dst)
214
215    # Test merge nested fields.
216    nested_src = unittest_pb2.NestedTestAllTypes()
217    nested_dst = unittest_pb2.NestedTestAllTypes()
218    nested_src.child.payload.optional_int32 = 1234
219    nested_src.child.child.payload.optional_int32 = 5678
220    mask = field_mask_pb2.FieldMask()
221    mask.FromJsonString('child.payload')
222    mask.MergeMessage(nested_src, nested_dst)
223    self.assertEqual(1234, nested_dst.child.payload.optional_int32)
224    self.assertEqual(0, nested_dst.child.child.payload.optional_int32)
225
226    mask.FromJsonString('child.child.payload')
227    mask.MergeMessage(nested_src, nested_dst)
228    self.assertEqual(1234, nested_dst.child.payload.optional_int32)
229    self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
230
231    nested_dst.Clear()
232    mask.FromJsonString('child.child.payload')
233    mask.MergeMessage(nested_src, nested_dst)
234    self.assertEqual(0, nested_dst.child.payload.optional_int32)
235    self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
236
237    nested_dst.Clear()
238    mask.FromJsonString('child')
239    mask.MergeMessage(nested_src, nested_dst)
240    self.assertEqual(1234, nested_dst.child.payload.optional_int32)
241    self.assertEqual(5678, nested_dst.child.child.payload.optional_int32)
242
243    # Test MergeOptions.
244    nested_dst.Clear()
245    nested_dst.child.payload.optional_int64 = 4321
246    # Message fields will be merged by default.
247    mask.FromJsonString('child.payload')
248    mask.MergeMessage(nested_src, nested_dst)
249    self.assertEqual(1234, nested_dst.child.payload.optional_int32)
250    self.assertEqual(4321, nested_dst.child.payload.optional_int64)
251    # Change the behavior to replace message fields.
252    mask.FromJsonString('child.payload')
253    mask.MergeMessage(nested_src, nested_dst, True, False)
254    self.assertEqual(1234, nested_dst.child.payload.optional_int32)
255    self.assertEqual(0, nested_dst.child.payload.optional_int64)
256
257    # By default, fields missing in source are not cleared in destination.
258    nested_dst.payload.optional_int32 = 1234
259    self.assertTrue(nested_dst.HasField('payload'))
260    mask.FromJsonString('payload')
261    mask.MergeMessage(nested_src, nested_dst)
262    self.assertTrue(nested_dst.HasField('payload'))
263    # But they are cleared when replacing message fields.
264    nested_dst.Clear()
265    nested_dst.payload.optional_int32 = 1234
266    mask.FromJsonString('payload')
267    mask.MergeMessage(nested_src, nested_dst, True, False)
268    self.assertFalse(nested_dst.HasField('payload'))
269
270    nested_src.payload.repeated_int32.append(1234)
271    nested_dst.payload.repeated_int32.append(5678)
272    # Repeated fields will be appended by default.
273    mask.FromJsonString('payload.repeatedInt32')
274    mask.MergeMessage(nested_src, nested_dst)
275    self.assertEqual(2, len(nested_dst.payload.repeated_int32))
276    self.assertEqual(5678, nested_dst.payload.repeated_int32[0])
277    self.assertEqual(1234, nested_dst.payload.repeated_int32[1])
278    # Change the behavior to replace repeated fields.
279    mask.FromJsonString('payload.repeatedInt32')
280    mask.MergeMessage(nested_src, nested_dst, False, True)
281    self.assertEqual(1, len(nested_dst.payload.repeated_int32))
282    self.assertEqual(1234, nested_dst.payload.repeated_int32[0])
283
284    # Test Merge oneof field.
285    new_msg = unittest_pb2.TestOneof2()
286    dst = unittest_pb2.TestOneof2()
287    dst.foo_message.moo_int = 1
288    mask = field_mask_pb2.FieldMask()
289    mask.FromJsonString('fooMessage,fooLazyMessage.mooInt')
290    mask.MergeMessage(new_msg, dst)
291    self.assertTrue(dst.HasField('foo_message'))
292    self.assertFalse(dst.HasField('foo_lazy_message'))
293
294  def testMergeMessageWithMapField(self):
295    empty_map = map_unittest_pb2.TestRecursiveMapMessage()
296    src_level_2 = map_unittest_pb2.TestRecursiveMapMessage()
297    src_level_2.a['src level 2'].CopyFrom(empty_map)
298    src = map_unittest_pb2.TestRecursiveMapMessage()
299    src.a['common key'].CopyFrom(src_level_2)
300    src.a['src level 1'].CopyFrom(src_level_2)
301
302    dst_level_2 = map_unittest_pb2.TestRecursiveMapMessage()
303    dst_level_2.a['dst level 2'].CopyFrom(empty_map)
304    dst = map_unittest_pb2.TestRecursiveMapMessage()
305    dst.a['common key'].CopyFrom(dst_level_2)
306    dst.a['dst level 1'].CopyFrom(empty_map)
307
308    mask = field_mask_pb2.FieldMask()
309    mask.FromJsonString('a')
310    mask.MergeMessage(src, dst)
311
312    # map from dst is replaced with map from src.
313    self.assertEqual(dst.a['common key'], src_level_2)
314    self.assertEqual(dst.a['src level 1'], src_level_2)
315    self.assertEqual(dst.a['dst level 1'], empty_map)
316
317  def testMergeErrors(self):
318    src = unittest_pb2.TestAllTypes()
319    dst = unittest_pb2.TestAllTypes()
320    mask = field_mask_pb2.FieldMask()
321    test_util.SetAllFields(src)
322    mask.FromJsonString('optionalInt32.field')
323    with self.assertRaises(ValueError) as e:
324      mask.MergeMessage(src, dst)
325    self.assertEqual('Error: Field optional_int32 in message '
326                     'protobuf_unittest.TestAllTypes is not a singular '
327                     'message field and cannot have sub-fields.',
328                     str(e.exception))
329
330  def testSnakeCaseToCamelCase(self):
331    self.assertEqual('fooBar',
332                     field_mask._SnakeCaseToCamelCase('foo_bar'))
333    self.assertEqual('FooBar',
334                     field_mask._SnakeCaseToCamelCase('_foo_bar'))
335    self.assertEqual('foo3Bar',
336                     field_mask._SnakeCaseToCamelCase('foo3_bar'))
337
338    # No uppercase letter is allowed.
339    self.assertRaisesRegex(
340        ValueError,
341        'Fail to print FieldMask to Json string: Path name Foo must '
342        'not contain uppercase letters.',
343        field_mask._SnakeCaseToCamelCase, 'Foo')
344    # Any character after a "_" must be a lowercase letter.
345    #   1. "_" cannot be followed by another "_".
346    #   2. "_" cannot be followed by a digit.
347    #   3. "_" cannot appear as the last character.
348    self.assertRaisesRegex(
349        ValueError,
350        'Fail to print FieldMask to Json string: The character after a '
351        '"_" must be a lowercase letter in path name foo__bar.',
352        field_mask._SnakeCaseToCamelCase, 'foo__bar')
353    self.assertRaisesRegex(
354        ValueError,
355        'Fail to print FieldMask to Json string: The character after a '
356        '"_" must be a lowercase letter in path name foo_3bar.',
357        field_mask._SnakeCaseToCamelCase, 'foo_3bar')
358    self.assertRaisesRegex(
359        ValueError,
360        'Fail to print FieldMask to Json string: Trailing "_" in path '
361        'name foo_bar_.', field_mask._SnakeCaseToCamelCase, 'foo_bar_')
362
363  def testCamelCaseToSnakeCase(self):
364    self.assertEqual('foo_bar',
365                     field_mask._CamelCaseToSnakeCase('fooBar'))
366    self.assertEqual('_foo_bar',
367                     field_mask._CamelCaseToSnakeCase('FooBar'))
368    self.assertEqual('foo3_bar',
369                     field_mask._CamelCaseToSnakeCase('foo3Bar'))
370    self.assertRaisesRegex(
371        ValueError,
372        'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.',
373        field_mask._CamelCaseToSnakeCase, 'foo_bar')
374
375
376if __name__ == '__main__':
377  unittest.main()
378