1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31# This code is meant to work on Python 2.4 and above only. 32# 33# TODO(robinson): Helpers for verbose, common checks like seeing if a 34# descriptor's cpp_type is CPPTYPE_MESSAGE. 35 36"""Contains a metaclass and helper functions used to create 37protocol message classes from Descriptor objects at runtime. 38 39Recall that a metaclass is the "type" of a class. 40(A class is to a metaclass what an instance is to a class.) 41 42In this case, we use the GeneratedProtocolMessageType metaclass 43to inject all the useful functionality into the classes 44output by the protocol compiler at compile-time. 45 46The upshot of all this is that the real implementation 47details for ALL pure-Python protocol buffers are *here in 48this file*. 49""" 50 51__author__ = 'robinson@google.com (Will Robinson)' 52 53from io import BytesIO 54import struct 55import sys 56import weakref 57 58import six 59from six.moves import range 60 61# We use "as" to avoid name collisions with variables. 62from google.protobuf.internal import api_implementation 63from google.protobuf.internal import containers 64from google.protobuf.internal import decoder 65from google.protobuf.internal import encoder 66from google.protobuf.internal import enum_type_wrapper 67from google.protobuf.internal import extension_dict 68from google.protobuf.internal import message_listener as message_listener_mod 69from google.protobuf.internal import type_checkers 70from google.protobuf.internal import well_known_types 71from google.protobuf.internal import wire_format 72from google.protobuf import descriptor as descriptor_mod 73from google.protobuf import message as message_mod 74from google.protobuf import text_format 75 76_FieldDescriptor = descriptor_mod.FieldDescriptor 77_AnyFullTypeName = 'google.protobuf.Any' 78_ExtensionDict = extension_dict._ExtensionDict 79 80class GeneratedProtocolMessageType(type): 81 82 """Metaclass for protocol message classes created at runtime from Descriptors. 83 84 We add implementations for all methods described in the Message class. We 85 also create properties to allow getting/setting all fields in the protocol 86 message. Finally, we create slots to prevent users from accidentally 87 "setting" nonexistent fields in the protocol message, which then wouldn't get 88 serialized / deserialized properly. 89 90 The protocol compiler currently uses this metaclass to create protocol 91 message classes at runtime. Clients can also manually create their own 92 classes at runtime, as in this example: 93 94 mydescriptor = Descriptor(.....) 95 factory = symbol_database.Default() 96 factory.pool.AddDescriptor(mydescriptor) 97 MyProtoClass = factory.GetPrototype(mydescriptor) 98 myproto_instance = MyProtoClass() 99 myproto.foo_field = 23 100 ... 101 """ 102 103 # Must be consistent with the protocol-compiler code in 104 # proto2/compiler/internal/generator.*. 105 _DESCRIPTOR_KEY = 'DESCRIPTOR' 106 107 def __new__(cls, name, bases, dictionary): 108 """Custom allocation for runtime-generated class types. 109 110 We override __new__ because this is apparently the only place 111 where we can meaningfully set __slots__ on the class we're creating(?). 112 (The interplay between metaclasses and slots is not very well-documented). 113 114 Args: 115 name: Name of the class (ignored, but required by the 116 metaclass protocol). 117 bases: Base classes of the class we're constructing. 118 (Should be message.Message). We ignore this field, but 119 it's required by the metaclass protocol 120 dictionary: The class dictionary of the class we're 121 constructing. dictionary[_DESCRIPTOR_KEY] must contain 122 a Descriptor object describing this protocol message 123 type. 124 125 Returns: 126 Newly-allocated class. 127 """ 128 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 129 130 # If a concrete class already exists for this descriptor, don't try to 131 # create another. Doing so will break any messages that already exist with 132 # the existing class. 133 # 134 # The C++ implementation appears to have its own internal `PyMessageFactory` 135 # to achieve similar results. 136 # 137 # This most commonly happens in `text_format.py` when using descriptors from 138 # a custom pool; it calls symbol_database.Global().getPrototype() on a 139 # descriptor which already has an existing concrete class. 140 new_class = getattr(descriptor, '_concrete_class', None) 141 if new_class: 142 return new_class 143 144 if descriptor.full_name in well_known_types.WKTBASES: 145 bases += (well_known_types.WKTBASES[descriptor.full_name],) 146 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 147 _AddSlots(descriptor, dictionary) 148 149 superclass = super(GeneratedProtocolMessageType, cls) 150 new_class = superclass.__new__(cls, name, bases, dictionary) 151 return new_class 152 153 def __init__(cls, name, bases, dictionary): 154 """Here we perform the majority of our work on the class. 155 We add enum getters, an __init__ method, implementations 156 of all Message methods, and properties for all fields 157 in the protocol type. 158 159 Args: 160 name: Name of the class (ignored, but required by the 161 metaclass protocol). 162 bases: Base classes of the class we're constructing. 163 (Should be message.Message). We ignore this field, but 164 it's required by the metaclass protocol 165 dictionary: The class dictionary of the class we're 166 constructing. dictionary[_DESCRIPTOR_KEY] must contain 167 a Descriptor object describing this protocol message 168 type. 169 """ 170 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 171 172 # If this is an _existing_ class looked up via `_concrete_class` in the 173 # __new__ method above, then we don't need to re-initialize anything. 174 existing_class = getattr(descriptor, '_concrete_class', None) 175 if existing_class: 176 assert existing_class is cls, ( 177 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' 178 % (descriptor.full_name)) 179 return 180 181 cls._decoders_by_tag = {} 182 if (descriptor.has_options and 183 descriptor.GetOptions().message_set_wire_format): 184 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 185 decoder.MessageSetItemDecoder(descriptor), None) 186 187 # Attach stuff to each FieldDescriptor for quick lookup later on. 188 for field in descriptor.fields: 189 _AttachFieldHelpers(cls, field) 190 191 descriptor._concrete_class = cls # pylint: disable=protected-access 192 _AddEnumValues(descriptor, cls) 193 _AddInitMethod(descriptor, cls) 194 _AddPropertiesForFields(descriptor, cls) 195 _AddPropertiesForExtensions(descriptor, cls) 196 _AddStaticMethods(cls) 197 _AddMessageMethods(descriptor, cls) 198 _AddPrivateHelperMethods(descriptor, cls) 199 200 superclass = super(GeneratedProtocolMessageType, cls) 201 superclass.__init__(name, bases, dictionary) 202 203 204# Stateless helpers for GeneratedProtocolMessageType below. 205# Outside clients should not access these directly. 206# 207# I opted not to make any of these methods on the metaclass, to make it more 208# clear that I'm not really using any state there and to keep clients from 209# thinking that they have direct access to these construction helpers. 210 211 212def _PropertyName(proto_field_name): 213 """Returns the name of the public property attribute which 214 clients can use to get and (in some cases) set the value 215 of a protocol message field. 216 217 Args: 218 proto_field_name: The protocol message field name, exactly 219 as it appears (or would appear) in a .proto file. 220 """ 221 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 222 # nnorwitz makes my day by writing: 223 # """ 224 # FYI. See the keyword module in the stdlib. This could be as simple as: 225 # 226 # if keyword.iskeyword(proto_field_name): 227 # return proto_field_name + "_" 228 # return proto_field_name 229 # """ 230 # Kenton says: The above is a BAD IDEA. People rely on being able to use 231 # getattr() and setattr() to reflectively manipulate field values. If we 232 # rename the properties, then every such user has to also make sure to apply 233 # the same transformation. Note that currently if you name a field "yield", 234 # you can still access it just fine using getattr/setattr -- it's not even 235 # that cumbersome to do so. 236 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 237 # position. 238 return proto_field_name 239 240 241def _AddSlots(message_descriptor, dictionary): 242 """Adds a __slots__ entry to dictionary, containing the names of all valid 243 attributes for this message type. 244 245 Args: 246 message_descriptor: A Descriptor instance describing this message type. 247 dictionary: Class dictionary to which we'll add a '__slots__' entry. 248 """ 249 dictionary['__slots__'] = ['_cached_byte_size', 250 '_cached_byte_size_dirty', 251 '_fields', 252 '_unknown_fields', 253 '_unknown_field_set', 254 '_is_present_in_parent', 255 '_listener', 256 '_listener_for_children', 257 '__weakref__', 258 '_oneofs'] 259 260 261def _IsMessageSetExtension(field): 262 return (field.is_extension and 263 field.containing_type.has_options and 264 field.containing_type.GetOptions().message_set_wire_format and 265 field.type == _FieldDescriptor.TYPE_MESSAGE and 266 field.label == _FieldDescriptor.LABEL_OPTIONAL) 267 268 269def _IsMapField(field): 270 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 271 field.message_type.has_options and 272 field.message_type.GetOptions().map_entry) 273 274 275def _IsMessageMapField(field): 276 value_type = field.message_type.fields_by_name['value'] 277 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 278 279 280def _IsStrictUtf8Check(field): 281 if field.containing_type.syntax != 'proto3': 282 return False 283 enforce_utf8 = True 284 return enforce_utf8 285 286 287def _AttachFieldHelpers(cls, field_descriptor): 288 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 289 is_packable = (is_repeated and 290 wire_format.IsTypePackable(field_descriptor.type)) 291 if not is_packable: 292 is_packed = False 293 elif field_descriptor.containing_type.syntax == 'proto2': 294 is_packed = (field_descriptor.has_options and 295 field_descriptor.GetOptions().packed) 296 else: 297 has_packed_false = (field_descriptor.has_options and 298 field_descriptor.GetOptions().HasField('packed') and 299 field_descriptor.GetOptions().packed == False) 300 is_packed = not has_packed_false 301 is_map_entry = _IsMapField(field_descriptor) 302 303 if is_map_entry: 304 field_encoder = encoder.MapEncoder(field_descriptor) 305 sizer = encoder.MapSizer(field_descriptor, 306 _IsMessageMapField(field_descriptor)) 307 elif _IsMessageSetExtension(field_descriptor): 308 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 309 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 310 else: 311 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 312 field_descriptor.number, is_repeated, is_packed) 313 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 314 field_descriptor.number, is_repeated, is_packed) 315 316 field_descriptor._encoder = field_encoder 317 field_descriptor._sizer = sizer 318 field_descriptor._default_constructor = _DefaultValueConstructorForField( 319 field_descriptor) 320 321 def AddDecoder(wiretype, is_packed): 322 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 323 decode_type = field_descriptor.type 324 if (decode_type == _FieldDescriptor.TYPE_ENUM and 325 type_checkers.SupportsOpenEnums(field_descriptor)): 326 decode_type = _FieldDescriptor.TYPE_INT32 327 328 oneof_descriptor = None 329 if field_descriptor.containing_oneof is not None: 330 oneof_descriptor = field_descriptor 331 332 if is_map_entry: 333 is_message_map = _IsMessageMapField(field_descriptor) 334 335 field_decoder = decoder.MapDecoder( 336 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 337 is_message_map) 338 elif decode_type == _FieldDescriptor.TYPE_STRING: 339 is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor) 340 field_decoder = decoder.StringDecoder( 341 field_descriptor.number, is_repeated, is_packed, 342 field_descriptor, field_descriptor._default_constructor, 343 is_strict_utf8_check) 344 else: 345 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 346 field_descriptor.number, is_repeated, is_packed, 347 field_descriptor, field_descriptor._default_constructor) 348 349 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) 350 351 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 352 False) 353 354 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 355 # To support wire compatibility of adding packed = true, add a decoder for 356 # packed values regardless of the field's options. 357 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 358 359 360def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 361 extensions = descriptor.extensions_by_name 362 for extension_name, extension_field in extensions.items(): 363 assert extension_name not in dictionary 364 dictionary[extension_name] = extension_field 365 366 367def _AddEnumValues(descriptor, cls): 368 """Sets class-level attributes for all enum fields defined in this message. 369 370 Also exporting a class-level object that can name enum values. 371 372 Args: 373 descriptor: Descriptor object for this message type. 374 cls: Class we're constructing for this message type. 375 """ 376 for enum_type in descriptor.enum_types: 377 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 378 for enum_value in enum_type.values: 379 setattr(cls, enum_value.name, enum_value.number) 380 381 382def _GetInitializeDefaultForMap(field): 383 if field.label != _FieldDescriptor.LABEL_REPEATED: 384 raise ValueError('map_entry set on non-repeated field %s' % ( 385 field.name)) 386 fields_by_name = field.message_type.fields_by_name 387 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 388 389 value_field = fields_by_name['value'] 390 if _IsMessageMapField(field): 391 def MakeMessageMapDefault(message): 392 return containers.MessageMap( 393 message._listener_for_children, value_field.message_type, key_checker, 394 field.message_type) 395 return MakeMessageMapDefault 396 else: 397 value_checker = type_checkers.GetTypeChecker(value_field) 398 def MakePrimitiveMapDefault(message): 399 return containers.ScalarMap( 400 message._listener_for_children, key_checker, value_checker, 401 field.message_type) 402 return MakePrimitiveMapDefault 403 404def _DefaultValueConstructorForField(field): 405 """Returns a function which returns a default value for a field. 406 407 Args: 408 field: FieldDescriptor object for this field. 409 410 The returned function has one argument: 411 message: Message instance containing this field, or a weakref proxy 412 of same. 413 414 That function in turn returns a default value for this field. The default 415 value may refer back to |message| via a weak reference. 416 """ 417 418 if _IsMapField(field): 419 return _GetInitializeDefaultForMap(field) 420 421 if field.label == _FieldDescriptor.LABEL_REPEATED: 422 if field.has_default_value and field.default_value != []: 423 raise ValueError('Repeated field default value not empty list: %s' % ( 424 field.default_value)) 425 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 426 # We can't look at _concrete_class yet since it might not have 427 # been set. (Depends on order in which we initialize the classes). 428 message_type = field.message_type 429 def MakeRepeatedMessageDefault(message): 430 return containers.RepeatedCompositeFieldContainer( 431 message._listener_for_children, field.message_type) 432 return MakeRepeatedMessageDefault 433 else: 434 type_checker = type_checkers.GetTypeChecker(field) 435 def MakeRepeatedScalarDefault(message): 436 return containers.RepeatedScalarFieldContainer( 437 message._listener_for_children, type_checker) 438 return MakeRepeatedScalarDefault 439 440 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 441 # _concrete_class may not yet be initialized. 442 message_type = field.message_type 443 def MakeSubMessageDefault(message): 444 assert getattr(message_type, '_concrete_class', None), ( 445 'Uninitialized concrete class found for field %r (message type %r)' 446 % (field.full_name, message_type.full_name)) 447 result = message_type._concrete_class() 448 result._SetListener( 449 _OneofListener(message, field) 450 if field.containing_oneof is not None 451 else message._listener_for_children) 452 return result 453 return MakeSubMessageDefault 454 455 def MakeScalarDefault(message): 456 # TODO(protobuf-team): This may be broken since there may not be 457 # default_value. Combine with has_default_value somehow. 458 return field.default_value 459 return MakeScalarDefault 460 461 462def _ReraiseTypeErrorWithFieldName(message_name, field_name): 463 """Re-raise the currently-handled TypeError with the field name added.""" 464 exc = sys.exc_info()[1] 465 if len(exc.args) == 1 and type(exc) is TypeError: 466 # simple TypeError; add field name to exception message 467 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 468 469 # re-raise possibly-amended exception with original traceback: 470 six.reraise(type(exc), exc, sys.exc_info()[2]) 471 472 473def _AddInitMethod(message_descriptor, cls): 474 """Adds an __init__ method to cls.""" 475 476 def _GetIntegerEnumValue(enum_type, value): 477 """Convert a string or integer enum value to an integer. 478 479 If the value is a string, it is converted to the enum value in 480 enum_type with the same name. If the value is not a string, it's 481 returned as-is. (No conversion or bounds-checking is done.) 482 """ 483 if isinstance(value, six.string_types): 484 try: 485 return enum_type.values_by_name[value].number 486 except KeyError: 487 raise ValueError('Enum type %s: unknown label "%s"' % ( 488 enum_type.full_name, value)) 489 return value 490 491 def init(self, **kwargs): 492 self._cached_byte_size = 0 493 self._cached_byte_size_dirty = len(kwargs) > 0 494 self._fields = {} 495 # Contains a mapping from oneof field descriptors to the descriptor 496 # of the currently set field in that oneof field. 497 self._oneofs = {} 498 499 # _unknown_fields is () when empty for efficiency, and will be turned into 500 # a list if fields are added. 501 self._unknown_fields = () 502 # _unknown_field_set is None when empty for efficiency, and will be 503 # turned into UnknownFieldSet struct if fields are added. 504 self._unknown_field_set = None # pylint: disable=protected-access 505 self._is_present_in_parent = False 506 self._listener = message_listener_mod.NullMessageListener() 507 self._listener_for_children = _Listener(self) 508 for field_name, field_value in kwargs.items(): 509 field = _GetFieldByName(message_descriptor, field_name) 510 if field is None: 511 raise TypeError('%s() got an unexpected keyword argument "%s"' % 512 (message_descriptor.name, field_name)) 513 if field_value is None: 514 # field=None is the same as no field at all. 515 continue 516 if field.label == _FieldDescriptor.LABEL_REPEATED: 517 copy = field._default_constructor(self) 518 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 519 if _IsMapField(field): 520 if _IsMessageMapField(field): 521 for key in field_value: 522 copy[key].MergeFrom(field_value[key]) 523 else: 524 copy.update(field_value) 525 else: 526 for val in field_value: 527 if isinstance(val, dict): 528 copy.add(**val) 529 else: 530 copy.add().MergeFrom(val) 531 else: # Scalar 532 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 533 field_value = [_GetIntegerEnumValue(field.enum_type, val) 534 for val in field_value] 535 copy.extend(field_value) 536 self._fields[field] = copy 537 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 538 copy = field._default_constructor(self) 539 new_val = field_value 540 if isinstance(field_value, dict): 541 new_val = field.message_type._concrete_class(**field_value) 542 try: 543 copy.MergeFrom(new_val) 544 except TypeError: 545 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 546 self._fields[field] = copy 547 else: 548 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 549 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 550 try: 551 setattr(self, field_name, field_value) 552 except TypeError: 553 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 554 555 init.__module__ = None 556 init.__doc__ = None 557 cls.__init__ = init 558 559 560def _GetFieldByName(message_descriptor, field_name): 561 """Returns a field descriptor by field name. 562 563 Args: 564 message_descriptor: A Descriptor describing all fields in message. 565 field_name: The name of the field to retrieve. 566 Returns: 567 The field descriptor associated with the field name. 568 """ 569 try: 570 return message_descriptor.fields_by_name[field_name] 571 except KeyError: 572 raise ValueError('Protocol message %s has no "%s" field.' % 573 (message_descriptor.name, field_name)) 574 575 576def _AddPropertiesForFields(descriptor, cls): 577 """Adds properties for all fields in this protocol message type.""" 578 for field in descriptor.fields: 579 _AddPropertiesForField(field, cls) 580 581 if descriptor.is_extendable: 582 # _ExtensionDict is just an adaptor with no state so we allocate a new one 583 # every time it is accessed. 584 cls.Extensions = property(lambda self: _ExtensionDict(self)) 585 586 587def _AddPropertiesForField(field, cls): 588 """Adds a public property for a protocol message field. 589 Clients can use this property to get and (in the case 590 of non-repeated scalar fields) directly set the value 591 of a protocol message field. 592 593 Args: 594 field: A FieldDescriptor for this field. 595 cls: The class we're constructing. 596 """ 597 # Catch it if we add other types that we should 598 # handle specially here. 599 assert _FieldDescriptor.MAX_CPPTYPE == 10 600 601 constant_name = field.name.upper() + '_FIELD_NUMBER' 602 setattr(cls, constant_name, field.number) 603 604 if field.label == _FieldDescriptor.LABEL_REPEATED: 605 _AddPropertiesForRepeatedField(field, cls) 606 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 607 _AddPropertiesForNonRepeatedCompositeField(field, cls) 608 else: 609 _AddPropertiesForNonRepeatedScalarField(field, cls) 610 611 612class _FieldProperty(property): 613 __slots__ = ('DESCRIPTOR',) 614 615 def __init__(self, descriptor, getter, setter, doc): 616 property.__init__(self, getter, setter, doc=doc) 617 self.DESCRIPTOR = descriptor 618 619 620def _AddPropertiesForRepeatedField(field, cls): 621 """Adds a public property for a "repeated" protocol message field. Clients 622 can use this property to get the value of the field, which will be either a 623 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see 624 below). 625 626 Note that when clients add values to these containers, we perform 627 type-checking in the case of repeated scalar fields, and we also set any 628 necessary "has" bits as a side-effect. 629 630 Args: 631 field: A FieldDescriptor for this field. 632 cls: The class we're constructing. 633 """ 634 proto_field_name = field.name 635 property_name = _PropertyName(proto_field_name) 636 637 def getter(self): 638 field_value = self._fields.get(field) 639 if field_value is None: 640 # Construct a new object to represent this field. 641 field_value = field._default_constructor(self) 642 643 # Atomically check if another thread has preempted us and, if not, swap 644 # in the new object we just created. If someone has preempted us, we 645 # take that object and discard ours. 646 # WARNING: We are relying on setdefault() being atomic. This is true 647 # in CPython but we haven't investigated others. This warning appears 648 # in several other locations in this file. 649 field_value = self._fields.setdefault(field, field_value) 650 return field_value 651 getter.__module__ = None 652 getter.__doc__ = 'Getter for %s.' % proto_field_name 653 654 # We define a setter just so we can throw an exception with a more 655 # helpful error message. 656 def setter(self, new_value): 657 raise AttributeError('Assignment not allowed to repeated field ' 658 '"%s" in protocol message object.' % proto_field_name) 659 660 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 661 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 662 663 664def _AddPropertiesForNonRepeatedScalarField(field, cls): 665 """Adds a public property for a nonrepeated, scalar protocol message field. 666 Clients can use this property to get and directly set the value of the field. 667 Note that when the client sets the value of a field by using this property, 668 all necessary "has" bits are set as a side-effect, and we also perform 669 type-checking. 670 671 Args: 672 field: A FieldDescriptor for this field. 673 cls: The class we're constructing. 674 """ 675 proto_field_name = field.name 676 property_name = _PropertyName(proto_field_name) 677 type_checker = type_checkers.GetTypeChecker(field) 678 default_value = field.default_value 679 valid_values = set() 680 is_proto3 = field.containing_type.syntax == 'proto3' 681 682 def getter(self): 683 # TODO(protobuf-team): This may be broken since there may not be 684 # default_value. Combine with has_default_value somehow. 685 return self._fields.get(field, default_value) 686 getter.__module__ = None 687 getter.__doc__ = 'Getter for %s.' % proto_field_name 688 689 clear_when_set_to_default = is_proto3 and not field.containing_oneof 690 691 def field_setter(self, new_value): 692 # pylint: disable=protected-access 693 # Testing the value for truthiness captures all of the proto3 defaults 694 # (0, 0.0, enum 0, and False). 695 try: 696 new_value = type_checker.CheckValue(new_value) 697 except TypeError as e: 698 raise TypeError( 699 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) 700 if clear_when_set_to_default and not new_value: 701 self._fields.pop(field, None) 702 else: 703 self._fields[field] = new_value 704 # Check _cached_byte_size_dirty inline to improve performance, since scalar 705 # setters are called frequently. 706 if not self._cached_byte_size_dirty: 707 self._Modified() 708 709 if field.containing_oneof: 710 def setter(self, new_value): 711 field_setter(self, new_value) 712 self._UpdateOneofState(field) 713 else: 714 setter = field_setter 715 716 setter.__module__ = None 717 setter.__doc__ = 'Setter for %s.' % proto_field_name 718 719 # Add a property to encapsulate the getter/setter. 720 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 721 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 722 723 724def _AddPropertiesForNonRepeatedCompositeField(field, cls): 725 """Adds a public property for a nonrepeated, composite protocol message field. 726 A composite field is a "group" or "message" field. 727 728 Clients can use this property to get the value of the field, but cannot 729 assign to the property directly. 730 731 Args: 732 field: A FieldDescriptor for this field. 733 cls: The class we're constructing. 734 """ 735 # TODO(robinson): Remove duplication with similar method 736 # for non-repeated scalars. 737 proto_field_name = field.name 738 property_name = _PropertyName(proto_field_name) 739 740 def getter(self): 741 field_value = self._fields.get(field) 742 if field_value is None: 743 # Construct a new object to represent this field. 744 field_value = field._default_constructor(self) 745 746 # Atomically check if another thread has preempted us and, if not, swap 747 # in the new object we just created. If someone has preempted us, we 748 # take that object and discard ours. 749 # WARNING: We are relying on setdefault() being atomic. This is true 750 # in CPython but we haven't investigated others. This warning appears 751 # in several other locations in this file. 752 field_value = self._fields.setdefault(field, field_value) 753 return field_value 754 getter.__module__ = None 755 getter.__doc__ = 'Getter for %s.' % proto_field_name 756 757 # We define a setter just so we can throw an exception with a more 758 # helpful error message. 759 def setter(self, new_value): 760 raise AttributeError('Assignment not allowed to composite field ' 761 '"%s" in protocol message object.' % proto_field_name) 762 763 # Add a property to encapsulate the getter. 764 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 765 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 766 767 768def _AddPropertiesForExtensions(descriptor, cls): 769 """Adds properties for all fields in this protocol message type.""" 770 extensions = descriptor.extensions_by_name 771 for extension_name, extension_field in extensions.items(): 772 constant_name = extension_name.upper() + '_FIELD_NUMBER' 773 setattr(cls, constant_name, extension_field.number) 774 775 # TODO(amauryfa): Migrate all users of these attributes to functions like 776 # pool.FindExtensionByNumber(descriptor). 777 if descriptor.file is not None: 778 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 779 pool = descriptor.file.pool 780 cls._extensions_by_number = pool._extensions_by_number[descriptor] 781 cls._extensions_by_name = pool._extensions_by_name[descriptor] 782 783def _AddStaticMethods(cls): 784 # TODO(robinson): This probably needs to be thread-safe(?) 785 def RegisterExtension(extension_handle): 786 extension_handle.containing_type = cls.DESCRIPTOR 787 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 788 cls.DESCRIPTOR.file.pool.AddExtensionDescriptor(extension_handle) 789 _AttachFieldHelpers(cls, extension_handle) 790 cls.RegisterExtension = staticmethod(RegisterExtension) 791 792 def FromString(s): 793 message = cls() 794 message.MergeFromString(s) 795 return message 796 cls.FromString = staticmethod(FromString) 797 798 799def _IsPresent(item): 800 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 801 value should be included in the list returned by ListFields().""" 802 803 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 804 return bool(item[1]) 805 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 806 return item[1]._is_present_in_parent 807 else: 808 return True 809 810 811def _AddListFieldsMethod(message_descriptor, cls): 812 """Helper for _AddMessageMethods().""" 813 814 def ListFields(self): 815 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 816 all_fields.sort(key = lambda item: item[0].number) 817 return all_fields 818 819 cls.ListFields = ListFields 820 821_PROTO3_ERROR_TEMPLATE = \ 822 'Protocol message %s has no non-repeated submessage field "%s"' 823_PROTO2_ERROR_TEMPLATE = 'Protocol message %s has no non-repeated field "%s"' 824 825def _AddHasFieldMethod(message_descriptor, cls): 826 """Helper for _AddMessageMethods().""" 827 828 is_proto3 = (message_descriptor.syntax == "proto3") 829 error_msg = _PROTO3_ERROR_TEMPLATE if is_proto3 else _PROTO2_ERROR_TEMPLATE 830 831 hassable_fields = {} 832 for field in message_descriptor.fields: 833 if field.label == _FieldDescriptor.LABEL_REPEATED: 834 continue 835 # For proto3, only submessages and fields inside a oneof have presence. 836 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and 837 not field.containing_oneof): 838 continue 839 hassable_fields[field.name] = field 840 841 if not is_proto3: 842 # Fields inside oneofs are never repeated (enforced by the compiler). 843 for oneof in message_descriptor.oneofs: 844 hassable_fields[oneof.name] = oneof 845 846 def HasField(self, field_name): 847 try: 848 field = hassable_fields[field_name] 849 except KeyError: 850 raise ValueError(error_msg % (message_descriptor.full_name, field_name)) 851 852 if isinstance(field, descriptor_mod.OneofDescriptor): 853 try: 854 return HasField(self, self._oneofs[field].name) 855 except KeyError: 856 return False 857 else: 858 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 859 value = self._fields.get(field) 860 return value is not None and value._is_present_in_parent 861 else: 862 return field in self._fields 863 864 cls.HasField = HasField 865 866 867def _AddClearFieldMethod(message_descriptor, cls): 868 """Helper for _AddMessageMethods().""" 869 def ClearField(self, field_name): 870 try: 871 field = message_descriptor.fields_by_name[field_name] 872 except KeyError: 873 try: 874 field = message_descriptor.oneofs_by_name[field_name] 875 if field in self._oneofs: 876 field = self._oneofs[field] 877 else: 878 return 879 except KeyError: 880 raise ValueError('Protocol message %s has no "%s" field.' % 881 (message_descriptor.name, field_name)) 882 883 if field in self._fields: 884 # To match the C++ implementation, we need to invalidate iterators 885 # for map fields when ClearField() happens. 886 if hasattr(self._fields[field], 'InvalidateIterators'): 887 self._fields[field].InvalidateIterators() 888 889 # Note: If the field is a sub-message, its listener will still point 890 # at us. That's fine, because the worst than can happen is that it 891 # will call _Modified() and invalidate our byte size. Big deal. 892 del self._fields[field] 893 894 if self._oneofs.get(field.containing_oneof, None) is field: 895 del self._oneofs[field.containing_oneof] 896 897 # Always call _Modified() -- even if nothing was changed, this is 898 # a mutating method, and thus calling it should cause the field to become 899 # present in the parent message. 900 self._Modified() 901 902 cls.ClearField = ClearField 903 904 905def _AddClearExtensionMethod(cls): 906 """Helper for _AddMessageMethods().""" 907 def ClearExtension(self, extension_handle): 908 extension_dict._VerifyExtensionHandle(self, extension_handle) 909 910 # Similar to ClearField(), above. 911 if extension_handle in self._fields: 912 del self._fields[extension_handle] 913 self._Modified() 914 cls.ClearExtension = ClearExtension 915 916 917def _AddHasExtensionMethod(cls): 918 """Helper for _AddMessageMethods().""" 919 def HasExtension(self, extension_handle): 920 extension_dict._VerifyExtensionHandle(self, extension_handle) 921 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 922 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 923 924 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 925 value = self._fields.get(extension_handle) 926 return value is not None and value._is_present_in_parent 927 else: 928 return extension_handle in self._fields 929 cls.HasExtension = HasExtension 930 931def _InternalUnpackAny(msg): 932 """Unpacks Any message and returns the unpacked message. 933 934 This internal method is different from public Any Unpack method which takes 935 the target message as argument. _InternalUnpackAny method does not have 936 target message type and need to find the message type in descriptor pool. 937 938 Args: 939 msg: An Any message to be unpacked. 940 941 Returns: 942 The unpacked message. 943 """ 944 # TODO(amauryfa): Don't use the factory of generated messages. 945 # To make Any work with custom factories, use the message factory of the 946 # parent message. 947 # pylint: disable=g-import-not-at-top 948 from google.protobuf import symbol_database 949 factory = symbol_database.Default() 950 951 type_url = msg.type_url 952 953 if not type_url: 954 return None 955 956 # TODO(haberman): For now we just strip the hostname. Better logic will be 957 # required. 958 type_name = type_url.split('/')[-1] 959 descriptor = factory.pool.FindMessageTypeByName(type_name) 960 961 if descriptor is None: 962 return None 963 964 message_class = factory.GetPrototype(descriptor) 965 message = message_class() 966 967 message.ParseFromString(msg.value) 968 return message 969 970 971def _AddEqualsMethod(message_descriptor, cls): 972 """Helper for _AddMessageMethods().""" 973 def __eq__(self, other): 974 if (not isinstance(other, message_mod.Message) or 975 other.DESCRIPTOR != self.DESCRIPTOR): 976 return False 977 978 if self is other: 979 return True 980 981 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 982 any_a = _InternalUnpackAny(self) 983 any_b = _InternalUnpackAny(other) 984 if any_a and any_b: 985 return any_a == any_b 986 987 if not self.ListFields() == other.ListFields(): 988 return False 989 990 # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions, 991 # then use it for the comparison. 992 unknown_fields = list(self._unknown_fields) 993 unknown_fields.sort() 994 other_unknown_fields = list(other._unknown_fields) 995 other_unknown_fields.sort() 996 return unknown_fields == other_unknown_fields 997 998 cls.__eq__ = __eq__ 999 1000 1001def _AddStrMethod(message_descriptor, cls): 1002 """Helper for _AddMessageMethods().""" 1003 def __str__(self): 1004 return text_format.MessageToString(self) 1005 cls.__str__ = __str__ 1006 1007 1008def _AddReprMethod(message_descriptor, cls): 1009 """Helper for _AddMessageMethods().""" 1010 def __repr__(self): 1011 return text_format.MessageToString(self) 1012 cls.__repr__ = __repr__ 1013 1014 1015def _AddUnicodeMethod(unused_message_descriptor, cls): 1016 """Helper for _AddMessageMethods().""" 1017 1018 def __unicode__(self): 1019 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 1020 cls.__unicode__ = __unicode__ 1021 1022 1023def _BytesForNonRepeatedElement(value, field_number, field_type): 1024 """Returns the number of bytes needed to serialize a non-repeated element. 1025 The returned byte count includes space for tag information and any 1026 other additional space associated with serializing value. 1027 1028 Args: 1029 value: Value we're serializing. 1030 field_number: Field number of this value. (Since the field number 1031 is stored as part of a varint-encoded tag, this has an impact 1032 on the total bytes required to serialize the value). 1033 field_type: The type of the field. One of the TYPE_* constants 1034 within FieldDescriptor. 1035 """ 1036 try: 1037 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 1038 return fn(field_number, value) 1039 except KeyError: 1040 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 1041 1042 1043def _AddByteSizeMethod(message_descriptor, cls): 1044 """Helper for _AddMessageMethods().""" 1045 1046 def ByteSize(self): 1047 if not self._cached_byte_size_dirty: 1048 return self._cached_byte_size 1049 1050 size = 0 1051 descriptor = self.DESCRIPTOR 1052 if descriptor.GetOptions().map_entry: 1053 # Fields of map entry should always be serialized. 1054 size = descriptor.fields_by_name['key']._sizer(self.key) 1055 size += descriptor.fields_by_name['value']._sizer(self.value) 1056 else: 1057 for field_descriptor, field_value in self.ListFields(): 1058 size += field_descriptor._sizer(field_value) 1059 for tag_bytes, value_bytes in self._unknown_fields: 1060 size += len(tag_bytes) + len(value_bytes) 1061 1062 self._cached_byte_size = size 1063 self._cached_byte_size_dirty = False 1064 self._listener_for_children.dirty = False 1065 return size 1066 1067 cls.ByteSize = ByteSize 1068 1069 1070def _AddSerializeToStringMethod(message_descriptor, cls): 1071 """Helper for _AddMessageMethods().""" 1072 1073 def SerializeToString(self, **kwargs): 1074 # Check if the message has all of its required fields set. 1075 errors = [] 1076 if not self.IsInitialized(): 1077 raise message_mod.EncodeError( 1078 'Message %s is missing required fields: %s' % ( 1079 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 1080 return self.SerializePartialToString(**kwargs) 1081 cls.SerializeToString = SerializeToString 1082 1083 1084def _AddSerializePartialToStringMethod(message_descriptor, cls): 1085 """Helper for _AddMessageMethods().""" 1086 1087 def SerializePartialToString(self, **kwargs): 1088 out = BytesIO() 1089 self._InternalSerialize(out.write, **kwargs) 1090 return out.getvalue() 1091 cls.SerializePartialToString = SerializePartialToString 1092 1093 def InternalSerialize(self, write_bytes, deterministic=None): 1094 if deterministic is None: 1095 deterministic = ( 1096 api_implementation.IsPythonDefaultSerializationDeterministic()) 1097 else: 1098 deterministic = bool(deterministic) 1099 1100 descriptor = self.DESCRIPTOR 1101 if descriptor.GetOptions().map_entry: 1102 # Fields of map entry should always be serialized. 1103 descriptor.fields_by_name['key']._encoder( 1104 write_bytes, self.key, deterministic) 1105 descriptor.fields_by_name['value']._encoder( 1106 write_bytes, self.value, deterministic) 1107 else: 1108 for field_descriptor, field_value in self.ListFields(): 1109 field_descriptor._encoder(write_bytes, field_value, deterministic) 1110 for tag_bytes, value_bytes in self._unknown_fields: 1111 write_bytes(tag_bytes) 1112 write_bytes(value_bytes) 1113 cls._InternalSerialize = InternalSerialize 1114 1115 1116def _AddMergeFromStringMethod(message_descriptor, cls): 1117 """Helper for _AddMessageMethods().""" 1118 def MergeFromString(self, serialized): 1119 if isinstance(serialized, memoryview) and six.PY2: 1120 raise TypeError( 1121 'memoryview not supported in Python 2 with the pure Python proto ' 1122 'implementation: this is to maintain compatibility with the C++ ' 1123 'implementation') 1124 1125 serialized = memoryview(serialized) 1126 length = len(serialized) 1127 try: 1128 if self._InternalParse(serialized, 0, length) != length: 1129 # The only reason _InternalParse would return early is if it 1130 # encountered an end-group tag. 1131 raise message_mod.DecodeError('Unexpected end-group tag.') 1132 except (IndexError, TypeError): 1133 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 1134 raise message_mod.DecodeError('Truncated message.') 1135 except struct.error as e: 1136 raise message_mod.DecodeError(e) 1137 return length # Return this for legacy reasons. 1138 cls.MergeFromString = MergeFromString 1139 1140 local_ReadTag = decoder.ReadTag 1141 local_SkipField = decoder.SkipField 1142 decoders_by_tag = cls._decoders_by_tag 1143 1144 def InternalParse(self, buffer, pos, end): 1145 """Create a message from serialized bytes. 1146 1147 Args: 1148 self: Message, instance of the proto message object. 1149 buffer: memoryview of the serialized data. 1150 pos: int, position to start in the serialized data. 1151 end: int, end position of the serialized data. 1152 1153 Returns: 1154 Message object. 1155 """ 1156 # Guard against internal misuse, since this function is called internally 1157 # quite extensively, and its easy to accidentally pass bytes. 1158 assert isinstance(buffer, memoryview) 1159 self._Modified() 1160 field_dict = self._fields 1161 # pylint: disable=protected-access 1162 unknown_field_set = self._unknown_field_set 1163 while pos != end: 1164 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 1165 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) 1166 if field_decoder is None: 1167 if not self._unknown_fields: # pylint: disable=protected-access 1168 self._unknown_fields = [] # pylint: disable=protected-access 1169 if unknown_field_set is None: 1170 # pylint: disable=protected-access 1171 self._unknown_field_set = containers.UnknownFieldSet() 1172 # pylint: disable=protected-access 1173 unknown_field_set = self._unknown_field_set 1174 # pylint: disable=protected-access 1175 (tag, _) = decoder._DecodeVarint(tag_bytes, 0) 1176 field_number, wire_type = wire_format.UnpackTag(tag) 1177 # TODO(jieluo): remove old_pos. 1178 old_pos = new_pos 1179 (data, new_pos) = decoder._DecodeUnknownField( 1180 buffer, new_pos, wire_type) # pylint: disable=protected-access 1181 if new_pos == -1: 1182 return pos 1183 # pylint: disable=protected-access 1184 unknown_field_set._add(field_number, wire_type, data) 1185 # TODO(jieluo): remove _unknown_fields. 1186 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) 1187 if new_pos == -1: 1188 return pos 1189 self._unknown_fields.append( 1190 (tag_bytes, buffer[old_pos:new_pos].tobytes())) 1191 pos = new_pos 1192 else: 1193 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1194 if field_desc: 1195 self._UpdateOneofState(field_desc) 1196 return pos 1197 cls._InternalParse = InternalParse 1198 1199 1200def _AddIsInitializedMethod(message_descriptor, cls): 1201 """Adds the IsInitialized and FindInitializationError methods to the 1202 protocol message class.""" 1203 1204 required_fields = [field for field in message_descriptor.fields 1205 if field.label == _FieldDescriptor.LABEL_REQUIRED] 1206 1207 def IsInitialized(self, errors=None): 1208 """Checks if all required fields of a message are set. 1209 1210 Args: 1211 errors: A list which, if provided, will be populated with the field 1212 paths of all missing required fields. 1213 1214 Returns: 1215 True iff the specified message has all required fields set. 1216 """ 1217 1218 # Performance is critical so we avoid HasField() and ListFields(). 1219 1220 for field in required_fields: 1221 if (field not in self._fields or 1222 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 1223 not self._fields[field]._is_present_in_parent)): 1224 if errors is not None: 1225 errors.extend(self.FindInitializationErrors()) 1226 return False 1227 1228 for field, value in list(self._fields.items()): # dict can change size! 1229 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1230 if field.label == _FieldDescriptor.LABEL_REPEATED: 1231 if (field.message_type.has_options and 1232 field.message_type.GetOptions().map_entry): 1233 continue 1234 for element in value: 1235 if not element.IsInitialized(): 1236 if errors is not None: 1237 errors.extend(self.FindInitializationErrors()) 1238 return False 1239 elif value._is_present_in_parent and not value.IsInitialized(): 1240 if errors is not None: 1241 errors.extend(self.FindInitializationErrors()) 1242 return False 1243 1244 return True 1245 1246 cls.IsInitialized = IsInitialized 1247 1248 def FindInitializationErrors(self): 1249 """Finds required fields which are not initialized. 1250 1251 Returns: 1252 A list of strings. Each string is a path to an uninitialized field from 1253 the top-level message, e.g. "foo.bar[5].baz". 1254 """ 1255 1256 errors = [] # simplify things 1257 1258 for field in required_fields: 1259 if not self.HasField(field.name): 1260 errors.append(field.name) 1261 1262 for field, value in self.ListFields(): 1263 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1264 if field.is_extension: 1265 name = '(%s)' % field.full_name 1266 else: 1267 name = field.name 1268 1269 if _IsMapField(field): 1270 if _IsMessageMapField(field): 1271 for key in value: 1272 element = value[key] 1273 prefix = '%s[%s].' % (name, key) 1274 sub_errors = element.FindInitializationErrors() 1275 errors += [prefix + error for error in sub_errors] 1276 else: 1277 # ScalarMaps can't have any initialization errors. 1278 pass 1279 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1280 for i in range(len(value)): 1281 element = value[i] 1282 prefix = '%s[%d].' % (name, i) 1283 sub_errors = element.FindInitializationErrors() 1284 errors += [prefix + error for error in sub_errors] 1285 else: 1286 prefix = name + '.' 1287 sub_errors = value.FindInitializationErrors() 1288 errors += [prefix + error for error in sub_errors] 1289 1290 return errors 1291 1292 cls.FindInitializationErrors = FindInitializationErrors 1293 1294 1295def _AddMergeFromMethod(cls): 1296 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 1297 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 1298 1299 def MergeFrom(self, msg): 1300 if not isinstance(msg, cls): 1301 raise TypeError( 1302 'Parameter to MergeFrom() must be instance of same class: ' 1303 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__)) 1304 1305 assert msg is not self 1306 self._Modified() 1307 1308 fields = self._fields 1309 1310 for field, value in msg._fields.items(): 1311 if field.label == LABEL_REPEATED: 1312 field_value = fields.get(field) 1313 if field_value is None: 1314 # Construct a new object to represent this field. 1315 field_value = field._default_constructor(self) 1316 fields[field] = field_value 1317 field_value.MergeFrom(value) 1318 elif field.cpp_type == CPPTYPE_MESSAGE: 1319 if value._is_present_in_parent: 1320 field_value = fields.get(field) 1321 if field_value is None: 1322 # Construct a new object to represent this field. 1323 field_value = field._default_constructor(self) 1324 fields[field] = field_value 1325 field_value.MergeFrom(value) 1326 else: 1327 self._fields[field] = value 1328 if field.containing_oneof: 1329 self._UpdateOneofState(field) 1330 1331 if msg._unknown_fields: 1332 if not self._unknown_fields: 1333 self._unknown_fields = [] 1334 self._unknown_fields.extend(msg._unknown_fields) 1335 # pylint: disable=protected-access 1336 if self._unknown_field_set is None: 1337 self._unknown_field_set = containers.UnknownFieldSet() 1338 self._unknown_field_set._extend(msg._unknown_field_set) 1339 1340 cls.MergeFrom = MergeFrom 1341 1342 1343def _AddWhichOneofMethod(message_descriptor, cls): 1344 def WhichOneof(self, oneof_name): 1345 """Returns the name of the currently set field inside a oneof, or None.""" 1346 try: 1347 field = message_descriptor.oneofs_by_name[oneof_name] 1348 except KeyError: 1349 raise ValueError( 1350 'Protocol message has no oneof "%s" field.' % oneof_name) 1351 1352 nested_field = self._oneofs.get(field, None) 1353 if nested_field is not None and self.HasField(nested_field.name): 1354 return nested_field.name 1355 else: 1356 return None 1357 1358 cls.WhichOneof = WhichOneof 1359 1360 1361def _AddReduceMethod(cls): 1362 def __reduce__(self): # pylint: disable=invalid-name 1363 return (type(self), (), self.__getstate__()) 1364 cls.__reduce__ = __reduce__ 1365 1366 1367def _Clear(self): 1368 # Clear fields. 1369 self._fields = {} 1370 self._unknown_fields = () 1371 # pylint: disable=protected-access 1372 if self._unknown_field_set is not None: 1373 self._unknown_field_set._clear() 1374 self._unknown_field_set = None 1375 1376 self._oneofs = {} 1377 self._Modified() 1378 1379 1380def _UnknownFields(self): 1381 if self._unknown_field_set is None: # pylint: disable=protected-access 1382 # pylint: disable=protected-access 1383 self._unknown_field_set = containers.UnknownFieldSet() 1384 return self._unknown_field_set # pylint: disable=protected-access 1385 1386 1387def _DiscardUnknownFields(self): 1388 self._unknown_fields = [] 1389 self._unknown_field_set = None # pylint: disable=protected-access 1390 for field, value in self.ListFields(): 1391 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1392 if _IsMapField(field): 1393 if _IsMessageMapField(field): 1394 for key in value: 1395 value[key].DiscardUnknownFields() 1396 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1397 for sub_message in value: 1398 sub_message.DiscardUnknownFields() 1399 else: 1400 value.DiscardUnknownFields() 1401 1402 1403def _SetListener(self, listener): 1404 if listener is None: 1405 self._listener = message_listener_mod.NullMessageListener() 1406 else: 1407 self._listener = listener 1408 1409 1410def _AddMessageMethods(message_descriptor, cls): 1411 """Adds implementations of all Message methods to cls.""" 1412 _AddListFieldsMethod(message_descriptor, cls) 1413 _AddHasFieldMethod(message_descriptor, cls) 1414 _AddClearFieldMethod(message_descriptor, cls) 1415 if message_descriptor.is_extendable: 1416 _AddClearExtensionMethod(cls) 1417 _AddHasExtensionMethod(cls) 1418 _AddEqualsMethod(message_descriptor, cls) 1419 _AddStrMethod(message_descriptor, cls) 1420 _AddReprMethod(message_descriptor, cls) 1421 _AddUnicodeMethod(message_descriptor, cls) 1422 _AddByteSizeMethod(message_descriptor, cls) 1423 _AddSerializeToStringMethod(message_descriptor, cls) 1424 _AddSerializePartialToStringMethod(message_descriptor, cls) 1425 _AddMergeFromStringMethod(message_descriptor, cls) 1426 _AddIsInitializedMethod(message_descriptor, cls) 1427 _AddMergeFromMethod(cls) 1428 _AddWhichOneofMethod(message_descriptor, cls) 1429 _AddReduceMethod(cls) 1430 # Adds methods which do not depend on cls. 1431 cls.Clear = _Clear 1432 cls.UnknownFields = _UnknownFields 1433 cls.DiscardUnknownFields = _DiscardUnknownFields 1434 cls._SetListener = _SetListener 1435 1436 1437def _AddPrivateHelperMethods(message_descriptor, cls): 1438 """Adds implementation of private helper methods to cls.""" 1439 1440 def Modified(self): 1441 """Sets the _cached_byte_size_dirty bit to true, 1442 and propagates this to our listener iff this was a state change. 1443 """ 1444 1445 # Note: Some callers check _cached_byte_size_dirty before calling 1446 # _Modified() as an extra optimization. So, if this method is ever 1447 # changed such that it does stuff even when _cached_byte_size_dirty is 1448 # already true, the callers need to be updated. 1449 if not self._cached_byte_size_dirty: 1450 self._cached_byte_size_dirty = True 1451 self._listener_for_children.dirty = True 1452 self._is_present_in_parent = True 1453 self._listener.Modified() 1454 1455 def _UpdateOneofState(self, field): 1456 """Sets field as the active field in its containing oneof. 1457 1458 Will also delete currently active field in the oneof, if it is different 1459 from the argument. Does not mark the message as modified. 1460 """ 1461 other_field = self._oneofs.setdefault(field.containing_oneof, field) 1462 if other_field is not field: 1463 del self._fields[other_field] 1464 self._oneofs[field.containing_oneof] = field 1465 1466 cls._Modified = Modified 1467 cls.SetInParent = Modified 1468 cls._UpdateOneofState = _UpdateOneofState 1469 1470 1471class _Listener(object): 1472 1473 """MessageListener implementation that a parent message registers with its 1474 child message. 1475 1476 In order to support semantics like: 1477 1478 foo.bar.baz.qux = 23 1479 assert foo.HasField('bar') 1480 1481 ...child objects must have back references to their parents. 1482 This helper class is at the heart of this support. 1483 """ 1484 1485 def __init__(self, parent_message): 1486 """Args: 1487 parent_message: The message whose _Modified() method we should call when 1488 we receive Modified() messages. 1489 """ 1490 # This listener establishes a back reference from a child (contained) object 1491 # to its parent (containing) object. We make this a weak reference to avoid 1492 # creating cyclic garbage when the client finishes with the 'parent' object 1493 # in the tree. 1494 if isinstance(parent_message, weakref.ProxyType): 1495 self._parent_message_weakref = parent_message 1496 else: 1497 self._parent_message_weakref = weakref.proxy(parent_message) 1498 1499 # As an optimization, we also indicate directly on the listener whether 1500 # or not the parent message is dirty. This way we can avoid traversing 1501 # up the tree in the common case. 1502 self.dirty = False 1503 1504 def Modified(self): 1505 if self.dirty: 1506 return 1507 try: 1508 # Propagate the signal to our parents iff this is the first field set. 1509 self._parent_message_weakref._Modified() 1510 except ReferenceError: 1511 # We can get here if a client has kept a reference to a child object, 1512 # and is now setting a field on it, but the child's parent has been 1513 # garbage-collected. This is not an error. 1514 pass 1515 1516 1517class _OneofListener(_Listener): 1518 """Special listener implementation for setting composite oneof fields.""" 1519 1520 def __init__(self, parent_message, field): 1521 """Args: 1522 parent_message: The message whose _Modified() method we should call when 1523 we receive Modified() messages. 1524 field: The descriptor of the field being set in the parent message. 1525 """ 1526 super(_OneofListener, self).__init__(parent_message) 1527 self._field = field 1528 1529 def Modified(self): 1530 """Also updates the state of the containing oneof in the parent message.""" 1531 try: 1532 self._parent_message_weakref._UpdateOneofState(self._field) 1533 super(_OneofListener, self).Modified() 1534 except ReferenceError: 1535 pass 1536