1# Copyright 2017 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import sys 16 17import pytest 18 19from google.api import http_pb2 20from google.api_core import protobuf_helpers 21from google.longrunning import operations_pb2 22from google.protobuf import any_pb2 23from google.protobuf import message 24from google.protobuf import source_context_pb2 25from google.protobuf import struct_pb2 26from google.protobuf import timestamp_pb2 27from google.protobuf import type_pb2 28from google.protobuf import wrappers_pb2 29from google.type import color_pb2 30from google.type import date_pb2 31from google.type import timeofday_pb2 32 33 34def test_from_any_pb_success(): 35 in_message = date_pb2.Date(year=1990) 36 in_message_any = any_pb2.Any() 37 in_message_any.Pack(in_message) 38 out_message = protobuf_helpers.from_any_pb(date_pb2.Date, in_message_any) 39 40 assert in_message == out_message 41 42 43def test_from_any_pb_wrapped_success(): 44 # Declare a message class conforming to wrapped messages. 45 class WrappedDate(object): 46 def __init__(self, **kwargs): 47 self._pb = date_pb2.Date(**kwargs) 48 49 def __eq__(self, other): 50 return self._pb == other 51 52 @classmethod 53 def pb(cls, msg): 54 return msg._pb 55 56 # Run the same test as `test_from_any_pb_success`, but using the 57 # wrapped class. 58 in_message = date_pb2.Date(year=1990) 59 in_message_any = any_pb2.Any() 60 in_message_any.Pack(in_message) 61 out_message = protobuf_helpers.from_any_pb(WrappedDate, in_message_any) 62 63 assert out_message == in_message 64 65 66def test_from_any_pb_failure(): 67 in_message = any_pb2.Any() 68 in_message.Pack(date_pb2.Date(year=1990)) 69 70 with pytest.raises(TypeError): 71 protobuf_helpers.from_any_pb(timeofday_pb2.TimeOfDay, in_message) 72 73 74def test_check_protobuf_helpers_ok(): 75 assert protobuf_helpers.check_oneof() is None 76 assert protobuf_helpers.check_oneof(foo="bar") is None 77 assert protobuf_helpers.check_oneof(foo="bar", baz=None) is None 78 assert protobuf_helpers.check_oneof(foo=None, baz="bacon") is None 79 assert protobuf_helpers.check_oneof(foo="bar", spam=None, eggs=None) is None 80 81 82def test_check_protobuf_helpers_failures(): 83 with pytest.raises(ValueError): 84 protobuf_helpers.check_oneof(foo="bar", spam="eggs") 85 with pytest.raises(ValueError): 86 protobuf_helpers.check_oneof(foo="bar", baz="bacon", spam="eggs") 87 with pytest.raises(ValueError): 88 protobuf_helpers.check_oneof(foo="bar", spam=0, eggs=None) 89 90 91def test_get_messages(): 92 answer = protobuf_helpers.get_messages(date_pb2) 93 94 # Ensure that Date was exported properly. 95 assert answer["Date"] is date_pb2.Date 96 97 # Ensure that no non-Message objects were exported. 98 for value in answer.values(): 99 assert issubclass(value, message.Message) 100 101 102def test_get_dict_absent(): 103 with pytest.raises(KeyError): 104 assert protobuf_helpers.get({}, "foo") 105 106 107def test_get_dict_present(): 108 assert protobuf_helpers.get({"foo": "bar"}, "foo") == "bar" 109 110 111def test_get_dict_default(): 112 assert protobuf_helpers.get({}, "foo", default="bar") == "bar" 113 114 115def test_get_dict_nested(): 116 assert protobuf_helpers.get({"foo": {"bar": "baz"}}, "foo.bar") == "baz" 117 118 119def test_get_dict_nested_default(): 120 assert protobuf_helpers.get({}, "foo.baz", default="bacon") == "bacon" 121 assert protobuf_helpers.get({"foo": {}}, "foo.baz", default="bacon") == "bacon" 122 123 124def test_get_msg_sentinel(): 125 msg = timestamp_pb2.Timestamp() 126 with pytest.raises(KeyError): 127 assert protobuf_helpers.get(msg, "foo") 128 129 130def test_get_msg_present(): 131 msg = timestamp_pb2.Timestamp(seconds=42) 132 assert protobuf_helpers.get(msg, "seconds") == 42 133 134 135def test_get_msg_default(): 136 msg = timestamp_pb2.Timestamp() 137 assert protobuf_helpers.get(msg, "foo", default="bar") == "bar" 138 139 140def test_invalid_object(): 141 with pytest.raises(TypeError): 142 protobuf_helpers.get(object(), "foo", "bar") 143 144 145def test_set_dict(): 146 mapping = {} 147 protobuf_helpers.set(mapping, "foo", "bar") 148 assert mapping == {"foo": "bar"} 149 150 151def test_set_msg(): 152 msg = timestamp_pb2.Timestamp() 153 protobuf_helpers.set(msg, "seconds", 42) 154 assert msg.seconds == 42 155 156 157def test_set_dict_nested(): 158 mapping = {} 159 protobuf_helpers.set(mapping, "foo.bar", "baz") 160 assert mapping == {"foo": {"bar": "baz"}} 161 162 163def test_set_invalid_object(): 164 with pytest.raises(TypeError): 165 protobuf_helpers.set(object(), "foo", "bar") 166 167 168def test_set_list(): 169 list_ops_response = operations_pb2.ListOperationsResponse() 170 171 protobuf_helpers.set( 172 list_ops_response, 173 "operations", 174 [{"name": "foo"}, operations_pb2.Operation(name="bar")], 175 ) 176 177 assert len(list_ops_response.operations) == 2 178 179 for operation in list_ops_response.operations: 180 assert isinstance(operation, operations_pb2.Operation) 181 182 assert list_ops_response.operations[0].name == "foo" 183 assert list_ops_response.operations[1].name == "bar" 184 185 186def test_set_list_clear_existing(): 187 list_ops_response = operations_pb2.ListOperationsResponse( 188 operations=[{"name": "baz"}] 189 ) 190 191 protobuf_helpers.set( 192 list_ops_response, 193 "operations", 194 [{"name": "foo"}, operations_pb2.Operation(name="bar")], 195 ) 196 197 assert len(list_ops_response.operations) == 2 198 for operation in list_ops_response.operations: 199 assert isinstance(operation, operations_pb2.Operation) 200 assert list_ops_response.operations[0].name == "foo" 201 assert list_ops_response.operations[1].name == "bar" 202 203 204def test_set_msg_with_msg_field(): 205 rule = http_pb2.HttpRule() 206 pattern = http_pb2.CustomHttpPattern(kind="foo", path="bar") 207 208 protobuf_helpers.set(rule, "custom", pattern) 209 210 assert rule.custom.kind == "foo" 211 assert rule.custom.path == "bar" 212 213 214def test_set_msg_with_dict_field(): 215 rule = http_pb2.HttpRule() 216 pattern = {"kind": "foo", "path": "bar"} 217 218 protobuf_helpers.set(rule, "custom", pattern) 219 220 assert rule.custom.kind == "foo" 221 assert rule.custom.path == "bar" 222 223 224def test_set_msg_nested_key(): 225 rule = http_pb2.HttpRule(custom=http_pb2.CustomHttpPattern(kind="foo", path="bar")) 226 227 protobuf_helpers.set(rule, "custom.kind", "baz") 228 229 assert rule.custom.kind == "baz" 230 assert rule.custom.path == "bar" 231 232 233def test_setdefault_dict_unset(): 234 mapping = {} 235 protobuf_helpers.setdefault(mapping, "foo", "bar") 236 assert mapping == {"foo": "bar"} 237 238 239def test_setdefault_dict_falsy(): 240 mapping = {"foo": None} 241 protobuf_helpers.setdefault(mapping, "foo", "bar") 242 assert mapping == {"foo": "bar"} 243 244 245def test_setdefault_dict_truthy(): 246 mapping = {"foo": "bar"} 247 protobuf_helpers.setdefault(mapping, "foo", "baz") 248 assert mapping == {"foo": "bar"} 249 250 251def test_setdefault_pb2_falsy(): 252 operation = operations_pb2.Operation() 253 protobuf_helpers.setdefault(operation, "name", "foo") 254 assert operation.name == "foo" 255 256 257def test_setdefault_pb2_truthy(): 258 operation = operations_pb2.Operation(name="bar") 259 protobuf_helpers.setdefault(operation, "name", "foo") 260 assert operation.name == "bar" 261 262 263def test_field_mask_invalid_args(): 264 with pytest.raises(ValueError): 265 protobuf_helpers.field_mask("foo", any_pb2.Any()) 266 with pytest.raises(ValueError): 267 protobuf_helpers.field_mask(any_pb2.Any(), "bar") 268 with pytest.raises(ValueError): 269 protobuf_helpers.field_mask(any_pb2.Any(), operations_pb2.Operation()) 270 271 272def test_field_mask_equal_values(): 273 assert protobuf_helpers.field_mask(None, None).paths == [] 274 275 original = struct_pb2.Value(number_value=1.0) 276 modified = struct_pb2.Value(number_value=1.0) 277 assert protobuf_helpers.field_mask(original, modified).paths == [] 278 279 original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) 280 modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) 281 assert protobuf_helpers.field_mask(original, modified).paths == [] 282 283 original = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0)]) 284 modified = struct_pb2.ListValue(values=[struct_pb2.Value(number_value=1.0)]) 285 assert protobuf_helpers.field_mask(original, modified).paths == [] 286 287 original = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)}) 288 modified = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)}) 289 assert protobuf_helpers.field_mask(original, modified).paths == [] 290 291 292def test_field_mask_zero_values(): 293 # Singular Values 294 original = color_pb2.Color(red=0.0) 295 modified = None 296 assert protobuf_helpers.field_mask(original, modified).paths == [] 297 298 original = None 299 modified = color_pb2.Color(red=0.0) 300 assert protobuf_helpers.field_mask(original, modified).paths == [] 301 302 # Repeated Values 303 original = struct_pb2.ListValue(values=[]) 304 modified = None 305 assert protobuf_helpers.field_mask(original, modified).paths == [] 306 307 original = None 308 modified = struct_pb2.ListValue(values=[]) 309 assert protobuf_helpers.field_mask(original, modified).paths == [] 310 311 # Maps 312 original = struct_pb2.Struct(fields={}) 313 modified = None 314 assert protobuf_helpers.field_mask(original, modified).paths == [] 315 316 original = None 317 modified = struct_pb2.Struct(fields={}) 318 assert protobuf_helpers.field_mask(original, modified).paths == [] 319 320 # Oneofs 321 original = struct_pb2.Value(number_value=0.0) 322 modified = None 323 assert protobuf_helpers.field_mask(original, modified).paths == [] 324 325 original = None 326 modified = struct_pb2.Value(number_value=0.0) 327 assert protobuf_helpers.field_mask(original, modified).paths == [] 328 329 330def test_field_mask_singular_field_diffs(): 331 original = type_pb2.Type(name="name") 332 modified = type_pb2.Type() 333 assert protobuf_helpers.field_mask(original, modified).paths == ["name"] 334 335 original = type_pb2.Type(name="name") 336 modified = type_pb2.Type() 337 assert protobuf_helpers.field_mask(original, modified).paths == ["name"] 338 339 original = None 340 modified = type_pb2.Type(name="name") 341 assert protobuf_helpers.field_mask(original, modified).paths == ["name"] 342 343 original = type_pb2.Type(name="name") 344 modified = None 345 assert protobuf_helpers.field_mask(original, modified).paths == ["name"] 346 347 348def test_field_mask_message_diffs(): 349 original = type_pb2.Type() 350 modified = type_pb2.Type( 351 source_context=source_context_pb2.SourceContext(file_name="name") 352 ) 353 assert protobuf_helpers.field_mask(original, modified).paths == [ 354 "source_context.file_name" 355 ] 356 357 original = type_pb2.Type( 358 source_context=source_context_pb2.SourceContext(file_name="name") 359 ) 360 modified = type_pb2.Type() 361 assert protobuf_helpers.field_mask(original, modified).paths == ["source_context"] 362 363 original = type_pb2.Type( 364 source_context=source_context_pb2.SourceContext(file_name="name") 365 ) 366 modified = type_pb2.Type( 367 source_context=source_context_pb2.SourceContext(file_name="other_name") 368 ) 369 assert protobuf_helpers.field_mask(original, modified).paths == [ 370 "source_context.file_name" 371 ] 372 373 original = None 374 modified = type_pb2.Type( 375 source_context=source_context_pb2.SourceContext(file_name="name") 376 ) 377 assert protobuf_helpers.field_mask(original, modified).paths == [ 378 "source_context.file_name" 379 ] 380 381 original = type_pb2.Type( 382 source_context=source_context_pb2.SourceContext(file_name="name") 383 ) 384 modified = None 385 assert protobuf_helpers.field_mask(original, modified).paths == ["source_context"] 386 387 388def test_field_mask_wrapper_type_diffs(): 389 original = color_pb2.Color() 390 modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) 391 assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] 392 393 original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) 394 modified = color_pb2.Color() 395 assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] 396 397 original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) 398 modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0)) 399 assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] 400 401 original = None 402 modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0)) 403 assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] 404 405 original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) 406 modified = None 407 assert protobuf_helpers.field_mask(original, modified).paths == ["alpha"] 408 409 410def test_field_mask_repeated_diffs(): 411 original = struct_pb2.ListValue() 412 modified = struct_pb2.ListValue( 413 values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] 414 ) 415 assert protobuf_helpers.field_mask(original, modified).paths == ["values"] 416 417 original = struct_pb2.ListValue( 418 values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] 419 ) 420 modified = struct_pb2.ListValue() 421 assert protobuf_helpers.field_mask(original, modified).paths == ["values"] 422 423 original = None 424 modified = struct_pb2.ListValue( 425 values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] 426 ) 427 assert protobuf_helpers.field_mask(original, modified).paths == ["values"] 428 429 original = struct_pb2.ListValue( 430 values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] 431 ) 432 modified = None 433 assert protobuf_helpers.field_mask(original, modified).paths == ["values"] 434 435 original = struct_pb2.ListValue( 436 values=[struct_pb2.Value(number_value=1.0), struct_pb2.Value(number_value=2.0)] 437 ) 438 modified = struct_pb2.ListValue( 439 values=[struct_pb2.Value(number_value=2.0), struct_pb2.Value(number_value=1.0)] 440 ) 441 assert protobuf_helpers.field_mask(original, modified).paths == ["values"] 442 443 444def test_field_mask_map_diffs(): 445 original = struct_pb2.Struct() 446 modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) 447 assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] 448 449 original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) 450 modified = struct_pb2.Struct() 451 assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] 452 453 original = None 454 modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) 455 assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] 456 457 original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) 458 modified = None 459 assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] 460 461 original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) 462 modified = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=2.0)}) 463 assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] 464 465 original = struct_pb2.Struct(fields={"foo": struct_pb2.Value(number_value=1.0)}) 466 modified = struct_pb2.Struct(fields={"bar": struct_pb2.Value(number_value=1.0)}) 467 assert protobuf_helpers.field_mask(original, modified).paths == ["fields"] 468 469 470def test_field_mask_different_level_diffs(): 471 original = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=1.0)) 472 modified = color_pb2.Color(alpha=wrappers_pb2.FloatValue(value=2.0), red=1.0) 473 assert sorted(protobuf_helpers.field_mask(original, modified).paths) == [ 474 "alpha", 475 "red", 476 ] 477 478 479@pytest.mark.skipif( 480 sys.version_info.major == 2, 481 reason="Field names with trailing underscores can only be created" 482 "through proto-plus, which is Python 3 only.", 483) 484def test_field_mask_ignore_trailing_underscore(): 485 import proto 486 487 class Foo(proto.Message): 488 type_ = proto.Field(proto.STRING, number=1) 489 input_config = proto.Field(proto.STRING, number=2) 490 491 modified = Foo(type_="bar", input_config="baz") 492 493 assert sorted(protobuf_helpers.field_mask(None, Foo.pb(modified)).paths) == [ 494 "input_config", 495 "type", 496 ] 497 498 499@pytest.mark.skipif( 500 sys.version_info.major == 2, 501 reason="Field names with trailing underscores can only be created" 502 "through proto-plus, which is Python 3 only.", 503) 504def test_field_mask_ignore_trailing_underscore_with_nesting(): 505 import proto 506 507 class Bar(proto.Message): 508 class Baz(proto.Message): 509 input_config = proto.Field(proto.STRING, number=1) 510 511 type_ = proto.Field(Baz, number=1) 512 513 modified = Bar() 514 modified.type_.input_config = "foo" 515 516 assert sorted(protobuf_helpers.field_mask(None, Bar.pb(modified)).paths) == [ 517 "type.input_config", 518 ] 519