• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for :mod:`protobuf`."""
16
17import collections
18import collections.abc
19import copy
20import inspect
21
22from google.protobuf import field_mask_pb2
23from google.protobuf import message
24from google.protobuf import wrappers_pb2
25
26
27_SENTINEL = object()
28_WRAPPER_TYPES = (
29    wrappers_pb2.BoolValue,
30    wrappers_pb2.BytesValue,
31    wrappers_pb2.DoubleValue,
32    wrappers_pb2.FloatValue,
33    wrappers_pb2.Int32Value,
34    wrappers_pb2.Int64Value,
35    wrappers_pb2.StringValue,
36    wrappers_pb2.UInt32Value,
37    wrappers_pb2.UInt64Value,
38)
39
40
41def from_any_pb(pb_type, any_pb):
42    """Converts an ``Any`` protobuf to the specified message type.
43
44    Args:
45        pb_type (type): the type of the message that any_pb stores an instance
46            of.
47        any_pb (google.protobuf.any_pb2.Any): the object to be converted.
48
49    Returns:
50        pb_type: An instance of the pb_type message.
51
52    Raises:
53        TypeError: if the message could not be converted.
54    """
55    msg = pb_type()
56
57    # Unwrap proto-plus wrapped messages.
58    if callable(getattr(pb_type, "pb", None)):
59        msg_pb = pb_type.pb(msg)
60    else:
61        msg_pb = msg
62
63    # Unpack the Any object and populate the protobuf message instance.
64    if not any_pb.Unpack(msg_pb):
65        raise TypeError(
66            "Could not convert {} to {}".format(
67                any_pb.__class__.__name__, pb_type.__name__
68            )
69        )
70
71    # Done; return the message.
72    return msg
73
74
75def check_oneof(**kwargs):
76    """Raise ValueError if more than one keyword argument is not ``None``.
77
78    Args:
79        kwargs (dict): The keyword arguments sent to the function.
80
81    Raises:
82        ValueError: If more than one entry in ``kwargs`` is not ``None``.
83    """
84    # Sanity check: If no keyword arguments were sent, this is fine.
85    if not kwargs:
86        return
87
88    not_nones = [val for val in kwargs.values() if val is not None]
89    if len(not_nones) > 1:
90        raise ValueError(
91            "Only one of {fields} should be set.".format(
92                fields=", ".join(sorted(kwargs.keys()))
93            )
94        )
95
96
97def get_messages(module):
98    """Discovers all protobuf Message classes in a given import module.
99
100    Args:
101        module (module): A Python module; :func:`dir` will be run against this
102            module to find Message subclasses.
103
104    Returns:
105        dict[str, google.protobuf.message.Message]: A dictionary with the
106            Message class names as keys, and the Message subclasses themselves
107            as values.
108    """
109    answer = collections.OrderedDict()
110    for name in dir(module):
111        candidate = getattr(module, name)
112        if inspect.isclass(candidate) and issubclass(candidate, message.Message):
113            answer[name] = candidate
114    return answer
115
116
117def _resolve_subkeys(key, separator="."):
118    """Resolve a potentially nested key.
119
120    If the key contains the ``separator`` (e.g. ``.``) then the key will be
121    split on the first instance of the subkey::
122
123       >>> _resolve_subkeys('a.b.c')
124       ('a', 'b.c')
125       >>> _resolve_subkeys('d|e|f', separator='|')
126       ('d', 'e|f')
127
128    If not, the subkey will be :data:`None`::
129
130        >>> _resolve_subkeys('foo')
131        ('foo', None)
132
133    Args:
134        key (str): A string that may or may not contain the separator.
135        separator (str): The namespace separator. Defaults to `.`.
136
137    Returns:
138        Tuple[str, str]: The key and subkey(s).
139    """
140    parts = key.split(separator, 1)
141
142    if len(parts) > 1:
143        return parts
144    else:
145        return parts[0], None
146
147
148def get(msg_or_dict, key, default=_SENTINEL):
149    """Retrieve a key's value from a protobuf Message or dictionary.
150
151    Args:
152        mdg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
153            object.
154        key (str): The key to retrieve from the object.
155        default (Any): If the key is not present on the object, and a default
156            is set, returns that default instead. A type-appropriate falsy
157            default is generally recommended, as protobuf messages almost
158            always have default values for unset values and it is not always
159            possible to tell the difference between a falsy value and an
160            unset one. If no default is set then :class:`KeyError` will be
161            raised if the key is not present in the object.
162
163    Returns:
164        Any: The return value from the underlying Message or dict.
165
166    Raises:
167        KeyError: If the key is not found. Note that, for unset values,
168            messages and dictionaries may not have consistent behavior.
169        TypeError: If ``msg_or_dict`` is not a Message or Mapping.
170    """
171    # We may need to get a nested key. Resolve this.
172    key, subkey = _resolve_subkeys(key)
173
174    # Attempt to get the value from the two types of objects we know about.
175    # If we get something else, complain.
176    if isinstance(msg_or_dict, message.Message):
177        answer = getattr(msg_or_dict, key, default)
178    elif isinstance(msg_or_dict, collections.abc.Mapping):
179        answer = msg_or_dict.get(key, default)
180    else:
181        raise TypeError(
182            "get() expected a dict or protobuf message, got {!r}.".format(
183                type(msg_or_dict)
184            )
185        )
186
187    # If the object we got back is our sentinel, raise KeyError; this is
188    # a "not found" case.
189    if answer is _SENTINEL:
190        raise KeyError(key)
191
192    # If a subkey exists, call this method recursively against the answer.
193    if subkey is not None and answer is not default:
194        return get(answer, subkey, default=default)
195
196    return answer
197
198
199def _set_field_on_message(msg, key, value):
200    """Set helper for protobuf Messages."""
201    # Attempt to set the value on the types of objects we know how to deal
202    # with.
203    if isinstance(value, (collections.abc.MutableSequence, tuple)):
204        # Clear the existing repeated protobuf message of any elements
205        # currently inside it.
206        while getattr(msg, key):
207            getattr(msg, key).pop()
208
209        # Write our new elements to the repeated field.
210        for item in value:
211            if isinstance(item, collections.abc.Mapping):
212                getattr(msg, key).add(**item)
213            else:
214                # protobuf's RepeatedCompositeContainer doesn't support
215                # append.
216                getattr(msg, key).extend([item])
217    elif isinstance(value, collections.abc.Mapping):
218        # Assign the dictionary values to the protobuf message.
219        for item_key, item_value in value.items():
220            set(getattr(msg, key), item_key, item_value)
221    elif isinstance(value, message.Message):
222        getattr(msg, key).CopyFrom(value)
223    else:
224        setattr(msg, key, value)
225
226
227def set(msg_or_dict, key, value):
228    """Set a key's value on a protobuf Message or dictionary.
229
230    Args:
231        msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
232            object.
233        key (str): The key to set.
234        value (Any): The value to set.
235
236    Raises:
237        TypeError: If ``msg_or_dict`` is not a Message or dictionary.
238    """
239    # Sanity check: Is our target object valid?
240    if not isinstance(msg_or_dict, (collections.abc.MutableMapping, message.Message)):
241        raise TypeError(
242            "set() expected a dict or protobuf message, got {!r}.".format(
243                type(msg_or_dict)
244            )
245        )
246
247    # We may be setting a nested key. Resolve this.
248    basekey, subkey = _resolve_subkeys(key)
249
250    # If a subkey exists, then get that object and call this method
251    # recursively against it using the subkey.
252    if subkey is not None:
253        if isinstance(msg_or_dict, collections.abc.MutableMapping):
254            msg_or_dict.setdefault(basekey, {})
255        set(get(msg_or_dict, basekey), subkey, value)
256        return
257
258    if isinstance(msg_or_dict, collections.abc.MutableMapping):
259        msg_or_dict[key] = value
260    else:
261        _set_field_on_message(msg_or_dict, key, value)
262
263
264def setdefault(msg_or_dict, key, value):
265    """Set the key on a protobuf Message or dictionary to a given value if the
266    current value is falsy.
267
268    Because protobuf Messages do not distinguish between unset values and
269    falsy ones particularly well (by design), this method treats any falsy
270    value (e.g. 0, empty list) as a target to be overwritten, on both Messages
271    and dictionaries.
272
273    Args:
274        msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the
275            object.
276        key (str): The key on the object in question.
277        value (Any): The value to set.
278
279    Raises:
280        TypeError: If ``msg_or_dict`` is not a Message or dictionary.
281    """
282    if not get(msg_or_dict, key, default=None):
283        set(msg_or_dict, key, value)
284
285
286def field_mask(original, modified):
287    """Create a field mask by comparing two messages.
288
289    Args:
290        original (~google.protobuf.message.Message): the original message.
291            If set to None, this field will be interpretted as an empty
292            message.
293        modified (~google.protobuf.message.Message): the modified message.
294            If set to None, this field will be interpretted as an empty
295            message.
296
297    Returns:
298        google.protobuf.field_mask_pb2.FieldMask: field mask that contains
299        the list of field names that have different values between the two
300        messages. If the messages are equivalent, then the field mask is empty.
301
302    Raises:
303        ValueError: If the ``original`` or ``modified`` are not the same type.
304    """
305    if original is None and modified is None:
306        return field_mask_pb2.FieldMask()
307
308    if original is None and modified is not None:
309        original = copy.deepcopy(modified)
310        original.Clear()
311
312    if modified is None and original is not None:
313        modified = copy.deepcopy(original)
314        modified.Clear()
315
316    if type(original) != type(modified):
317        raise ValueError(
318            "expected that both original and modified should be of the "
319            'same type, received "{!r}" and "{!r}".'.format(
320                type(original), type(modified)
321            )
322        )
323
324    return field_mask_pb2.FieldMask(paths=_field_mask_helper(original, modified))
325
326
327def _field_mask_helper(original, modified, current=""):
328    answer = []
329
330    for name in original.DESCRIPTOR.fields_by_name:
331        field_path = _get_path(current, name)
332
333        original_val = getattr(original, name)
334        modified_val = getattr(modified, name)
335
336        if _is_message(original_val) or _is_message(modified_val):
337            if original_val != modified_val:
338                # Wrapper types do not need to include the .value part of the
339                # path.
340                if _is_wrapper(original_val) or _is_wrapper(modified_val):
341                    answer.append(field_path)
342                elif not modified_val.ListFields():
343                    answer.append(field_path)
344                else:
345                    answer.extend(
346                        _field_mask_helper(original_val, modified_val, field_path)
347                    )
348        else:
349            if original_val != modified_val:
350                answer.append(field_path)
351
352    return answer
353
354
355def _get_path(current, name):
356    # gapic-generator-python appends underscores to field names
357    # that collide with python keywords.
358    # `_` is stripped away as it is not possible to
359    # natively define a field with a trailing underscore in protobuf.
360    # APIs will reject field masks if fields have trailing underscores.
361    # See https://github.com/googleapis/python-api-core/issues/227
362    name = name.rstrip("_")
363    if not current:
364        return name
365    return "%s.%s" % (current, name)
366
367
368def _is_message(value):
369    return isinstance(value, message.Message)
370
371
372def _is_wrapper(value):
373    return type(value) in _WRAPPER_TYPES
374