1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# 4# Use of this source code is governed by a BSD-style 5# license that can be found in the LICENSE file or at 6# https://developers.google.com/open-source/licenses/bsd 7 8# This code is meant to work on Python 2.4 and above only. 9# 10# TODO: Helpers for verbose, common checks like seeing if a 11# descriptor's cpp_type is CPPTYPE_MESSAGE. 12 13"""Contains a metaclass and helper functions used to create 14protocol message classes from Descriptor objects at runtime. 15 16Recall that a metaclass is the "type" of a class. 17(A class is to a metaclass what an instance is to a class.) 18 19In this case, we use the GeneratedProtocolMessageType metaclass 20to inject all the useful functionality into the classes 21output by the protocol compiler at compile-time. 22 23The upshot of all this is that the real implementation 24details for ALL pure-Python protocol buffers are *here in 25this file*. 26""" 27 28__author__ = 'robinson@google.com (Will Robinson)' 29 30import datetime 31from io import BytesIO 32import struct 33import sys 34import warnings 35import weakref 36 37from google.protobuf import descriptor as descriptor_mod 38from google.protobuf import message as message_mod 39from google.protobuf import text_format 40# We use "as" to avoid name collisions with variables. 41from google.protobuf.internal import api_implementation 42from google.protobuf.internal import containers 43from google.protobuf.internal import decoder 44from google.protobuf.internal import encoder 45from google.protobuf.internal import enum_type_wrapper 46from google.protobuf.internal import extension_dict 47from google.protobuf.internal import message_listener as message_listener_mod 48from google.protobuf.internal import type_checkers 49from google.protobuf.internal import well_known_types 50from google.protobuf.internal import wire_format 51 52_FieldDescriptor = descriptor_mod.FieldDescriptor 53_AnyFullTypeName = 'google.protobuf.Any' 54_StructFullTypeName = 'google.protobuf.Struct' 55_ListValueFullTypeName = 'google.protobuf.ListValue' 56_ExtensionDict = extension_dict._ExtensionDict 57 58class GeneratedProtocolMessageType(type): 59 60 """Metaclass for protocol message classes created at runtime from Descriptors. 61 62 We add implementations for all methods described in the Message class. We 63 also create properties to allow getting/setting all fields in the protocol 64 message. Finally, we create slots to prevent users from accidentally 65 "setting" nonexistent fields in the protocol message, which then wouldn't get 66 serialized / deserialized properly. 67 68 The protocol compiler currently uses this metaclass to create protocol 69 message classes at runtime. Clients can also manually create their own 70 classes at runtime, as in this example: 71 72 mydescriptor = Descriptor(.....) 73 factory = symbol_database.Default() 74 factory.pool.AddDescriptor(mydescriptor) 75 MyProtoClass = message_factory.GetMessageClass(mydescriptor) 76 myproto_instance = MyProtoClass() 77 myproto.foo_field = 23 78 ... 79 """ 80 81 # Must be consistent with the protocol-compiler code in 82 # proto2/compiler/internal/generator.*. 83 _DESCRIPTOR_KEY = 'DESCRIPTOR' 84 85 def __new__(cls, name, bases, dictionary): 86 """Custom allocation for runtime-generated class types. 87 88 We override __new__ because this is apparently the only place 89 where we can meaningfully set __slots__ on the class we're creating(?). 90 (The interplay between metaclasses and slots is not very well-documented). 91 92 Args: 93 name: Name of the class (ignored, but required by the 94 metaclass protocol). 95 bases: Base classes of the class we're constructing. 96 (Should be message.Message). We ignore this field, but 97 it's required by the metaclass protocol 98 dictionary: The class dictionary of the class we're 99 constructing. dictionary[_DESCRIPTOR_KEY] must contain 100 a Descriptor object describing this protocol message 101 type. 102 103 Returns: 104 Newly-allocated class. 105 106 Raises: 107 RuntimeError: Generated code only work with python cpp extension. 108 """ 109 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 110 111 if isinstance(descriptor, str): 112 raise RuntimeError('The generated code only work with python cpp ' 113 'extension, but it is using pure python runtime.') 114 115 # If a concrete class already exists for this descriptor, don't try to 116 # create another. Doing so will break any messages that already exist with 117 # the existing class. 118 # 119 # The C++ implementation appears to have its own internal `PyMessageFactory` 120 # to achieve similar results. 121 # 122 # This most commonly happens in `text_format.py` when using descriptors from 123 # a custom pool; it calls message_factory.GetMessageClass() on a 124 # descriptor which already has an existing concrete class. 125 new_class = getattr(descriptor, '_concrete_class', None) 126 if new_class: 127 return new_class 128 129 if descriptor.full_name in well_known_types.WKTBASES: 130 bases += (well_known_types.WKTBASES[descriptor.full_name],) 131 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 132 _AddSlots(descriptor, dictionary) 133 134 superclass = super(GeneratedProtocolMessageType, cls) 135 new_class = superclass.__new__(cls, name, bases, dictionary) 136 return new_class 137 138 def __init__(cls, name, bases, dictionary): 139 """Here we perform the majority of our work on the class. 140 We add enum getters, an __init__ method, implementations 141 of all Message methods, and properties for all fields 142 in the protocol type. 143 144 Args: 145 name: Name of the class (ignored, but required by the 146 metaclass protocol). 147 bases: Base classes of the class we're constructing. 148 (Should be message.Message). We ignore this field, but 149 it's required by the metaclass protocol 150 dictionary: The class dictionary of the class we're 151 constructing. dictionary[_DESCRIPTOR_KEY] must contain 152 a Descriptor object describing this protocol message 153 type. 154 """ 155 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 156 157 # If this is an _existing_ class looked up via `_concrete_class` in the 158 # __new__ method above, then we don't need to re-initialize anything. 159 existing_class = getattr(descriptor, '_concrete_class', None) 160 if existing_class: 161 assert existing_class is cls, ( 162 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' 163 % (descriptor.full_name)) 164 return 165 166 cls._message_set_decoders_by_tag = {} 167 cls._fields_by_tag = {} 168 if (descriptor.has_options and 169 descriptor.GetOptions().message_set_wire_format): 170 cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 171 decoder.MessageSetItemDecoder(descriptor), 172 None, 173 ) 174 175 # Attach stuff to each FieldDescriptor for quick lookup later on. 176 for field in descriptor.fields: 177 _AttachFieldHelpers(cls, field) 178 179 if descriptor.is_extendable and hasattr(descriptor.file, 'pool'): 180 extensions = descriptor.file.pool.FindAllExtensions(descriptor) 181 for ext in extensions: 182 _AttachFieldHelpers(cls, ext) 183 184 descriptor._concrete_class = cls # pylint: disable=protected-access 185 _AddEnumValues(descriptor, cls) 186 _AddInitMethod(descriptor, cls) 187 _AddPropertiesForFields(descriptor, cls) 188 _AddPropertiesForExtensions(descriptor, cls) 189 _AddStaticMethods(cls) 190 _AddMessageMethods(descriptor, cls) 191 _AddPrivateHelperMethods(descriptor, cls) 192 193 superclass = super(GeneratedProtocolMessageType, cls) 194 superclass.__init__(name, bases, dictionary) 195 196 197# Stateless helpers for GeneratedProtocolMessageType below. 198# Outside clients should not access these directly. 199# 200# I opted not to make any of these methods on the metaclass, to make it more 201# clear that I'm not really using any state there and to keep clients from 202# thinking that they have direct access to these construction helpers. 203 204 205def _PropertyName(proto_field_name): 206 """Returns the name of the public property attribute which 207 clients can use to get and (in some cases) set the value 208 of a protocol message field. 209 210 Args: 211 proto_field_name: The protocol message field name, exactly 212 as it appears (or would appear) in a .proto file. 213 """ 214 # TODO: Escape Python keywords (e.g., yield), and test this support. 215 # nnorwitz makes my day by writing: 216 # """ 217 # FYI. See the keyword module in the stdlib. This could be as simple as: 218 # 219 # if keyword.iskeyword(proto_field_name): 220 # return proto_field_name + "_" 221 # return proto_field_name 222 # """ 223 # Kenton says: The above is a BAD IDEA. People rely on being able to use 224 # getattr() and setattr() to reflectively manipulate field values. If we 225 # rename the properties, then every such user has to also make sure to apply 226 # the same transformation. Note that currently if you name a field "yield", 227 # you can still access it just fine using getattr/setattr -- it's not even 228 # that cumbersome to do so. 229 # TODO: Remove this method entirely if/when everyone agrees with my 230 # position. 231 return proto_field_name 232 233 234def _AddSlots(message_descriptor, dictionary): 235 """Adds a __slots__ entry to dictionary, containing the names of all valid 236 attributes for this message type. 237 238 Args: 239 message_descriptor: A Descriptor instance describing this message type. 240 dictionary: Class dictionary to which we'll add a '__slots__' entry. 241 """ 242 dictionary['__slots__'] = ['_cached_byte_size', 243 '_cached_byte_size_dirty', 244 '_fields', 245 '_unknown_fields', 246 '_is_present_in_parent', 247 '_listener', 248 '_listener_for_children', 249 '__weakref__', 250 '_oneofs'] 251 252 253def _IsMessageSetExtension(field): 254 return (field.is_extension and 255 field.containing_type.has_options and 256 field.containing_type.GetOptions().message_set_wire_format and 257 field.type == _FieldDescriptor.TYPE_MESSAGE and 258 field.label == _FieldDescriptor.LABEL_OPTIONAL) 259 260 261def _IsMapField(field): 262 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 263 field.message_type._is_map_entry) 264 265 266def _IsMessageMapField(field): 267 value_type = field.message_type.fields_by_name['value'] 268 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 269 270def _AttachFieldHelpers(cls, field_descriptor): 271 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED 272 field_descriptor._default_constructor = _DefaultValueConstructorForField( 273 field_descriptor 274 ) 275 276 def AddFieldByTag(wiretype, is_packed): 277 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 278 cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed) 279 280 AddFieldByTag( 281 type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False 282 ) 283 284 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 285 # To support wire compatibility of adding packed = true, add a decoder for 286 # packed values regardless of the field's options. 287 AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 288 289 290def _MaybeAddEncoder(cls, field_descriptor): 291 if hasattr(field_descriptor, '_encoder'): 292 return 293 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 294 is_map_entry = _IsMapField(field_descriptor) 295 is_packed = field_descriptor.is_packed 296 297 if is_map_entry: 298 field_encoder = encoder.MapEncoder(field_descriptor) 299 sizer = encoder.MapSizer(field_descriptor, 300 _IsMessageMapField(field_descriptor)) 301 elif _IsMessageSetExtension(field_descriptor): 302 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 303 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 304 else: 305 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 306 field_descriptor.number, is_repeated, is_packed) 307 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 308 field_descriptor.number, is_repeated, is_packed) 309 310 field_descriptor._sizer = sizer 311 field_descriptor._encoder = field_encoder 312 313 314def _MaybeAddDecoder(cls, field_descriptor): 315 if hasattr(field_descriptor, '_decoders'): 316 return 317 318 is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED 319 is_map_entry = _IsMapField(field_descriptor) 320 helper_decoders = {} 321 322 def AddDecoder(is_packed): 323 decode_type = field_descriptor.type 324 if (decode_type == _FieldDescriptor.TYPE_ENUM and 325 not field_descriptor.enum_type.is_closed): 326 decode_type = _FieldDescriptor.TYPE_INT32 327 328 oneof_descriptor = None 329 if field_descriptor.containing_oneof is not None: 330 oneof_descriptor = field_descriptor 331 332 if is_map_entry: 333 is_message_map = _IsMessageMapField(field_descriptor) 334 335 field_decoder = decoder.MapDecoder( 336 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 337 is_message_map) 338 elif decode_type == _FieldDescriptor.TYPE_STRING: 339 field_decoder = decoder.StringDecoder( 340 field_descriptor.number, is_repeated, is_packed, 341 field_descriptor, field_descriptor._default_constructor, 342 not field_descriptor.has_presence) 343 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 344 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 345 field_descriptor.number, is_repeated, is_packed, 346 field_descriptor, field_descriptor._default_constructor) 347 else: 348 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 349 field_descriptor.number, is_repeated, is_packed, 350 # pylint: disable=protected-access 351 field_descriptor, field_descriptor._default_constructor, 352 not field_descriptor.has_presence) 353 354 helper_decoders[is_packed] = field_decoder 355 356 AddDecoder(False) 357 358 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 359 # To support wire compatibility of adding packed = true, add a decoder for 360 # packed values regardless of the field's options. 361 AddDecoder(True) 362 363 field_descriptor._decoders = helper_decoders 364 365 366def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 367 extensions = descriptor.extensions_by_name 368 for extension_name, extension_field in extensions.items(): 369 assert extension_name not in dictionary 370 dictionary[extension_name] = extension_field 371 372 373def _AddEnumValues(descriptor, cls): 374 """Sets class-level attributes for all enum fields defined in this message. 375 376 Also exporting a class-level object that can name enum values. 377 378 Args: 379 descriptor: Descriptor object for this message type. 380 cls: Class we're constructing for this message type. 381 """ 382 for enum_type in descriptor.enum_types: 383 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 384 for enum_value in enum_type.values: 385 setattr(cls, enum_value.name, enum_value.number) 386 387 388def _GetInitializeDefaultForMap(field): 389 if field.label != _FieldDescriptor.LABEL_REPEATED: 390 raise ValueError('map_entry set on non-repeated field %s' % ( 391 field.name)) 392 fields_by_name = field.message_type.fields_by_name 393 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 394 395 value_field = fields_by_name['value'] 396 if _IsMessageMapField(field): 397 def MakeMessageMapDefault(message): 398 return containers.MessageMap( 399 message._listener_for_children, value_field.message_type, key_checker, 400 field.message_type) 401 return MakeMessageMapDefault 402 else: 403 value_checker = type_checkers.GetTypeChecker(value_field) 404 def MakePrimitiveMapDefault(message): 405 return containers.ScalarMap( 406 message._listener_for_children, key_checker, value_checker, 407 field.message_type) 408 return MakePrimitiveMapDefault 409 410def _DefaultValueConstructorForField(field): 411 """Returns a function which returns a default value for a field. 412 413 Args: 414 field: FieldDescriptor object for this field. 415 416 The returned function has one argument: 417 message: Message instance containing this field, or a weakref proxy 418 of same. 419 420 That function in turn returns a default value for this field. The default 421 value may refer back to |message| via a weak reference. 422 """ 423 424 if _IsMapField(field): 425 return _GetInitializeDefaultForMap(field) 426 427 if field.label == _FieldDescriptor.LABEL_REPEATED: 428 if field.has_default_value and field.default_value != []: 429 raise ValueError('Repeated field default value not empty list: %s' % ( 430 field.default_value)) 431 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 432 # We can't look at _concrete_class yet since it might not have 433 # been set. (Depends on order in which we initialize the classes). 434 message_type = field.message_type 435 def MakeRepeatedMessageDefault(message): 436 return containers.RepeatedCompositeFieldContainer( 437 message._listener_for_children, field.message_type) 438 return MakeRepeatedMessageDefault 439 else: 440 type_checker = type_checkers.GetTypeChecker(field) 441 def MakeRepeatedScalarDefault(message): 442 return containers.RepeatedScalarFieldContainer( 443 message._listener_for_children, type_checker) 444 return MakeRepeatedScalarDefault 445 446 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 447 message_type = field.message_type 448 def MakeSubMessageDefault(message): 449 # _concrete_class may not yet be initialized. 450 if not hasattr(message_type, '_concrete_class'): 451 from google.protobuf import message_factory 452 message_factory.GetMessageClass(message_type) 453 result = message_type._concrete_class() 454 result._SetListener( 455 _OneofListener(message, field) 456 if field.containing_oneof is not None 457 else message._listener_for_children) 458 return result 459 return MakeSubMessageDefault 460 461 def MakeScalarDefault(message): 462 # TODO: This may be broken since there may not be 463 # default_value. Combine with has_default_value somehow. 464 return field.default_value 465 return MakeScalarDefault 466 467 468def _ReraiseTypeErrorWithFieldName(message_name, field_name): 469 """Re-raise the currently-handled TypeError with the field name added.""" 470 exc = sys.exc_info()[1] 471 if len(exc.args) == 1 and type(exc) is TypeError: 472 # simple TypeError; add field name to exception message 473 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 474 475 # re-raise possibly-amended exception with original traceback: 476 raise exc.with_traceback(sys.exc_info()[2]) 477 478 479def _AddInitMethod(message_descriptor, cls): 480 """Adds an __init__ method to cls.""" 481 482 def _GetIntegerEnumValue(enum_type, value): 483 """Convert a string or integer enum value to an integer. 484 485 If the value is a string, it is converted to the enum value in 486 enum_type with the same name. If the value is not a string, it's 487 returned as-is. (No conversion or bounds-checking is done.) 488 """ 489 if isinstance(value, str): 490 try: 491 return enum_type.values_by_name[value].number 492 except KeyError: 493 raise ValueError('Enum type %s: unknown label "%s"' % ( 494 enum_type.full_name, value)) 495 return value 496 497 def init(self, **kwargs): 498 self._cached_byte_size = 0 499 self._cached_byte_size_dirty = len(kwargs) > 0 500 self._fields = {} 501 # Contains a mapping from oneof field descriptors to the descriptor 502 # of the currently set field in that oneof field. 503 self._oneofs = {} 504 505 # _unknown_fields is () when empty for efficiency, and will be turned into 506 # a list if fields are added. 507 self._unknown_fields = () 508 self._is_present_in_parent = False 509 self._listener = message_listener_mod.NullMessageListener() 510 self._listener_for_children = _Listener(self) 511 for field_name, field_value in kwargs.items(): 512 field = _GetFieldByName(message_descriptor, field_name) 513 if field is None: 514 raise TypeError('%s() got an unexpected keyword argument "%s"' % 515 (message_descriptor.name, field_name)) 516 if field_value is None: 517 # field=None is the same as no field at all. 518 continue 519 if field.label == _FieldDescriptor.LABEL_REPEATED: 520 field_copy = field._default_constructor(self) 521 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 522 if _IsMapField(field): 523 if _IsMessageMapField(field): 524 for key in field_value: 525 field_copy[key].MergeFrom(field_value[key]) 526 else: 527 field_copy.update(field_value) 528 else: 529 for val in field_value: 530 if isinstance(val, dict): 531 field_copy.add(**val) 532 else: 533 field_copy.add().MergeFrom(val) 534 else: # Scalar 535 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 536 field_value = [_GetIntegerEnumValue(field.enum_type, val) 537 for val in field_value] 538 field_copy.extend(field_value) 539 self._fields[field] = field_copy 540 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 541 field_copy = field._default_constructor(self) 542 new_val = None 543 if isinstance(field_value, message_mod.Message): 544 new_val = field_value 545 elif isinstance(field_value, dict): 546 if field.message_type.full_name == _StructFullTypeName: 547 field_copy.Clear() 548 if len(field_value) == 1 and 'fields' in field_value: 549 try: 550 field_copy.update(field_value) 551 except: 552 # Fall back to init normal message field 553 field_copy.Clear() 554 new_val = field.message_type._concrete_class(**field_value) 555 else: 556 field_copy.update(field_value) 557 else: 558 new_val = field.message_type._concrete_class(**field_value) 559 elif hasattr(field_copy, '_internal_assign'): 560 field_copy._internal_assign(field_value) 561 else: 562 raise TypeError( 563 'Message field {0}.{1} must be initialized with a ' 564 'dict or instance of same class, got {2}.'.format( 565 message_descriptor.name, 566 field_name, 567 type(field_value).__name__, 568 ) 569 ) 570 571 if new_val != None: 572 try: 573 field_copy.MergeFrom(new_val) 574 except TypeError: 575 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 576 self._fields[field] = field_copy 577 else: 578 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 579 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 580 try: 581 setattr(self, field_name, field_value) 582 except TypeError: 583 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 584 585 init.__module__ = None 586 init.__doc__ = None 587 cls.__init__ = init 588 589 590def _GetFieldByName(message_descriptor, field_name): 591 """Returns a field descriptor by field name. 592 593 Args: 594 message_descriptor: A Descriptor describing all fields in message. 595 field_name: The name of the field to retrieve. 596 Returns: 597 The field descriptor associated with the field name. 598 """ 599 try: 600 return message_descriptor.fields_by_name[field_name] 601 except KeyError: 602 raise ValueError('Protocol message %s has no "%s" field.' % 603 (message_descriptor.name, field_name)) 604 605 606def _AddPropertiesForFields(descriptor, cls): 607 """Adds properties for all fields in this protocol message type.""" 608 for field in descriptor.fields: 609 _AddPropertiesForField(field, cls) 610 611 if descriptor.is_extendable: 612 # _ExtensionDict is just an adaptor with no state so we allocate a new one 613 # every time it is accessed. 614 cls.Extensions = property(lambda self: _ExtensionDict(self)) 615 616 617def _AddPropertiesForField(field, cls): 618 """Adds a public property for a protocol message field. 619 Clients can use this property to get and (in the case 620 of non-repeated scalar fields) directly set the value 621 of a protocol message field. 622 623 Args: 624 field: A FieldDescriptor for this field. 625 cls: The class we're constructing. 626 """ 627 # Catch it if we add other types that we should 628 # handle specially here. 629 assert _FieldDescriptor.MAX_CPPTYPE == 10 630 631 constant_name = field.name.upper() + '_FIELD_NUMBER' 632 setattr(cls, constant_name, field.number) 633 634 if field.label == _FieldDescriptor.LABEL_REPEATED: 635 _AddPropertiesForRepeatedField(field, cls) 636 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 637 _AddPropertiesForNonRepeatedCompositeField(field, cls) 638 else: 639 _AddPropertiesForNonRepeatedScalarField(field, cls) 640 641 642class _FieldProperty(property): 643 __slots__ = ('DESCRIPTOR',) 644 645 def __init__(self, descriptor, getter, setter, doc): 646 property.__init__(self, getter, setter, doc=doc) 647 self.DESCRIPTOR = descriptor 648 649 650def _AddPropertiesForRepeatedField(field, cls): 651 """Adds a public property for a "repeated" protocol message field. Clients 652 can use this property to get the value of the field, which will be either a 653 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see 654 below). 655 656 Note that when clients add values to these containers, we perform 657 type-checking in the case of repeated scalar fields, and we also set any 658 necessary "has" bits as a side-effect. 659 660 Args: 661 field: A FieldDescriptor for this field. 662 cls: The class we're constructing. 663 """ 664 proto_field_name = field.name 665 property_name = _PropertyName(proto_field_name) 666 667 def getter(self): 668 field_value = self._fields.get(field) 669 if field_value is None: 670 # Construct a new object to represent this field. 671 field_value = field._default_constructor(self) 672 673 # Atomically check if another thread has preempted us and, if not, swap 674 # in the new object we just created. If someone has preempted us, we 675 # take that object and discard ours. 676 # WARNING: We are relying on setdefault() being atomic. This is true 677 # in CPython but we haven't investigated others. This warning appears 678 # in several other locations in this file. 679 field_value = self._fields.setdefault(field, field_value) 680 return field_value 681 getter.__module__ = None 682 getter.__doc__ = 'Getter for %s.' % proto_field_name 683 684 # We define a setter just so we can throw an exception with a more 685 # helpful error message. 686 def setter(self, new_value): 687 raise AttributeError('Assignment not allowed to repeated field ' 688 '"%s" in protocol message object.' % proto_field_name) 689 690 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 691 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 692 693 694def _AddPropertiesForNonRepeatedScalarField(field, cls): 695 """Adds a public property for a nonrepeated, scalar protocol message field. 696 Clients can use this property to get and directly set the value of the field. 697 Note that when the client sets the value of a field by using this property, 698 all necessary "has" bits are set as a side-effect, and we also perform 699 type-checking. 700 701 Args: 702 field: A FieldDescriptor for this field. 703 cls: The class we're constructing. 704 """ 705 proto_field_name = field.name 706 property_name = _PropertyName(proto_field_name) 707 type_checker = type_checkers.GetTypeChecker(field) 708 default_value = field.default_value 709 710 def getter(self): 711 # TODO: This may be broken since there may not be 712 # default_value. Combine with has_default_value somehow. 713 return self._fields.get(field, default_value) 714 getter.__module__ = None 715 getter.__doc__ = 'Getter for %s.' % proto_field_name 716 717 def field_setter(self, new_value): 718 # pylint: disable=protected-access 719 # Testing the value for truthiness captures all of the proto3 defaults 720 # (0, 0.0, enum 0, and False). 721 try: 722 new_value = type_checker.CheckValue(new_value) 723 except TypeError as e: 724 raise TypeError( 725 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) 726 if not field.has_presence and not new_value: 727 self._fields.pop(field, None) 728 else: 729 self._fields[field] = new_value 730 # Check _cached_byte_size_dirty inline to improve performance, since scalar 731 # setters are called frequently. 732 if not self._cached_byte_size_dirty: 733 self._Modified() 734 735 if field.containing_oneof: 736 def setter(self, new_value): 737 field_setter(self, new_value) 738 self._UpdateOneofState(field) 739 else: 740 setter = field_setter 741 742 setter.__module__ = None 743 setter.__doc__ = 'Setter for %s.' % proto_field_name 744 745 # Add a property to encapsulate the getter/setter. 746 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 747 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 748 749 750def _AddPropertiesForNonRepeatedCompositeField(field, cls): 751 """Adds a public property for a nonrepeated, composite protocol message field. 752 A composite field is a "group" or "message" field. 753 754 Clients can use this property to get the value of the field, but cannot 755 assign to the property directly. 756 757 Args: 758 field: A FieldDescriptor for this field. 759 cls: The class we're constructing. 760 """ 761 # TODO: Remove duplication with similar method 762 # for non-repeated scalars. 763 proto_field_name = field.name 764 property_name = _PropertyName(proto_field_name) 765 766 def getter(self): 767 field_value = self._fields.get(field) 768 if field_value is None: 769 # Construct a new object to represent this field. 770 field_value = field._default_constructor(self) 771 772 # Atomically check if another thread has preempted us and, if not, swap 773 # in the new object we just created. If someone has preempted us, we 774 # take that object and discard ours. 775 # WARNING: We are relying on setdefault() being atomic. This is true 776 # in CPython but we haven't investigated others. This warning appears 777 # in several other locations in this file. 778 field_value = self._fields.setdefault(field, field_value) 779 return field_value 780 getter.__module__ = None 781 getter.__doc__ = 'Getter for %s.' % proto_field_name 782 783 # We define a setter just so we can throw an exception with a more 784 # helpful error message. 785 def setter(self, new_value): 786 if field.message_type.full_name == 'google.protobuf.Timestamp': 787 getter(self) 788 self._fields[field].FromDatetime(new_value) 789 elif field.message_type.full_name == 'google.protobuf.Duration': 790 getter(self) 791 self._fields[field].FromTimedelta(new_value) 792 elif field.message_type.full_name == _StructFullTypeName: 793 getter(self) 794 self._fields[field].Clear() 795 self._fields[field].update(new_value) 796 elif field.message_type.full_name == _ListValueFullTypeName: 797 getter(self) 798 self._fields[field].Clear() 799 self._fields[field].extend(new_value) 800 else: 801 raise AttributeError( 802 'Assignment not allowed to composite field ' 803 '"%s" in protocol message object.' % proto_field_name 804 ) 805 806 # Add a property to encapsulate the getter. 807 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 808 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 809 810 811def _AddPropertiesForExtensions(descriptor, cls): 812 """Adds properties for all fields in this protocol message type.""" 813 extensions = descriptor.extensions_by_name 814 for extension_name, extension_field in extensions.items(): 815 constant_name = extension_name.upper() + '_FIELD_NUMBER' 816 setattr(cls, constant_name, extension_field.number) 817 818 # TODO: Migrate all users of these attributes to functions like 819 # pool.FindExtensionByNumber(descriptor). 820 if descriptor.file is not None: 821 # TODO: Use cls.MESSAGE_FACTORY.pool when available. 822 pool = descriptor.file.pool 823 824def _AddStaticMethods(cls): 825 def FromString(s): 826 message = cls() 827 message.MergeFromString(s) 828 return message 829 cls.FromString = staticmethod(FromString) 830 831 832def _IsPresent(item): 833 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 834 value should be included in the list returned by ListFields().""" 835 836 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 837 return bool(item[1]) 838 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 839 return item[1]._is_present_in_parent 840 else: 841 return True 842 843 844def _AddListFieldsMethod(message_descriptor, cls): 845 """Helper for _AddMessageMethods().""" 846 847 def ListFields(self): 848 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 849 all_fields.sort(key = lambda item: item[0].number) 850 return all_fields 851 852 cls.ListFields = ListFields 853 854 855def _AddHasFieldMethod(message_descriptor, cls): 856 """Helper for _AddMessageMethods().""" 857 858 hassable_fields = {} 859 for field in message_descriptor.fields: 860 if field.label == _FieldDescriptor.LABEL_REPEATED: 861 continue 862 # For proto3, only submessages and fields inside a oneof have presence. 863 if not field.has_presence: 864 continue 865 hassable_fields[field.name] = field 866 867 # Has methods are supported for oneof descriptors. 868 for oneof in message_descriptor.oneofs: 869 hassable_fields[oneof.name] = oneof 870 871 def HasField(self, field_name): 872 try: 873 field = hassable_fields[field_name] 874 except KeyError as exc: 875 raise ValueError('Protocol message %s has no non-repeated field "%s" ' 876 'nor has presence is not available for this field.' % ( 877 message_descriptor.full_name, field_name)) from exc 878 879 if isinstance(field, descriptor_mod.OneofDescriptor): 880 try: 881 return HasField(self, self._oneofs[field].name) 882 except KeyError: 883 return False 884 else: 885 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 886 value = self._fields.get(field) 887 return value is not None and value._is_present_in_parent 888 else: 889 return field in self._fields 890 891 cls.HasField = HasField 892 893 894def _AddClearFieldMethod(message_descriptor, cls): 895 """Helper for _AddMessageMethods().""" 896 def ClearField(self, field_name): 897 try: 898 field = message_descriptor.fields_by_name[field_name] 899 except KeyError: 900 try: 901 field = message_descriptor.oneofs_by_name[field_name] 902 if field in self._oneofs: 903 field = self._oneofs[field] 904 else: 905 return 906 except KeyError: 907 raise ValueError('Protocol message %s has no "%s" field.' % 908 (message_descriptor.name, field_name)) 909 910 if field in self._fields: 911 # To match the C++ implementation, we need to invalidate iterators 912 # for map fields when ClearField() happens. 913 if hasattr(self._fields[field], 'InvalidateIterators'): 914 self._fields[field].InvalidateIterators() 915 916 # Note: If the field is a sub-message, its listener will still point 917 # at us. That's fine, because the worst than can happen is that it 918 # will call _Modified() and invalidate our byte size. Big deal. 919 del self._fields[field] 920 921 if self._oneofs.get(field.containing_oneof, None) is field: 922 del self._oneofs[field.containing_oneof] 923 924 # Always call _Modified() -- even if nothing was changed, this is 925 # a mutating method, and thus calling it should cause the field to become 926 # present in the parent message. 927 self._Modified() 928 929 cls.ClearField = ClearField 930 931 932def _AddClearExtensionMethod(cls): 933 """Helper for _AddMessageMethods().""" 934 def ClearExtension(self, field_descriptor): 935 extension_dict._VerifyExtensionHandle(self, field_descriptor) 936 937 # Similar to ClearField(), above. 938 if field_descriptor in self._fields: 939 del self._fields[field_descriptor] 940 self._Modified() 941 cls.ClearExtension = ClearExtension 942 943 944def _AddHasExtensionMethod(cls): 945 """Helper for _AddMessageMethods().""" 946 def HasExtension(self, field_descriptor): 947 extension_dict._VerifyExtensionHandle(self, field_descriptor) 948 if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: 949 raise KeyError('"%s" is repeated.' % field_descriptor.full_name) 950 951 if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 952 value = self._fields.get(field_descriptor) 953 return value is not None and value._is_present_in_parent 954 else: 955 return field_descriptor in self._fields 956 cls.HasExtension = HasExtension 957 958def _InternalUnpackAny(msg): 959 """Unpacks Any message and returns the unpacked message. 960 961 This internal method is different from public Any Unpack method which takes 962 the target message as argument. _InternalUnpackAny method does not have 963 target message type and need to find the message type in descriptor pool. 964 965 Args: 966 msg: An Any message to be unpacked. 967 968 Returns: 969 The unpacked message. 970 """ 971 # TODO: Don't use the factory of generated messages. 972 # To make Any work with custom factories, use the message factory of the 973 # parent message. 974 # pylint: disable=g-import-not-at-top 975 from google.protobuf import symbol_database 976 factory = symbol_database.Default() 977 978 type_url = msg.type_url 979 980 if not type_url: 981 return None 982 983 # TODO: For now we just strip the hostname. Better logic will be 984 # required. 985 type_name = type_url.split('/')[-1] 986 descriptor = factory.pool.FindMessageTypeByName(type_name) 987 988 if descriptor is None: 989 return None 990 991 # Unable to import message_factory at top because of circular import. 992 # pylint: disable=g-import-not-at-top 993 from google.protobuf import message_factory 994 message_class = message_factory.GetMessageClass(descriptor) 995 message = message_class() 996 997 message.ParseFromString(msg.value) 998 return message 999 1000 1001def _AddEqualsMethod(message_descriptor, cls): 1002 """Helper for _AddMessageMethods().""" 1003 def __eq__(self, other): 1004 if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance( 1005 other, list 1006 ): 1007 return self._internal_compare(other) 1008 if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance( 1009 other, dict 1010 ): 1011 return self._internal_compare(other) 1012 1013 if (not isinstance(other, message_mod.Message) or 1014 other.DESCRIPTOR != self.DESCRIPTOR): 1015 return NotImplemented 1016 1017 if self is other: 1018 return True 1019 1020 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 1021 any_a = _InternalUnpackAny(self) 1022 any_b = _InternalUnpackAny(other) 1023 if any_a and any_b: 1024 return any_a == any_b 1025 1026 if not self.ListFields() == other.ListFields(): 1027 return False 1028 1029 # TODO: Fix UnknownFieldSet to consider MessageSet extensions, 1030 # then use it for the comparison. 1031 unknown_fields = list(self._unknown_fields) 1032 unknown_fields.sort() 1033 other_unknown_fields = list(other._unknown_fields) 1034 other_unknown_fields.sort() 1035 return unknown_fields == other_unknown_fields 1036 1037 cls.__eq__ = __eq__ 1038 1039 1040def _AddStrMethod(message_descriptor, cls): 1041 """Helper for _AddMessageMethods().""" 1042 def __str__(self): 1043 return text_format.MessageToString(self) 1044 cls.__str__ = __str__ 1045 1046 1047def _AddReprMethod(message_descriptor, cls): 1048 """Helper for _AddMessageMethods().""" 1049 def __repr__(self): 1050 return text_format.MessageToString(self) 1051 cls.__repr__ = __repr__ 1052 1053 1054def _AddUnicodeMethod(unused_message_descriptor, cls): 1055 """Helper for _AddMessageMethods().""" 1056 1057 def __unicode__(self): 1058 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 1059 cls.__unicode__ = __unicode__ 1060 1061 1062def _AddContainsMethod(message_descriptor, cls): 1063 1064 if message_descriptor.full_name == 'google.protobuf.Struct': 1065 def __contains__(self, key): 1066 return key in self.fields 1067 elif message_descriptor.full_name == 'google.protobuf.ListValue': 1068 def __contains__(self, value): 1069 return value in self.items() 1070 else: 1071 def __contains__(self, field): 1072 return self.HasField(field) 1073 1074 cls.__contains__ = __contains__ 1075 1076 1077def _BytesForNonRepeatedElement(value, field_number, field_type): 1078 """Returns the number of bytes needed to serialize a non-repeated element. 1079 The returned byte count includes space for tag information and any 1080 other additional space associated with serializing value. 1081 1082 Args: 1083 value: Value we're serializing. 1084 field_number: Field number of this value. (Since the field number 1085 is stored as part of a varint-encoded tag, this has an impact 1086 on the total bytes required to serialize the value). 1087 field_type: The type of the field. One of the TYPE_* constants 1088 within FieldDescriptor. 1089 """ 1090 try: 1091 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 1092 return fn(field_number, value) 1093 except KeyError: 1094 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 1095 1096 1097def _AddByteSizeMethod(message_descriptor, cls): 1098 """Helper for _AddMessageMethods().""" 1099 1100 def ByteSize(self): 1101 if not self._cached_byte_size_dirty: 1102 return self._cached_byte_size 1103 1104 size = 0 1105 descriptor = self.DESCRIPTOR 1106 if descriptor._is_map_entry: 1107 # Fields of map entry should always be serialized. 1108 key_field = descriptor.fields_by_name['key'] 1109 _MaybeAddEncoder(cls, key_field) 1110 size = key_field._sizer(self.key) 1111 value_field = descriptor.fields_by_name['value'] 1112 _MaybeAddEncoder(cls, value_field) 1113 size += value_field._sizer(self.value) 1114 else: 1115 for field_descriptor, field_value in self.ListFields(): 1116 _MaybeAddEncoder(cls, field_descriptor) 1117 size += field_descriptor._sizer(field_value) 1118 for tag_bytes, value_bytes in self._unknown_fields: 1119 size += len(tag_bytes) + len(value_bytes) 1120 1121 self._cached_byte_size = size 1122 self._cached_byte_size_dirty = False 1123 self._listener_for_children.dirty = False 1124 return size 1125 1126 cls.ByteSize = ByteSize 1127 1128 1129def _AddSerializeToStringMethod(message_descriptor, cls): 1130 """Helper for _AddMessageMethods().""" 1131 1132 def SerializeToString(self, **kwargs): 1133 # Check if the message has all of its required fields set. 1134 if not self.IsInitialized(): 1135 raise message_mod.EncodeError( 1136 'Message %s is missing required fields: %s' % ( 1137 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 1138 return self.SerializePartialToString(**kwargs) 1139 cls.SerializeToString = SerializeToString 1140 1141 1142def _AddSerializePartialToStringMethod(message_descriptor, cls): 1143 """Helper for _AddMessageMethods().""" 1144 1145 def SerializePartialToString(self, **kwargs): 1146 out = BytesIO() 1147 self._InternalSerialize(out.write, **kwargs) 1148 return out.getvalue() 1149 cls.SerializePartialToString = SerializePartialToString 1150 1151 def InternalSerialize(self, write_bytes, deterministic=None): 1152 if deterministic is None: 1153 deterministic = ( 1154 api_implementation.IsPythonDefaultSerializationDeterministic()) 1155 else: 1156 deterministic = bool(deterministic) 1157 1158 descriptor = self.DESCRIPTOR 1159 if descriptor._is_map_entry: 1160 # Fields of map entry should always be serialized. 1161 key_field = descriptor.fields_by_name['key'] 1162 _MaybeAddEncoder(cls, key_field) 1163 key_field._encoder(write_bytes, self.key, deterministic) 1164 value_field = descriptor.fields_by_name['value'] 1165 _MaybeAddEncoder(cls, value_field) 1166 value_field._encoder(write_bytes, self.value, deterministic) 1167 else: 1168 for field_descriptor, field_value in self.ListFields(): 1169 _MaybeAddEncoder(cls, field_descriptor) 1170 field_descriptor._encoder(write_bytes, field_value, deterministic) 1171 for tag_bytes, value_bytes in self._unknown_fields: 1172 write_bytes(tag_bytes) 1173 write_bytes(value_bytes) 1174 cls._InternalSerialize = InternalSerialize 1175 1176 1177def _AddMergeFromStringMethod(message_descriptor, cls): 1178 """Helper for _AddMessageMethods().""" 1179 def MergeFromString(self, serialized): 1180 serialized = memoryview(serialized) 1181 length = len(serialized) 1182 try: 1183 if self._InternalParse(serialized, 0, length) != length: 1184 # The only reason _InternalParse would return early is if it 1185 # encountered an end-group tag. 1186 raise message_mod.DecodeError('Unexpected end-group tag.') 1187 except (IndexError, TypeError): 1188 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 1189 raise message_mod.DecodeError('Truncated message.') 1190 except struct.error as e: 1191 raise message_mod.DecodeError(e) 1192 return length # Return this for legacy reasons. 1193 cls.MergeFromString = MergeFromString 1194 1195 local_ReadTag = decoder.ReadTag 1196 local_SkipField = decoder.SkipField 1197 fields_by_tag = cls._fields_by_tag 1198 message_set_decoders_by_tag = cls._message_set_decoders_by_tag 1199 1200 def InternalParse(self, buffer, pos, end): 1201 """Create a message from serialized bytes. 1202 1203 Args: 1204 self: Message, instance of the proto message object. 1205 buffer: memoryview of the serialized data. 1206 pos: int, position to start in the serialized data. 1207 end: int, end position of the serialized data. 1208 1209 Returns: 1210 Message object. 1211 """ 1212 # Guard against internal misuse, since this function is called internally 1213 # quite extensively, and its easy to accidentally pass bytes. 1214 assert isinstance(buffer, memoryview) 1215 self._Modified() 1216 field_dict = self._fields 1217 while pos != end: 1218 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 1219 field_decoder, field_des = message_set_decoders_by_tag.get( 1220 tag_bytes, (None, None) 1221 ) 1222 if field_decoder: 1223 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1224 continue 1225 field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None)) 1226 if field_des is None: 1227 if not self._unknown_fields: # pylint: disable=protected-access 1228 self._unknown_fields = [] # pylint: disable=protected-access 1229 # pylint: disable=protected-access 1230 (tag, _) = decoder._DecodeVarint(tag_bytes, 0) 1231 field_number, wire_type = wire_format.UnpackTag(tag) 1232 if field_number == 0: 1233 raise message_mod.DecodeError('Field number 0 is illegal.') 1234 # TODO: remove old_pos. 1235 old_pos = new_pos 1236 (data, new_pos) = decoder._DecodeUnknownField( 1237 buffer, new_pos, wire_type) # pylint: disable=protected-access 1238 if new_pos == -1: 1239 return pos 1240 # TODO: remove _unknown_fields. 1241 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) 1242 if new_pos == -1: 1243 return pos 1244 self._unknown_fields.append( 1245 (tag_bytes, buffer[old_pos:new_pos].tobytes())) 1246 pos = new_pos 1247 else: 1248 _MaybeAddDecoder(cls, field_des) 1249 field_decoder = field_des._decoders[is_packed] 1250 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1251 if field_des.containing_oneof: 1252 self._UpdateOneofState(field_des) 1253 return pos 1254 cls._InternalParse = InternalParse 1255 1256 1257def _AddIsInitializedMethod(message_descriptor, cls): 1258 """Adds the IsInitialized and FindInitializationError methods to the 1259 protocol message class.""" 1260 1261 required_fields = [field for field in message_descriptor.fields 1262 if field.label == _FieldDescriptor.LABEL_REQUIRED] 1263 1264 def IsInitialized(self, errors=None): 1265 """Checks if all required fields of a message are set. 1266 1267 Args: 1268 errors: A list which, if provided, will be populated with the field 1269 paths of all missing required fields. 1270 1271 Returns: 1272 True iff the specified message has all required fields set. 1273 """ 1274 1275 # Performance is critical so we avoid HasField() and ListFields(). 1276 1277 for field in required_fields: 1278 if (field not in self._fields or 1279 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 1280 not self._fields[field]._is_present_in_parent)): 1281 if errors is not None: 1282 errors.extend(self.FindInitializationErrors()) 1283 return False 1284 1285 for field, value in list(self._fields.items()): # dict can change size! 1286 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1287 if field.label == _FieldDescriptor.LABEL_REPEATED: 1288 if (field.message_type._is_map_entry): 1289 continue 1290 for element in value: 1291 if not element.IsInitialized(): 1292 if errors is not None: 1293 errors.extend(self.FindInitializationErrors()) 1294 return False 1295 elif value._is_present_in_parent and not value.IsInitialized(): 1296 if errors is not None: 1297 errors.extend(self.FindInitializationErrors()) 1298 return False 1299 1300 return True 1301 1302 cls.IsInitialized = IsInitialized 1303 1304 def FindInitializationErrors(self): 1305 """Finds required fields which are not initialized. 1306 1307 Returns: 1308 A list of strings. Each string is a path to an uninitialized field from 1309 the top-level message, e.g. "foo.bar[5].baz". 1310 """ 1311 1312 errors = [] # simplify things 1313 1314 for field in required_fields: 1315 if not self.HasField(field.name): 1316 errors.append(field.name) 1317 1318 for field, value in self.ListFields(): 1319 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1320 if field.is_extension: 1321 name = '(%s)' % field.full_name 1322 else: 1323 name = field.name 1324 1325 if _IsMapField(field): 1326 if _IsMessageMapField(field): 1327 for key in value: 1328 element = value[key] 1329 prefix = '%s[%s].' % (name, key) 1330 sub_errors = element.FindInitializationErrors() 1331 errors += [prefix + error for error in sub_errors] 1332 else: 1333 # ScalarMaps can't have any initialization errors. 1334 pass 1335 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1336 for i in range(len(value)): 1337 element = value[i] 1338 prefix = '%s[%d].' % (name, i) 1339 sub_errors = element.FindInitializationErrors() 1340 errors += [prefix + error for error in sub_errors] 1341 else: 1342 prefix = name + '.' 1343 sub_errors = value.FindInitializationErrors() 1344 errors += [prefix + error for error in sub_errors] 1345 1346 return errors 1347 1348 cls.FindInitializationErrors = FindInitializationErrors 1349 1350 1351def _FullyQualifiedClassName(klass): 1352 module = klass.__module__ 1353 name = getattr(klass, '__qualname__', klass.__name__) 1354 if module in (None, 'builtins', '__builtin__'): 1355 return name 1356 return module + '.' + name 1357 1358 1359def _AddMergeFromMethod(cls): 1360 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 1361 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 1362 1363 def MergeFrom(self, msg): 1364 if not isinstance(msg, cls): 1365 raise TypeError( 1366 'Parameter to MergeFrom() must be instance of same class: ' 1367 'expected %s got %s.' % (_FullyQualifiedClassName(cls), 1368 _FullyQualifiedClassName(msg.__class__))) 1369 1370 assert msg is not self 1371 self._Modified() 1372 1373 fields = self._fields 1374 1375 for field, value in msg._fields.items(): 1376 if field.label == LABEL_REPEATED: 1377 field_value = fields.get(field) 1378 if field_value is None: 1379 # Construct a new object to represent this field. 1380 field_value = field._default_constructor(self) 1381 fields[field] = field_value 1382 field_value.MergeFrom(value) 1383 elif field.cpp_type == CPPTYPE_MESSAGE: 1384 if value._is_present_in_parent: 1385 field_value = fields.get(field) 1386 if field_value is None: 1387 # Construct a new object to represent this field. 1388 field_value = field._default_constructor(self) 1389 fields[field] = field_value 1390 field_value.MergeFrom(value) 1391 else: 1392 self._fields[field] = value 1393 if field.containing_oneof: 1394 self._UpdateOneofState(field) 1395 1396 if msg._unknown_fields: 1397 if not self._unknown_fields: 1398 self._unknown_fields = [] 1399 self._unknown_fields.extend(msg._unknown_fields) 1400 1401 cls.MergeFrom = MergeFrom 1402 1403 1404def _AddWhichOneofMethod(message_descriptor, cls): 1405 def WhichOneof(self, oneof_name): 1406 """Returns the name of the currently set field inside a oneof, or None.""" 1407 try: 1408 field = message_descriptor.oneofs_by_name[oneof_name] 1409 except KeyError: 1410 raise ValueError( 1411 'Protocol message has no oneof "%s" field.' % oneof_name) 1412 1413 nested_field = self._oneofs.get(field, None) 1414 if nested_field is not None and self.HasField(nested_field.name): 1415 return nested_field.name 1416 else: 1417 return None 1418 1419 cls.WhichOneof = WhichOneof 1420 1421 1422def _Clear(self): 1423 # Clear fields. 1424 self._fields = {} 1425 self._unknown_fields = () 1426 1427 self._oneofs = {} 1428 self._Modified() 1429 1430 1431def _UnknownFields(self): 1432 raise NotImplementedError('Please use the add-on feaure ' 1433 'unknown_fields.UnknownFieldSet(message) in ' 1434 'unknown_fields.py instead.') 1435 1436 1437def _DiscardUnknownFields(self): 1438 self._unknown_fields = [] 1439 for field, value in self.ListFields(): 1440 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1441 if _IsMapField(field): 1442 if _IsMessageMapField(field): 1443 for key in value: 1444 value[key].DiscardUnknownFields() 1445 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1446 for sub_message in value: 1447 sub_message.DiscardUnknownFields() 1448 else: 1449 value.DiscardUnknownFields() 1450 1451 1452def _SetListener(self, listener): 1453 if listener is None: 1454 self._listener = message_listener_mod.NullMessageListener() 1455 else: 1456 self._listener = listener 1457 1458 1459def _AddMessageMethods(message_descriptor, cls): 1460 """Adds implementations of all Message methods to cls.""" 1461 _AddListFieldsMethod(message_descriptor, cls) 1462 _AddHasFieldMethod(message_descriptor, cls) 1463 _AddClearFieldMethod(message_descriptor, cls) 1464 if message_descriptor.is_extendable: 1465 _AddClearExtensionMethod(cls) 1466 _AddHasExtensionMethod(cls) 1467 _AddEqualsMethod(message_descriptor, cls) 1468 _AddStrMethod(message_descriptor, cls) 1469 _AddReprMethod(message_descriptor, cls) 1470 _AddUnicodeMethod(message_descriptor, cls) 1471 _AddContainsMethod(message_descriptor, cls) 1472 _AddByteSizeMethod(message_descriptor, cls) 1473 _AddSerializeToStringMethod(message_descriptor, cls) 1474 _AddSerializePartialToStringMethod(message_descriptor, cls) 1475 _AddMergeFromStringMethod(message_descriptor, cls) 1476 _AddIsInitializedMethod(message_descriptor, cls) 1477 _AddMergeFromMethod(cls) 1478 _AddWhichOneofMethod(message_descriptor, cls) 1479 # Adds methods which do not depend on cls. 1480 cls.Clear = _Clear 1481 cls.DiscardUnknownFields = _DiscardUnknownFields 1482 cls._SetListener = _SetListener 1483 1484 1485def _AddPrivateHelperMethods(message_descriptor, cls): 1486 """Adds implementation of private helper methods to cls.""" 1487 1488 def Modified(self): 1489 """Sets the _cached_byte_size_dirty bit to true, 1490 and propagates this to our listener iff this was a state change. 1491 """ 1492 1493 # Note: Some callers check _cached_byte_size_dirty before calling 1494 # _Modified() as an extra optimization. So, if this method is ever 1495 # changed such that it does stuff even when _cached_byte_size_dirty is 1496 # already true, the callers need to be updated. 1497 if not self._cached_byte_size_dirty: 1498 self._cached_byte_size_dirty = True 1499 self._listener_for_children.dirty = True 1500 self._is_present_in_parent = True 1501 self._listener.Modified() 1502 1503 def _UpdateOneofState(self, field): 1504 """Sets field as the active field in its containing oneof. 1505 1506 Will also delete currently active field in the oneof, if it is different 1507 from the argument. Does not mark the message as modified. 1508 """ 1509 other_field = self._oneofs.setdefault(field.containing_oneof, field) 1510 if other_field is not field: 1511 del self._fields[other_field] 1512 self._oneofs[field.containing_oneof] = field 1513 1514 cls._Modified = Modified 1515 cls.SetInParent = Modified 1516 cls._UpdateOneofState = _UpdateOneofState 1517 1518 1519class _Listener(object): 1520 1521 """MessageListener implementation that a parent message registers with its 1522 child message. 1523 1524 In order to support semantics like: 1525 1526 foo.bar.baz.moo = 23 1527 assert foo.HasField('bar') 1528 1529 ...child objects must have back references to their parents. 1530 This helper class is at the heart of this support. 1531 """ 1532 1533 def __init__(self, parent_message): 1534 """Args: 1535 parent_message: The message whose _Modified() method we should call when 1536 we receive Modified() messages. 1537 """ 1538 # This listener establishes a back reference from a child (contained) object 1539 # to its parent (containing) object. We make this a weak reference to avoid 1540 # creating cyclic garbage when the client finishes with the 'parent' object 1541 # in the tree. 1542 if isinstance(parent_message, weakref.ProxyType): 1543 self._parent_message_weakref = parent_message 1544 else: 1545 self._parent_message_weakref = weakref.proxy(parent_message) 1546 1547 # As an optimization, we also indicate directly on the listener whether 1548 # or not the parent message is dirty. This way we can avoid traversing 1549 # up the tree in the common case. 1550 self.dirty = False 1551 1552 def Modified(self): 1553 if self.dirty: 1554 return 1555 try: 1556 # Propagate the signal to our parents iff this is the first field set. 1557 self._parent_message_weakref._Modified() 1558 except ReferenceError: 1559 # We can get here if a client has kept a reference to a child object, 1560 # and is now setting a field on it, but the child's parent has been 1561 # garbage-collected. This is not an error. 1562 pass 1563 1564 1565class _OneofListener(_Listener): 1566 """Special listener implementation for setting composite oneof fields.""" 1567 1568 def __init__(self, parent_message, field): 1569 """Args: 1570 parent_message: The message whose _Modified() method we should call when 1571 we receive Modified() messages. 1572 field: The descriptor of the field being set in the parent message. 1573 """ 1574 super(_OneofListener, self).__init__(parent_message) 1575 self._field = field 1576 1577 def Modified(self): 1578 """Also updates the state of the containing oneof in the parent message.""" 1579 try: 1580 self._parent_message_weakref._UpdateOneofState(self._field) 1581 super(_OneofListener, self).Modified() 1582 except ReferenceError: 1583 pass 1584