• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Contains FieldMask class."""
9
10from google.protobuf.descriptor import FieldDescriptor
11
12
13class FieldMask(object):
14  """Class for FieldMask message type."""
15
16  __slots__ = ()
17
18  def ToJsonString(self):
19    """Converts FieldMask to string according to proto3 JSON spec."""
20    camelcase_paths = []
21    for path in self.paths:
22      camelcase_paths.append(_SnakeCaseToCamelCase(path))
23    return ','.join(camelcase_paths)
24
25  def FromJsonString(self, value):
26    """Converts string to FieldMask according to proto3 JSON spec."""
27    if not isinstance(value, str):
28      raise ValueError('FieldMask JSON value not a string: {!r}'.format(value))
29    self.Clear()
30    if value:
31      for path in value.split(','):
32        self.paths.append(_CamelCaseToSnakeCase(path))
33
34  def IsValidForDescriptor(self, message_descriptor):
35    """Checks whether the FieldMask is valid for Message Descriptor."""
36    for path in self.paths:
37      if not _IsValidPath(message_descriptor, path):
38        return False
39    return True
40
41  def AllFieldsFromDescriptor(self, message_descriptor):
42    """Gets all direct fields of Message Descriptor to FieldMask."""
43    self.Clear()
44    for field in message_descriptor.fields:
45      self.paths.append(field.name)
46
47  def CanonicalFormFromMask(self, mask):
48    """Converts a FieldMask to the canonical form.
49
50    Removes paths that are covered by another path. For example,
51    "foo.bar" is covered by "foo" and will be removed if "foo"
52    is also in the FieldMask. Then sorts all paths in alphabetical order.
53
54    Args:
55      mask: The original FieldMask to be converted.
56    """
57    tree = _FieldMaskTree(mask)
58    tree.ToFieldMask(self)
59
60  def Union(self, mask1, mask2):
61    """Merges mask1 and mask2 into this FieldMask."""
62    _CheckFieldMaskMessage(mask1)
63    _CheckFieldMaskMessage(mask2)
64    tree = _FieldMaskTree(mask1)
65    tree.MergeFromFieldMask(mask2)
66    tree.ToFieldMask(self)
67
68  def Intersect(self, mask1, mask2):
69    """Intersects mask1 and mask2 into this FieldMask."""
70    _CheckFieldMaskMessage(mask1)
71    _CheckFieldMaskMessage(mask2)
72    tree = _FieldMaskTree(mask1)
73    intersection = _FieldMaskTree()
74    for path in mask2.paths:
75      tree.IntersectPath(path, intersection)
76    intersection.ToFieldMask(self)
77
78  def MergeMessage(
79      self, source, destination,
80      replace_message_field=False, replace_repeated_field=False):
81    """Merges fields specified in FieldMask from source to destination.
82
83    Args:
84      source: Source message.
85      destination: The destination message to be merged into.
86      replace_message_field: Replace message field if True. Merge message
87          field if False.
88      replace_repeated_field: Replace repeated field if True. Append
89          elements of repeated field if False.
90    """
91    tree = _FieldMaskTree(self)
92    tree.MergeMessage(
93        source, destination, replace_message_field, replace_repeated_field)
94
95
96def _IsValidPath(message_descriptor, path):
97  """Checks whether the path is valid for Message Descriptor."""
98  parts = path.split('.')
99  last = parts.pop()
100  for name in parts:
101    field = message_descriptor.fields_by_name.get(name)
102    if (field is None or
103        field.label == FieldDescriptor.LABEL_REPEATED or
104        field.type != FieldDescriptor.TYPE_MESSAGE):
105      return False
106    message_descriptor = field.message_type
107  return last in message_descriptor.fields_by_name
108
109
110def _CheckFieldMaskMessage(message):
111  """Raises ValueError if message is not a FieldMask."""
112  message_descriptor = message.DESCRIPTOR
113  if (message_descriptor.name != 'FieldMask' or
114      message_descriptor.file.name != 'google/protobuf/field_mask.proto'):
115    raise ValueError('Message {0} is not a FieldMask.'.format(
116        message_descriptor.full_name))
117
118
119def _SnakeCaseToCamelCase(path_name):
120  """Converts a path name from snake_case to camelCase."""
121  result = []
122  after_underscore = False
123  for c in path_name:
124    if c.isupper():
125      raise ValueError(
126          'Fail to print FieldMask to Json string: Path name '
127          '{0} must not contain uppercase letters.'.format(path_name))
128    if after_underscore:
129      if c.islower():
130        result.append(c.upper())
131        after_underscore = False
132      else:
133        raise ValueError(
134            'Fail to print FieldMask to Json string: The '
135            'character after a "_" must be a lowercase letter '
136            'in path name {0}.'.format(path_name))
137    elif c == '_':
138      after_underscore = True
139    else:
140      result += c
141
142  if after_underscore:
143    raise ValueError('Fail to print FieldMask to Json string: Trailing "_" '
144                     'in path name {0}.'.format(path_name))
145  return ''.join(result)
146
147
148def _CamelCaseToSnakeCase(path_name):
149  """Converts a field name from camelCase to snake_case."""
150  result = []
151  for c in path_name:
152    if c == '_':
153      raise ValueError('Fail to parse FieldMask: Path name '
154                       '{0} must not contain "_"s.'.format(path_name))
155    if c.isupper():
156      result += '_'
157      result += c.lower()
158    else:
159      result += c
160  return ''.join(result)
161
162
163class _FieldMaskTree(object):
164  """Represents a FieldMask in a tree structure.
165
166  For example, given a FieldMask "foo.bar,foo.baz,bar.baz",
167  the FieldMaskTree will be:
168      [_root] -+- foo -+- bar
169            |       |
170            |       +- baz
171            |
172            +- bar --- baz
173  In the tree, each leaf node represents a field path.
174  """
175
176  __slots__ = ('_root',)
177
178  def __init__(self, field_mask=None):
179    """Initializes the tree by FieldMask."""
180    self._root = {}
181    if field_mask:
182      self.MergeFromFieldMask(field_mask)
183
184  def MergeFromFieldMask(self, field_mask):
185    """Merges a FieldMask to the tree."""
186    for path in field_mask.paths:
187      self.AddPath(path)
188
189  def AddPath(self, path):
190    """Adds a field path into the tree.
191
192    If the field path to add is a sub-path of an existing field path
193    in the tree (i.e., a leaf node), it means the tree already matches
194    the given path so nothing will be added to the tree. If the path
195    matches an existing non-leaf node in the tree, that non-leaf node
196    will be turned into a leaf node with all its children removed because
197    the path matches all the node's children. Otherwise, a new path will
198    be added.
199
200    Args:
201      path: The field path to add.
202    """
203    node = self._root
204    for name in path.split('.'):
205      if name not in node:
206        node[name] = {}
207      elif not node[name]:
208        # Pre-existing empty node implies we already have this entire tree.
209        return
210      node = node[name]
211    # Remove any sub-trees we might have had.
212    node.clear()
213
214  def ToFieldMask(self, field_mask):
215    """Converts the tree to a FieldMask."""
216    field_mask.Clear()
217    _AddFieldPaths(self._root, '', field_mask)
218
219  def IntersectPath(self, path, intersection):
220    """Calculates the intersection part of a field path with this tree.
221
222    Args:
223      path: The field path to calculates.
224      intersection: The out tree to record the intersection part.
225    """
226    node = self._root
227    for name in path.split('.'):
228      if name not in node:
229        return
230      elif not node[name]:
231        intersection.AddPath(path)
232        return
233      node = node[name]
234    intersection.AddLeafNodes(path, node)
235
236  def AddLeafNodes(self, prefix, node):
237    """Adds leaf nodes begin with prefix to this tree."""
238    if not node:
239      self.AddPath(prefix)
240    for name in node:
241      child_path = prefix + '.' + name
242      self.AddLeafNodes(child_path, node[name])
243
244  def MergeMessage(
245      self, source, destination,
246      replace_message, replace_repeated):
247    """Merge all fields specified by this tree from source to destination."""
248    _MergeMessage(
249        self._root, source, destination, replace_message, replace_repeated)
250
251
252def _StrConvert(value):
253  """Converts value to str if it is not."""
254  # This file is imported by c extension and some methods like ClearField
255  # requires string for the field name. py2/py3 has different text
256  # type and may use unicode.
257  if not isinstance(value, str):
258    return value.encode('utf-8')
259  return value
260
261
262def _MergeMessage(
263    node, source, destination, replace_message, replace_repeated):
264  """Merge all fields specified by a sub-tree from source to destination."""
265  source_descriptor = source.DESCRIPTOR
266  for name in node:
267    child = node[name]
268    field = source_descriptor.fields_by_name[name]
269    if field is None:
270      raise ValueError('Error: Can\'t find field {0} in message {1}.'.format(
271          name, source_descriptor.full_name))
272    if child:
273      # Sub-paths are only allowed for singular message fields.
274      if (field.label == FieldDescriptor.LABEL_REPEATED or
275          field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE):
276        raise ValueError('Error: Field {0} in message {1} is not a singular '
277                         'message field and cannot have sub-fields.'.format(
278                             name, source_descriptor.full_name))
279      if source.HasField(name):
280        _MergeMessage(
281            child, getattr(source, name), getattr(destination, name),
282            replace_message, replace_repeated)
283      continue
284    if field.label == FieldDescriptor.LABEL_REPEATED:
285      if replace_repeated:
286        destination.ClearField(_StrConvert(name))
287      repeated_source = getattr(source, name)
288      repeated_destination = getattr(destination, name)
289      repeated_destination.MergeFrom(repeated_source)
290    else:
291      if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE:
292        if replace_message:
293          destination.ClearField(_StrConvert(name))
294        if source.HasField(name):
295          getattr(destination, name).MergeFrom(getattr(source, name))
296      else:
297        setattr(destination, name, getattr(source, name))
298
299
300def _AddFieldPaths(node, prefix, field_mask):
301  """Adds the field paths descended from node to field_mask."""
302  if not node and prefix:
303    field_mask.paths.append(prefix)
304    return
305  for name in sorted(node):
306    if prefix:
307      child_path = prefix + '.' + name
308    else:
309      child_path = name
310    _AddFieldPaths(node[name], child_path, field_mask)
311