• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import sys
16
17import pytest
18
19from google.api import http_pb2
20from google.api_core import protobuf_helpers
21from google.longrunning import operations_pb2
22from google.protobuf import any_pb2
23from google.protobuf import message
24from google.protobuf import source_context_pb2
25from google.protobuf import struct_pb2
26from google.protobuf import timestamp_pb2
27from google.protobuf import type_pb2
28from google.protobuf import wrappers_pb2
29from google.type import color_pb2
30from google.type import date_pb2
31from google.type import timeofday_pb2
32
33
34def test_from_any_pb_success():
35    in_message = date_pb2.Date(year=1990)
36    in_message_any = any_pb2.Any()
37    in_message_any.Pack(in_message)
38    out_message = protobuf_helpers.from_any_pb(date_pb2.Date, in_message_any)
39
40    assert in_message == out_message
41
42
43def test_from_any_pb_wrapped_success():
44    # Declare a message class conforming to wrapped messages.
45    class WrappedDate(object):
46        def __init__(self, **kwargs):
47            self._pb = date_pb2.Date(**kwargs)
48
49        def __eq__(self, other):
50            return self._pb == other
51
52        @classmethod
53        def pb(cls, msg):
54            return msg._pb
55
56    # Run the same test as `test_from_any_pb_success`, but using the
57    # wrapped class.
58    in_message = date_pb2.Date(year=1990)
59    in_message_any = any_pb2.Any()
60    in_message_any.Pack(in_message)
61    out_message = protobuf_helpers.from_any_pb(WrappedDate, in_message_any)
62
63    assert out_message == in_message
64
65
66def test_from_any_pb_failure():
67    in_message = any_pb2.Any()
68    in_message.Pack(date_pb2.Date(year=1990))
69
70    with pytest.raises(TypeError):
71        protobuf_helpers.from_any_pb(timeofday_pb2.TimeOfDay, in_message)
72
73
74def test_check_protobuf_helpers_ok():
75    assert protobuf_helpers.check_oneof() is None
76    assert protobuf_helpers.check_oneof(foo="bar") is None
77    assert protobuf_helpers.check_oneof(foo="bar", baz=None) is None
78    assert protobuf_helpers.check_oneof(foo=None, baz="bacon") is None
79    assert protobuf_helpers.check_oneof(foo="bar", spam=None, eggs=None) is None
80
81
82def test_check_protobuf_helpers_failures():
83    with pytest.raises(ValueError):
84        protobuf_helpers.check_oneof(foo="bar", spam="eggs")
85    with pytest.raises(ValueError):
86        protobuf_helpers.check_oneof(foo="bar", baz="bacon", spam="eggs")
87    with pytest.raises(ValueError):
88        protobuf_helpers.check_oneof(foo="bar", spam=0, eggs=None)
89
90
91def test_get_messages():
92    answer = protobuf_helpers.get_messages(date_pb2)
93
94    # Ensure that Date was exported properly.
95    assert answer["Date"] is date_pb2.Date
96
97    # Ensure that no non-Message objects were exported.
98    for value in answer.values():
99        assert issubclass(value, message.Message)
100
101
102def test_get_dict_absent():
103    with pytest.raises(KeyError):
104        assert protobuf_helpers.get({}, "foo")
105
106
107def test_get_dict_present():
108    assert protobuf_helpers.get({"foo": "bar"}, "foo") == "bar"
109
110
111def test_get_dict_default():
112    assert protobuf_helpers.get({}, "foo", default="bar") == "bar"
113
114
115def test_get_dict_nested():
116    assert protobuf_helpers.get({"foo": {"bar": "baz"}}, "foo.bar") == "baz"
117
118
119def test_get_dict_nested_default():
120    assert protobuf_helpers.get({}, "foo.baz", default="bacon") == "bacon"
121    assert protobuf_helpers.get({"foo": {}}, "foo.baz", default="bacon") == "bacon"
122
123
124def test_get_msg_sentinel():
125    msg = timestamp_pb2.Timestamp()
126    with pytest.raises(KeyError):
127        assert protobuf_helpers.get(msg, "foo")
128
129
130def test_get_msg_present():
131    msg = timestamp_pb2.Timestamp(seconds=42)
132    assert protobuf_helpers.get(msg, "seconds") == 42
133
134
135def test_get_msg_default():
136    msg = timestamp_pb2.Timestamp()
137    assert protobuf_helpers.get(msg, "foo", default="bar") == "bar"
138
139
140def test_invalid_object():
141    with pytest.raises(TypeError):
142        protobuf_helpers.get(object(), "foo", "bar")
143
144
145def test_set_dict():
146    mapping = {}
147    protobuf_helpers.set(mapping, "foo", "bar")
148    assert mapping == {"foo": "bar"}
149
150
151def test_set_msg():
152    msg = timestamp_pb2.Timestamp()
153    protobuf_helpers.set(msg, "seconds", 42)
154    assert msg.seconds == 42
155
156
157def test_set_dict_nested():
158    mapping = {}
159    protobuf_helpers.set(mapping, "foo.bar", "baz")
160    assert mapping == {"foo": {"bar": "baz"}}
161
162
163def test_set_invalid_object():
164    with pytest.raises(TypeError):
165        protobuf_helpers.set(object(), "foo", "bar")
166
167
168def test_set_list():
169    list_ops_response = operations_pb2.ListOperationsResponse()
170
171    protobuf_helpers.set(
172        list_ops_response,
173        "operations",
174        [{"name": "foo"}, operations_pb2.Operation(name="bar")],
175    )
176
177    assert len(list_ops_response.operations) == 2
178
179    for operation in list_ops_response.operations:
180        assert isinstance(operation, operations_pb2.Operation)
181
182    assert list_ops_response.operations[0].name == "foo"
183    assert list_ops_response.operations[1].name == "bar"
184
185
186def test_set_list_clear_existing():
187    list_ops_response = operations_pb2.ListOperationsResponse(
188        operations=[{"name": "baz"}]
189    )
190
191    protobuf_helpers.set(
192        list_ops_response,
193        "operations",
194        [{"name": "foo"}, operations_pb2.Operation(name="bar")],
195    )
196
197    assert len(list_ops_response.operations) == 2
198    for operation in list_ops_response.operations:
199        assert isinstance(operation, operations_pb2.Operation)
200    assert list_ops_response.operations[0].name == "foo"
201    assert list_ops_response.operations[1].name == "bar"
202
203
204def test_set_msg_with_msg_field():
205    rule = http_pb2.HttpRule()
206    pattern = http_pb2.CustomHttpPattern(kind="foo", path="bar")
207
208    protobuf_helpers.set(rule, "custom", pattern)
209
210    assert rule.custom.kind == "foo"
211    assert rule.custom.path == "bar"
212
213
214def test_set_msg_with_dict_field():
215    rule = http_pb2.HttpRule()
216    pattern = {"kind": "foo", "path": "bar"}
217
218    protobuf_helpers.set(rule, "custom", pattern)
219
220    assert rule.custom.kind == "foo"
221    assert rule.custom.path == "bar"
222
223
224def test_set_msg_nested_key():
225    rule = http_pb2.HttpRule(custom=http_pb2.CustomHttpPattern(kind="foo", path="bar"))
226
227    protobuf_helpers.set(rule, "custom.kind", "baz")
228
229    assert rule.custom.kind == "baz"
230    assert rule.custom.path == "bar"
231
232
233def test_setdefault_dict_unset():
234    mapping = {}
235    protobuf_helpers.setdefault(mapping, "foo", "bar")
236    assert mapping == {"foo": "bar"}
237
238
239def test_setdefault_dict_falsy():
240    mapping = {"foo": None}
241    protobuf_helpers.setdefault(mapping, "foo", "bar")
242    assert mapping == {"foo": "bar"}
243
244
245def test_setdefault_dict_truthy():
246    mapping = {"foo": "bar"}
247    protobuf_helpers.setdefault(mapping, "foo", "baz")
248    assert mapping == {"foo": "bar"}
249
250
251def test_setdefault_pb2_falsy():
252    operation = operations_pb2.Operation()
253    protobuf_helpers.setdefault(operation, "name", "foo")
254    assert operation.name == "foo"
255
256
257def test_setdefault_pb2_truthy():
258    operation = operations_pb2.Operation(name="bar")
259    protobuf_helpers.setdefault(operation, "name", "foo")
260    assert operation.name == "bar"
261
262
263def test_field_mask_invalid_args():
264    with pytest.raises(ValueError):
265        protobuf_helpers.field_mask("foo", any_pb2.Any())
266    with pytest.raises(ValueError):
267        protobuf_helpers.field_mask(any_pb2.Any(), "bar")
268    with pytest.raises(ValueError):
269        protobuf_helpers.field_mask(any_pb2.Any(), operations_pb2.Operation())
270
271
272def test_field_mask_equal_values():
273    assert protobuf_helpers.field_mask(None, None).paths == []
274
275    original = struct_pb2.Value(number_value=1.0)
276    modified = struct_pb2.Value(number_value=1.0)
277    assert protobuf_helpers.field_mask(original, modified).paths == []
278
279    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
280    modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
281    assert protobuf_helpers.field_mask(original, modified).paths == []
282
283    original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0)])
284    modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0)])
285    assert protobuf_helpers.field_mask(original, modified).paths == []
286
287    original = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)})
288    modified = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)})
289    assert protobuf_helpers.field_mask(original, modified).paths == []
290
291
292def test_field_mask_zero_values():
293    # Singular Values
294    original = color_pb2.Color(red=0.0)
295    modified = None
296    assert protobuf_helpers.field_mask(original, modified).paths == []
297
298    original = None
299    modified = color_pb2.Color(red=0.0)
300    assert protobuf_helpers.field_mask(original, modified).paths == []
301
302    # Repeated Values
303    original = struct_pb2.ListValue(values=[])
304    modified = None
305    assert protobuf_helpers.field_mask(original, modified).paths == []
306
307    original = None
308    modified = struct_pb2.ListValue(values=[])
309    assert protobuf_helpers.field_mask(original, modified).paths == []
310
311    # Maps
312    original = struct_pb2.Struct(fields={})
313    modified = None
314    assert protobuf_helpers.field_mask(original, modified).paths == []
315
316    original = None
317    modified = struct_pb2.Struct(fields={})
318    assert protobuf_helpers.field_mask(original, modified).paths == []
319
320    # Oneofs
321    original = struct_pb2.Value(number_value=0.0)
322    modified = None
323    assert protobuf_helpers.field_mask(original, modified).paths == []
324
325    original = None
326    modified = struct_pb2.Value(number_value=0.0)
327    assert protobuf_helpers.field_mask(original, modified).paths == []
328
329
330def test_field_mask_singular_field_diffs():
331    original = type_pb2.Type(name="name")
332    modified = type_pb2.Type()
333    assert protobuf_helpers.field_mask(original, modified).paths == ["name"]
334
335    original = type_pb2.Type(name="name")
336    modified = type_pb2.Type()
337    assert protobuf_helpers.field_mask(original, modified).paths == ["name"]
338
339    original = None
340    modified = type_pb2.Type(name="name")
341    assert protobuf_helpers.field_mask(original, modified).paths == ["name"]
342
343    original = type_pb2.Type(name="name")
344    modified = None
345    assert protobuf_helpers.field_mask(original, modified).paths == ["name"]
346
347
348def test_field_mask_message_diffs():
349    original = type_pb2.Type()
350    modified = type_pb2.Type(
351        source_context=source_context_pb2.SourceContext(file_name="name")
352    )
353    assert protobuf_helpers.field_mask(original, modified).paths == [
354        "source_context.file_name"
355    ]
356
357    original = type_pb2.Type(
358        source_context=source_context_pb2.SourceContext(file_name="name")
359    )
360    modified = type_pb2.Type()
361    assert protobuf_helpers.field_mask(original, modified).paths == ["source_context"]
362
363    original = type_pb2.Type(
364        source_context=source_context_pb2.SourceContext(file_name="name")
365    )
366    modified = type_pb2.Type(
367        source_context=source_context_pb2.SourceContext(file_name="other_name")
368    )
369    assert protobuf_helpers.field_mask(original, modified).paths == [
370        "source_context.file_name"
371    ]
372
373    original = None
374    modified = type_pb2.Type(
375        source_context=source_context_pb2.SourceContext(file_name="name")
376    )
377    assert protobuf_helpers.field_mask(original, modified).paths == [
378        "source_context.file_name"
379    ]
380
381    original = type_pb2.Type(
382        source_context=source_context_pb2.SourceContext(file_name="name")
383    )
384    modified = None
385    assert protobuf_helpers.field_mask(original, modified).paths == ["source_context"]
386
387
388def test_field_mask_wrapper_type_diffs():
389    original = color_pb2.Color()
390    modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
391    assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
392
393    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
394    modified = color_pb2.Color()
395    assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
396
397    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
398    modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0))
399    assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
400
401    original = None
402    modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0))
403    assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
404
405    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
406    modified = None
407    assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"]
408
409
410def test_field_mask_repeated_diffs():
411    original = struct_pb2.ListValue()
412    modified = struct_pb2.ListValue(
413        values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
414    )
415    assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
416
417    original = struct_pb2.ListValue(
418        values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
419    )
420    modified = struct_pb2.ListValue()
421    assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
422
423    original = None
424    modified = struct_pb2.ListValue(
425        values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
426    )
427    assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
428
429    original = struct_pb2.ListValue(
430        values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
431    )
432    modified = None
433    assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
434
435    original = struct_pb2.ListValue(
436        values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)]
437    )
438    modified = struct_pb2.ListValue(
439        values=[struct_pb2.Value(number_value=2.0), struct_pb2.Value(number_value=1.0)]
440    )
441    assert protobuf_helpers.field_mask(original, modified).paths == ["values"]
442
443
444def test_field_mask_map_diffs():
445    original = struct_pb2.Struct()
446    modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
447    assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
448
449    original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
450    modified = struct_pb2.Struct()
451    assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
452
453    original = None
454    modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
455    assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
456
457    original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
458    modified = None
459    assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
460
461    original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
462    modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=2.0)})
463    assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
464
465    original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)})
466    modified = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)})
467    assert protobuf_helpers.field_mask(original, modified).paths == ["fields"]
468
469
470def test_field_mask_different_level_diffs():
471    original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0))
472    modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0), red=1.0)
473    assert sorted(protobuf_helpers.field_mask(original, modified).paths) == [
474        "alpha",
475        "red",
476    ]
477
478
479@pytest.mark.skipif(
480    sys.version_info.major == 2,
481    reason="Field names with trailing underscores can only be created"
482    "through proto-plus, which is Python 3 only.",
483)
484def test_field_mask_ignore_trailing_underscore():
485    import proto
486
487    class Foo(proto.Message):
488        type_ = proto.Field(proto.STRING, number=1)
489        input_config = proto.Field(proto.STRING, number=2)
490
491    modified = Foo(type_="bar", input_config="baz")
492
493    assert sorted(protobuf_helpers.field_mask(None, Foo.pb(modified)).paths) == [
494        "input_config",
495        "type",
496    ]
497
498
499@pytest.mark.skipif(
500    sys.version_info.major == 2,
501    reason="Field names with trailing underscores can only be created"
502    "through proto-plus, which is Python 3 only.",
503)
504def test_field_mask_ignore_trailing_underscore_with_nesting():
505    import proto
506
507    class Bar(proto.Message):
508        class Baz(proto.Message):
509            input_config = proto.Field(proto.STRING, number=1)
510
511        type_ = proto.Field(Baz, number=1)
512
513    modified = Bar()
514    modified.type_.input_config = "foo"
515
516    assert sorted(protobuf_helpers.field_mask(None, Bar.pb(modified)).paths) == [
517        "type.input_config",
518    ]
519