• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for :mod:`protobuf`."""
16
17import collections
18import collections.abc
19import copy
20import inspect
21
22from google.protobuf import field_mask_pb2
23from google.protobuf import message
24from google.protobuf import wrappers_pb2
25
26
27_SENTINEL = object()
28_WRAPPER_TYPES = (
29    wrappers_pb2.BoolValue,
30    wrappers_pb2.BytesValue,
31    wrappers_pb2.DoubleValue,
32    wrappers_pb2.FloatValue,
33    wrappers_pb2.Int32Value,
34    wrappers_pb2.Int64Value,
35    wrappers_pb2.StringValue,
36    wrappers_pb2.UInt32Value,
37    wrappers_pb2.UInt64Value,
38)
39
40
41def from_any_pb(pb_type, any_pb):
42    """Converts an ``Any`` protobuf to the specified message type.
43
44    Args:
45        pb_type (type): the type of the message that any_pb stores an instance
46            of.
47        any_pb (google.protobuf.any_pb2.Any): the object to be converted.
48
49    Returns:
50        pb_type: An instance of the pb_type message.
51
52    Raises:
53        TypeError: if the message could not be converted.
54    """
55    msg = pb_type()
56
57    # Unwrap proto-plus wrapped messages.
58    if callable(getattr(pb_type, "pb", None)):
59        msg_pb = pb_type.pb(msg)
60    else:
61        msg_pb = msg
62
63    # Unpack the Any object and populate the protobuf message instance.
64    if not any_pb.Unpack(msg_pb):
65        raise TypeError(
66            f"Could not convert `{any_pb.TypeName()}` with underlying type `google.protobuf.any_pb2.Any` to `{msg_pb.DESCRIPTOR.full_name}`"
67        )
68
69    # Done; return the message.
70    return msg
71
72
73def check_oneof(**kwargs):
74    """Raise ValueError if more than one keyword argument is not ``None``.
75
76    Args:
77        kwargs (dict): The keyword arguments sent to the function.
78
79    Raises:
80        ValueError: If more than one entry in ``kwargs`` is not ``None``.
81    """
82    # Sanity check: If no keyword arguments were sent, this is fine.
83    if not kwargs:
84        return
85
86    not_nones = [val for val in kwargs.values() if val is not None]
87    if len(not_nones) > 1:
88        raise ValueError(
89            "Only one of {fields} should be set.".format(
90                fields=", ".join(sorted(kwargs.keys()))
91            )
92        )
93
94
95def get_messages(module):
96    """Discovers all protobuf Message classes in a given import module.
97
98    Args:
99        module (module): A Python module; :func:`dir` will be run against this
100            module to find Message subclasses.
101
102    Returns:
103        dict[str, google.protobuf.message.Message]: A dictionary with the
104            Message class names as keys, and the Message subclasses themselves
105            as values.
106    """
107    answer = collections.OrderedDict()
108    for name in dir(module):
109        candidate = getattr(module, name)
110        if inspect.isclass(candidate) and issubclass(candidate, message.Message):
111            answer[name] = candidate
112    return answer
113
114
115def _resolve_subkeys(key, separator="."):
116    """Resolve a potentially nested key.
117
118    If the key contains the ``separator`` (e.g. ``.``) then the key will be
119    split on the first instance of the subkey::
120
121       >>> _resolve_subkeys('a.b.c')
122       ('a', 'b.c')
123       >>> _resolve_subkeys('d|e|f', separator='|')
124       ('d', 'e|f')
125
126    If not, the subkey will be :data:`None`::
127
128        >>> _resolve_subkeys('foo')
129        ('foo', None)
130
131    Args:
132        key (str): A string that may or may not contain the separator.
133        separator (str): The namespace separator. Defaults to `.`.
134
135    Returns:
136        Tuple[str, str]: The key and subkey(s).
137    """
138    parts = key.split(separator, 1)
139
140    if len(parts) > 1:
141        return parts
142    else:
143        return parts[0], None
144
145
146def get(msg_or_dict, key, default=_SENTINEL):
147    """Retrieve a key's value from a protobuf Message or dictionary.
148
149    Args:
150        mdg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
151            object.
152        key (str): The key to retrieve from the object.
153        default (Any): If the key is not present on the object, and a default
154            is set, returns that default instead. A type-appropriate falsy
155            default is generally recommended, as protobuf messages almost
156            always have default values for unset values and it is not always
157            possible to tell the difference between a falsy value and an
158            unset one. If no default is set then :class:`KeyError` will be
159            raised if the key is not present in the object.
160
161    Returns:
162        Any: The return value from the underlying Message or dict.
163
164    Raises:
165        KeyError: If the key is not found. Note that, for unset values,
166            messages and dictionaries may not have consistent behavior.
167        TypeError: If ``msg_or_dict`` is not a Message or Mapping.
168    """
169    # We may need to get a nested key. Resolve this.
170    key, subkey = _resolve_subkeys(key)
171
172    # Attempt to get the value from the two types of objects we know about.
173    # If we get something else, complain.
174    if isinstance(msg_or_dict, message.Message):
175        answer = getattr(msg_or_dict, key, default)
176    elif isinstance(msg_or_dict, collections.abc.Mapping):
177        answer = msg_or_dict.get(key, default)
178    else:
179        raise TypeError(
180            "get() expected a dict or protobuf message, got {!r}.".format(
181                type(msg_or_dict)
182            )
183        )
184
185    # If the object we got back is our sentinel, raise KeyError; this is
186    # a "not found" case.
187    if answer is _SENTINEL:
188        raise KeyError(key)
189
190    # If a subkey exists, call this method recursively against the answer.
191    if subkey is not None and answer is not default:
192        return get(answer, subkey, default=default)
193
194    return answer
195
196
197def _set_field_on_message(msg, key, value):
198    """Set helper for protobuf Messages."""
199    # Attempt to set the value on the types of objects we know how to deal
200    # with.
201    if isinstance(value, (collections.abc.MutableSequence, tuple)):
202        # Clear the existing repeated protobuf message of any elements
203        # currently inside it.
204        while getattr(msg, key):
205            getattr(msg, key).pop()
206
207        # Write our new elements to the repeated field.
208        for item in value:
209            if isinstance(item, collections.abc.Mapping):
210                getattr(msg, key).add(**item)
211            else:
212                # protobuf's RepeatedCompositeContainer doesn't support
213                # append.
214                getattr(msg, key).extend([item])
215    elif isinstance(value, collections.abc.Mapping):
216        # Assign the dictionary values to the protobuf message.
217        for item_key, item_value in value.items():
218            set(getattr(msg, key), item_key, item_value)
219    elif isinstance(value, message.Message):
220        getattr(msg, key).CopyFrom(value)
221    else:
222        setattr(msg, key, value)
223
224
225def set(msg_or_dict, key, value):
226    """Set a key's value on a protobuf Message or dictionary.
227
228    Args:
229        msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
230            object.
231        key (str): The key to set.
232        value (Any): The value to set.
233
234    Raises:
235        TypeError: If ``msg_or_dict`` is not a Message or dictionary.
236    """
237    # Sanity check: Is our target object valid?
238    if not isinstance(msg_or_dict, (collections.abc.MutableMapping, message.Message)):
239        raise TypeError(
240            "set() expected a dict or protobuf message, got {!r}.".format(
241                type(msg_or_dict)
242            )
243        )
244
245    # We may be setting a nested key. Resolve this.
246    basekey, subkey = _resolve_subkeys(key)
247
248    # If a subkey exists, then get that object and call this method
249    # recursively against it using the subkey.
250    if subkey is not None:
251        if isinstance(msg_or_dict, collections.abc.MutableMapping):
252            msg_or_dict.setdefault(basekey, {})
253        set(get(msg_or_dict, basekey), subkey, value)
254        return
255
256    if isinstance(msg_or_dict, collections.abc.MutableMapping):
257        msg_or_dict[key] = value
258    else:
259        _set_field_on_message(msg_or_dict, key, value)
260
261
262def setdefault(msg_or_dict, key, value):
263    """Set the key on a protobuf Message or dictionary to a given value if the
264    current value is falsy.
265
266    Because protobuf Messages do not distinguish between unset values and
267    falsy ones particularly well (by design), this method treats any falsy
268    value (e.g. 0, empty list) as a target to be overwritten, on both Messages
269    and dictionaries.
270
271    Args:
272        msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
273            object.
274        key (str): The key on the object in question.
275        value (Any): The value to set.
276
277    Raises:
278        TypeError: If ``msg_or_dict`` is not a Message or dictionary.
279    """
280    if not get(msg_or_dict, key, default=None):
281        set(msg_or_dict, key, value)
282
283
284def field_mask(original, modified):
285    """Create a field mask by comparing two messages.
286
287    Args:
288        original (~google.protobuf.message.Message): the original message.
289            If set to None, this field will be interpreted as an empty
290            message.
291        modified (~google.protobuf.message.Message): the modified message.
292            If set to None, this field will be interpreted as an empty
293            message.
294
295    Returns:
296        google.protobuf.field_mask_pb2.FieldMask: field mask that contains
297        the list of field names that have different values between the two
298        messages. If the messages are equivalent, then the field mask is empty.
299
300    Raises:
301        ValueError: If the ``original`` or ``modified`` are not the same type.
302    """
303    if original is None and modified is None:
304        return field_mask_pb2.FieldMask()
305
306    if original is None and modified is not None:
307        original = copy.deepcopy(modified)
308        original.Clear()
309
310    if modified is None and original is not None:
311        modified = copy.deepcopy(original)
312        modified.Clear()
313
314    if not isinstance(original, type(modified)):
315        raise ValueError(
316            "expected that both original and modified should be of the "
317            'same type, received "{!r}" and "{!r}".'.format(
318                type(original), type(modified)
319            )
320        )
321
322    return field_mask_pb2.FieldMask(paths=_field_mask_helper(original, modified))
323
324
325def _field_mask_helper(original, modified, current=""):
326    answer = []
327
328    for name in original.DESCRIPTOR.fields_by_name:
329        field_path = _get_path(current, name)
330
331        original_val = getattr(original, name)
332        modified_val = getattr(modified, name)
333
334        if _is_message(original_val) or _is_message(modified_val):
335            if original_val != modified_val:
336                # Wrapper types do not need to include the .value part of the
337                # path.
338                if _is_wrapper(original_val) or _is_wrapper(modified_val):
339                    answer.append(field_path)
340                elif not modified_val.ListFields():
341                    answer.append(field_path)
342                else:
343                    answer.extend(
344                        _field_mask_helper(original_val, modified_val, field_path)
345                    )
346        else:
347            if original_val != modified_val:
348                answer.append(field_path)
349
350    return answer
351
352
353def _get_path(current, name):
354    # gapic-generator-python appends underscores to field names
355    # that collide with python keywords.
356    # `_` is stripped away as it is not possible to
357    # natively define a field with a trailing underscore in protobuf.
358    # APIs will reject field masks if fields have trailing underscores.
359    # See https://github.com/googleapis/python-api-core/issues/227
360    name = name.rstrip("_")
361    if not current:
362        return name
363    return "%s.%s" % (current, name)
364
365
366def _is_message(value):
367    return isinstance(value, message.Message)
368
369
370def _is_wrapper(value):
371    return type(value) in _WRAPPER_TYPES
372