1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31# This code is meant to work on Python 2.4 and above only. 32# 33# TODO(robinson): Helpers for verbose, common checks like seeing if a 34# descriptor's cpp_type is CPPTYPE_MESSAGE. 35 36"""Contains a metaclass and helper functions used to create 37protocol message classes from Descriptor objects at runtime. 38 39Recall that a metaclass is the "type" of a class. 40(A class is to a metaclass what an instance is to a class.) 41 42In this case, we use the GeneratedProtocolMessageType metaclass 43to inject all the useful functionality into the classes 44output by the protocol compiler at compile-time. 45 46The upshot of all this is that the real implementation 47details for ALL pure-Python protocol buffers are *here in 48this file*. 49""" 50 51__author__ = 'robinson@google.com (Will Robinson)' 52 53from io import BytesIO 54import struct 55import sys 56import weakref 57 58import six 59from six.moves import range 60 61# We use "as" to avoid name collisions with variables. 62from google.protobuf.internal import api_implementation 63from google.protobuf.internal import containers 64from google.protobuf.internal import decoder 65from google.protobuf.internal import encoder 66from google.protobuf.internal import enum_type_wrapper 67from google.protobuf.internal import extension_dict 68from google.protobuf.internal import message_listener as message_listener_mod 69from google.protobuf.internal import type_checkers 70from google.protobuf.internal import well_known_types 71from google.protobuf.internal import wire_format 72from google.protobuf import descriptor as descriptor_mod 73from google.protobuf import message as message_mod 74from google.protobuf import text_format 75 76_FieldDescriptor = descriptor_mod.FieldDescriptor 77_AnyFullTypeName = 'google.protobuf.Any' 78_ExtensionDict = extension_dict._ExtensionDict 79 80class GeneratedProtocolMessageType(type): 81 82 """Metaclass for protocol message classes created at runtime from Descriptors. 83 84 We add implementations for all methods described in the Message class. We 85 also create properties to allow getting/setting all fields in the protocol 86 message. Finally, we create slots to prevent users from accidentally 87 "setting" nonexistent fields in the protocol message, which then wouldn't get 88 serialized / deserialized properly. 89 90 The protocol compiler currently uses this metaclass to create protocol 91 message classes at runtime. Clients can also manually create their own 92 classes at runtime, as in this example: 93 94 mydescriptor = Descriptor(.....) 95 factory = symbol_database.Default() 96 factory.pool.AddDescriptor(mydescriptor) 97 MyProtoClass = factory.GetPrototype(mydescriptor) 98 myproto_instance = MyProtoClass() 99 myproto.foo_field = 23 100 ... 101 """ 102 103 # Must be consistent with the protocol-compiler code in 104 # proto2/compiler/internal/generator.*. 105 _DESCRIPTOR_KEY = 'DESCRIPTOR' 106 107 def __new__(cls, name, bases, dictionary): 108 """Custom allocation for runtime-generated class types. 109 110 We override __new__ because this is apparently the only place 111 where we can meaningfully set __slots__ on the class we're creating(?). 112 (The interplay between metaclasses and slots is not very well-documented). 113 114 Args: 115 name: Name of the class (ignored, but required by the 116 metaclass protocol). 117 bases: Base classes of the class we're constructing. 118 (Should be message.Message). We ignore this field, but 119 it's required by the metaclass protocol 120 dictionary: The class dictionary of the class we're 121 constructing. dictionary[_DESCRIPTOR_KEY] must contain 122 a Descriptor object describing this protocol message 123 type. 124 125 Returns: 126 Newly-allocated class. 127 128 Raises: 129 RuntimeError: Generated code only work with python cpp extension. 130 """ 131 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 132 133 if isinstance(descriptor, str): 134 raise RuntimeError('The generated code only work with python cpp ' 135 'extension, but it is using pure python runtime.') 136 137 # If a concrete class already exists for this descriptor, don't try to 138 # create another. Doing so will break any messages that already exist with 139 # the existing class. 140 # 141 # The C++ implementation appears to have its own internal `PyMessageFactory` 142 # to achieve similar results. 143 # 144 # This most commonly happens in `text_format.py` when using descriptors from 145 # a custom pool; it calls symbol_database.Global().getPrototype() on a 146 # descriptor which already has an existing concrete class. 147 new_class = getattr(descriptor, '_concrete_class', None) 148 if new_class: 149 return new_class 150 151 if descriptor.full_name in well_known_types.WKTBASES: 152 bases += (well_known_types.WKTBASES[descriptor.full_name],) 153 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 154 _AddSlots(descriptor, dictionary) 155 156 superclass = super(GeneratedProtocolMessageType, cls) 157 new_class = superclass.__new__(cls, name, bases, dictionary) 158 return new_class 159 160 def __init__(cls, name, bases, dictionary): 161 """Here we perform the majority of our work on the class. 162 We add enum getters, an __init__ method, implementations 163 of all Message methods, and properties for all fields 164 in the protocol type. 165 166 Args: 167 name: Name of the class (ignored, but required by the 168 metaclass protocol). 169 bases: Base classes of the class we're constructing. 170 (Should be message.Message). We ignore this field, but 171 it's required by the metaclass protocol 172 dictionary: The class dictionary of the class we're 173 constructing. dictionary[_DESCRIPTOR_KEY] must contain 174 a Descriptor object describing this protocol message 175 type. 176 """ 177 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 178 179 # If this is an _existing_ class looked up via `_concrete_class` in the 180 # __new__ method above, then we don't need to re-initialize anything. 181 existing_class = getattr(descriptor, '_concrete_class', None) 182 if existing_class: 183 assert existing_class is cls, ( 184 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' 185 % (descriptor.full_name)) 186 return 187 188 cls._decoders_by_tag = {} 189 if (descriptor.has_options and 190 descriptor.GetOptions().message_set_wire_format): 191 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 192 decoder.MessageSetItemDecoder(descriptor), None) 193 194 # Attach stuff to each FieldDescriptor for quick lookup later on. 195 for field in descriptor.fields: 196 _AttachFieldHelpers(cls, field) 197 198 descriptor._concrete_class = cls # pylint: disable=protected-access 199 _AddEnumValues(descriptor, cls) 200 _AddInitMethod(descriptor, cls) 201 _AddPropertiesForFields(descriptor, cls) 202 _AddPropertiesForExtensions(descriptor, cls) 203 _AddStaticMethods(cls) 204 _AddMessageMethods(descriptor, cls) 205 _AddPrivateHelperMethods(descriptor, cls) 206 207 superclass = super(GeneratedProtocolMessageType, cls) 208 superclass.__init__(name, bases, dictionary) 209 210 211# Stateless helpers for GeneratedProtocolMessageType below. 212# Outside clients should not access these directly. 213# 214# I opted not to make any of these methods on the metaclass, to make it more 215# clear that I'm not really using any state there and to keep clients from 216# thinking that they have direct access to these construction helpers. 217 218 219def _PropertyName(proto_field_name): 220 """Returns the name of the public property attribute which 221 clients can use to get and (in some cases) set the value 222 of a protocol message field. 223 224 Args: 225 proto_field_name: The protocol message field name, exactly 226 as it appears (or would appear) in a .proto file. 227 """ 228 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 229 # nnorwitz makes my day by writing: 230 # """ 231 # FYI. See the keyword module in the stdlib. This could be as simple as: 232 # 233 # if keyword.iskeyword(proto_field_name): 234 # return proto_field_name + "_" 235 # return proto_field_name 236 # """ 237 # Kenton says: The above is a BAD IDEA. People rely on being able to use 238 # getattr() and setattr() to reflectively manipulate field values. If we 239 # rename the properties, then every such user has to also make sure to apply 240 # the same transformation. Note that currently if you name a field "yield", 241 # you can still access it just fine using getattr/setattr -- it's not even 242 # that cumbersome to do so. 243 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 244 # position. 245 return proto_field_name 246 247 248def _AddSlots(message_descriptor, dictionary): 249 """Adds a __slots__ entry to dictionary, containing the names of all valid 250 attributes for this message type. 251 252 Args: 253 message_descriptor: A Descriptor instance describing this message type. 254 dictionary: Class dictionary to which we'll add a '__slots__' entry. 255 """ 256 dictionary['__slots__'] = ['_cached_byte_size', 257 '_cached_byte_size_dirty', 258 '_fields', 259 '_unknown_fields', 260 '_unknown_field_set', 261 '_is_present_in_parent', 262 '_listener', 263 '_listener_for_children', 264 '__weakref__', 265 '_oneofs'] 266 267 268def _IsMessageSetExtension(field): 269 return (field.is_extension and 270 field.containing_type.has_options and 271 field.containing_type.GetOptions().message_set_wire_format and 272 field.type == _FieldDescriptor.TYPE_MESSAGE and 273 field.label == _FieldDescriptor.LABEL_OPTIONAL) 274 275 276def _IsMapField(field): 277 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 278 field.message_type.has_options and 279 field.message_type.GetOptions().map_entry) 280 281 282def _IsMessageMapField(field): 283 value_type = field.message_type.fields_by_name['value'] 284 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 285 286 287def _IsStrictUtf8Check(field): 288 if field.containing_type.syntax != 'proto3': 289 return False 290 enforce_utf8 = True 291 return enforce_utf8 292 293 294def _AttachFieldHelpers(cls, field_descriptor): 295 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 296 is_packable = (is_repeated and 297 wire_format.IsTypePackable(field_descriptor.type)) 298 is_proto3 = field_descriptor.containing_type.syntax == 'proto3' 299 if not is_packable: 300 is_packed = False 301 elif field_descriptor.containing_type.syntax == 'proto2': 302 is_packed = (field_descriptor.has_options and 303 field_descriptor.GetOptions().packed) 304 else: 305 has_packed_false = (field_descriptor.has_options and 306 field_descriptor.GetOptions().HasField('packed') and 307 field_descriptor.GetOptions().packed == False) 308 is_packed = not has_packed_false 309 is_map_entry = _IsMapField(field_descriptor) 310 311 if is_map_entry: 312 field_encoder = encoder.MapEncoder(field_descriptor) 313 sizer = encoder.MapSizer(field_descriptor, 314 _IsMessageMapField(field_descriptor)) 315 elif _IsMessageSetExtension(field_descriptor): 316 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 317 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 318 else: 319 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 320 field_descriptor.number, is_repeated, is_packed) 321 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 322 field_descriptor.number, is_repeated, is_packed) 323 324 field_descriptor._encoder = field_encoder 325 field_descriptor._sizer = sizer 326 field_descriptor._default_constructor = _DefaultValueConstructorForField( 327 field_descriptor) 328 329 def AddDecoder(wiretype, is_packed): 330 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 331 decode_type = field_descriptor.type 332 if (decode_type == _FieldDescriptor.TYPE_ENUM and 333 type_checkers.SupportsOpenEnums(field_descriptor)): 334 decode_type = _FieldDescriptor.TYPE_INT32 335 336 oneof_descriptor = None 337 clear_if_default = False 338 if field_descriptor.containing_oneof is not None: 339 oneof_descriptor = field_descriptor 340 elif (is_proto3 and not is_repeated and 341 field_descriptor.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): 342 clear_if_default = True 343 344 if is_map_entry: 345 is_message_map = _IsMessageMapField(field_descriptor) 346 347 field_decoder = decoder.MapDecoder( 348 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 349 is_message_map) 350 elif decode_type == _FieldDescriptor.TYPE_STRING: 351 is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor) 352 field_decoder = decoder.StringDecoder( 353 field_descriptor.number, is_repeated, is_packed, 354 field_descriptor, field_descriptor._default_constructor, 355 is_strict_utf8_check, clear_if_default) 356 elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 357 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 358 field_descriptor.number, is_repeated, is_packed, 359 field_descriptor, field_descriptor._default_constructor) 360 else: 361 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 362 field_descriptor.number, is_repeated, is_packed, 363 # pylint: disable=protected-access 364 field_descriptor, field_descriptor._default_constructor, 365 clear_if_default) 366 367 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) 368 369 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 370 False) 371 372 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 373 # To support wire compatibility of adding packed = true, add a decoder for 374 # packed values regardless of the field's options. 375 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 376 377 378def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 379 extensions = descriptor.extensions_by_name 380 for extension_name, extension_field in extensions.items(): 381 assert extension_name not in dictionary 382 dictionary[extension_name] = extension_field 383 384 385def _AddEnumValues(descriptor, cls): 386 """Sets class-level attributes for all enum fields defined in this message. 387 388 Also exporting a class-level object that can name enum values. 389 390 Args: 391 descriptor: Descriptor object for this message type. 392 cls: Class we're constructing for this message type. 393 """ 394 for enum_type in descriptor.enum_types: 395 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 396 for enum_value in enum_type.values: 397 setattr(cls, enum_value.name, enum_value.number) 398 399 400def _GetInitializeDefaultForMap(field): 401 if field.label != _FieldDescriptor.LABEL_REPEATED: 402 raise ValueError('map_entry set on non-repeated field %s' % ( 403 field.name)) 404 fields_by_name = field.message_type.fields_by_name 405 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 406 407 value_field = fields_by_name['value'] 408 if _IsMessageMapField(field): 409 def MakeMessageMapDefault(message): 410 return containers.MessageMap( 411 message._listener_for_children, value_field.message_type, key_checker, 412 field.message_type) 413 return MakeMessageMapDefault 414 else: 415 value_checker = type_checkers.GetTypeChecker(value_field) 416 def MakePrimitiveMapDefault(message): 417 return containers.ScalarMap( 418 message._listener_for_children, key_checker, value_checker, 419 field.message_type) 420 return MakePrimitiveMapDefault 421 422def _DefaultValueConstructorForField(field): 423 """Returns a function which returns a default value for a field. 424 425 Args: 426 field: FieldDescriptor object for this field. 427 428 The returned function has one argument: 429 message: Message instance containing this field, or a weakref proxy 430 of same. 431 432 That function in turn returns a default value for this field. The default 433 value may refer back to |message| via a weak reference. 434 """ 435 436 if _IsMapField(field): 437 return _GetInitializeDefaultForMap(field) 438 439 if field.label == _FieldDescriptor.LABEL_REPEATED: 440 if field.has_default_value and field.default_value != []: 441 raise ValueError('Repeated field default value not empty list: %s' % ( 442 field.default_value)) 443 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 444 # We can't look at _concrete_class yet since it might not have 445 # been set. (Depends on order in which we initialize the classes). 446 message_type = field.message_type 447 def MakeRepeatedMessageDefault(message): 448 return containers.RepeatedCompositeFieldContainer( 449 message._listener_for_children, field.message_type) 450 return MakeRepeatedMessageDefault 451 else: 452 type_checker = type_checkers.GetTypeChecker(field) 453 def MakeRepeatedScalarDefault(message): 454 return containers.RepeatedScalarFieldContainer( 455 message._listener_for_children, type_checker) 456 return MakeRepeatedScalarDefault 457 458 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 459 # _concrete_class may not yet be initialized. 460 message_type = field.message_type 461 def MakeSubMessageDefault(message): 462 assert getattr(message_type, '_concrete_class', None), ( 463 'Uninitialized concrete class found for field %r (message type %r)' 464 % (field.full_name, message_type.full_name)) 465 result = message_type._concrete_class() 466 result._SetListener( 467 _OneofListener(message, field) 468 if field.containing_oneof is not None 469 else message._listener_for_children) 470 return result 471 return MakeSubMessageDefault 472 473 def MakeScalarDefault(message): 474 # TODO(protobuf-team): This may be broken since there may not be 475 # default_value. Combine with has_default_value somehow. 476 return field.default_value 477 return MakeScalarDefault 478 479 480def _ReraiseTypeErrorWithFieldName(message_name, field_name): 481 """Re-raise the currently-handled TypeError with the field name added.""" 482 exc = sys.exc_info()[1] 483 if len(exc.args) == 1 and type(exc) is TypeError: 484 # simple TypeError; add field name to exception message 485 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 486 487 # re-raise possibly-amended exception with original traceback: 488 six.reraise(type(exc), exc, sys.exc_info()[2]) 489 490 491def _AddInitMethod(message_descriptor, cls): 492 """Adds an __init__ method to cls.""" 493 494 def _GetIntegerEnumValue(enum_type, value): 495 """Convert a string or integer enum value to an integer. 496 497 If the value is a string, it is converted to the enum value in 498 enum_type with the same name. If the value is not a string, it's 499 returned as-is. (No conversion or bounds-checking is done.) 500 """ 501 if isinstance(value, six.string_types): 502 try: 503 return enum_type.values_by_name[value].number 504 except KeyError: 505 raise ValueError('Enum type %s: unknown label "%s"' % ( 506 enum_type.full_name, value)) 507 return value 508 509 def init(self, **kwargs): 510 self._cached_byte_size = 0 511 self._cached_byte_size_dirty = len(kwargs) > 0 512 self._fields = {} 513 # Contains a mapping from oneof field descriptors to the descriptor 514 # of the currently set field in that oneof field. 515 self._oneofs = {} 516 517 # _unknown_fields is () when empty for efficiency, and will be turned into 518 # a list if fields are added. 519 self._unknown_fields = () 520 # _unknown_field_set is None when empty for efficiency, and will be 521 # turned into UnknownFieldSet struct if fields are added. 522 self._unknown_field_set = None # pylint: disable=protected-access 523 self._is_present_in_parent = False 524 self._listener = message_listener_mod.NullMessageListener() 525 self._listener_for_children = _Listener(self) 526 for field_name, field_value in kwargs.items(): 527 field = _GetFieldByName(message_descriptor, field_name) 528 if field is None: 529 raise TypeError('%s() got an unexpected keyword argument "%s"' % 530 (message_descriptor.name, field_name)) 531 if field_value is None: 532 # field=None is the same as no field at all. 533 continue 534 if field.label == _FieldDescriptor.LABEL_REPEATED: 535 copy = field._default_constructor(self) 536 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 537 if _IsMapField(field): 538 if _IsMessageMapField(field): 539 for key in field_value: 540 copy[key].MergeFrom(field_value[key]) 541 else: 542 copy.update(field_value) 543 else: 544 for val in field_value: 545 if isinstance(val, dict): 546 copy.add(**val) 547 else: 548 copy.add().MergeFrom(val) 549 else: # Scalar 550 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 551 field_value = [_GetIntegerEnumValue(field.enum_type, val) 552 for val in field_value] 553 copy.extend(field_value) 554 self._fields[field] = copy 555 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 556 copy = field._default_constructor(self) 557 new_val = field_value 558 if isinstance(field_value, dict): 559 new_val = field.message_type._concrete_class(**field_value) 560 try: 561 copy.MergeFrom(new_val) 562 except TypeError: 563 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 564 self._fields[field] = copy 565 else: 566 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 567 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 568 try: 569 setattr(self, field_name, field_value) 570 except TypeError: 571 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 572 573 init.__module__ = None 574 init.__doc__ = None 575 cls.__init__ = init 576 577 578def _GetFieldByName(message_descriptor, field_name): 579 """Returns a field descriptor by field name. 580 581 Args: 582 message_descriptor: A Descriptor describing all fields in message. 583 field_name: The name of the field to retrieve. 584 Returns: 585 The field descriptor associated with the field name. 586 """ 587 try: 588 return message_descriptor.fields_by_name[field_name] 589 except KeyError: 590 raise ValueError('Protocol message %s has no "%s" field.' % 591 (message_descriptor.name, field_name)) 592 593 594def _AddPropertiesForFields(descriptor, cls): 595 """Adds properties for all fields in this protocol message type.""" 596 for field in descriptor.fields: 597 _AddPropertiesForField(field, cls) 598 599 if descriptor.is_extendable: 600 # _ExtensionDict is just an adaptor with no state so we allocate a new one 601 # every time it is accessed. 602 cls.Extensions = property(lambda self: _ExtensionDict(self)) 603 604 605def _AddPropertiesForField(field, cls): 606 """Adds a public property for a protocol message field. 607 Clients can use this property to get and (in the case 608 of non-repeated scalar fields) directly set the value 609 of a protocol message field. 610 611 Args: 612 field: A FieldDescriptor for this field. 613 cls: The class we're constructing. 614 """ 615 # Catch it if we add other types that we should 616 # handle specially here. 617 assert _FieldDescriptor.MAX_CPPTYPE == 10 618 619 constant_name = field.name.upper() + '_FIELD_NUMBER' 620 setattr(cls, constant_name, field.number) 621 622 if field.label == _FieldDescriptor.LABEL_REPEATED: 623 _AddPropertiesForRepeatedField(field, cls) 624 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 625 _AddPropertiesForNonRepeatedCompositeField(field, cls) 626 else: 627 _AddPropertiesForNonRepeatedScalarField(field, cls) 628 629 630class _FieldProperty(property): 631 __slots__ = ('DESCRIPTOR',) 632 633 def __init__(self, descriptor, getter, setter, doc): 634 property.__init__(self, getter, setter, doc=doc) 635 self.DESCRIPTOR = descriptor 636 637 638def _AddPropertiesForRepeatedField(field, cls): 639 """Adds a public property for a "repeated" protocol message field. Clients 640 can use this property to get the value of the field, which will be either a 641 RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see 642 below). 643 644 Note that when clients add values to these containers, we perform 645 type-checking in the case of repeated scalar fields, and we also set any 646 necessary "has" bits as a side-effect. 647 648 Args: 649 field: A FieldDescriptor for this field. 650 cls: The class we're constructing. 651 """ 652 proto_field_name = field.name 653 property_name = _PropertyName(proto_field_name) 654 655 def getter(self): 656 field_value = self._fields.get(field) 657 if field_value is None: 658 # Construct a new object to represent this field. 659 field_value = field._default_constructor(self) 660 661 # Atomically check if another thread has preempted us and, if not, swap 662 # in the new object we just created. If someone has preempted us, we 663 # take that object and discard ours. 664 # WARNING: We are relying on setdefault() being atomic. This is true 665 # in CPython but we haven't investigated others. This warning appears 666 # in several other locations in this file. 667 field_value = self._fields.setdefault(field, field_value) 668 return field_value 669 getter.__module__ = None 670 getter.__doc__ = 'Getter for %s.' % proto_field_name 671 672 # We define a setter just so we can throw an exception with a more 673 # helpful error message. 674 def setter(self, new_value): 675 raise AttributeError('Assignment not allowed to repeated field ' 676 '"%s" in protocol message object.' % proto_field_name) 677 678 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 679 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 680 681 682def _AddPropertiesForNonRepeatedScalarField(field, cls): 683 """Adds a public property for a nonrepeated, scalar protocol message field. 684 Clients can use this property to get and directly set the value of the field. 685 Note that when the client sets the value of a field by using this property, 686 all necessary "has" bits are set as a side-effect, and we also perform 687 type-checking. 688 689 Args: 690 field: A FieldDescriptor for this field. 691 cls: The class we're constructing. 692 """ 693 proto_field_name = field.name 694 property_name = _PropertyName(proto_field_name) 695 type_checker = type_checkers.GetTypeChecker(field) 696 default_value = field.default_value 697 is_proto3 = field.containing_type.syntax == 'proto3' 698 699 def getter(self): 700 # TODO(protobuf-team): This may be broken since there may not be 701 # default_value. Combine with has_default_value somehow. 702 return self._fields.get(field, default_value) 703 getter.__module__ = None 704 getter.__doc__ = 'Getter for %s.' % proto_field_name 705 706 clear_when_set_to_default = is_proto3 and not field.containing_oneof 707 708 def field_setter(self, new_value): 709 # pylint: disable=protected-access 710 # Testing the value for truthiness captures all of the proto3 defaults 711 # (0, 0.0, enum 0, and False). 712 try: 713 new_value = type_checker.CheckValue(new_value) 714 except TypeError as e: 715 raise TypeError( 716 'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e)) 717 if clear_when_set_to_default and not new_value: 718 self._fields.pop(field, None) 719 else: 720 self._fields[field] = new_value 721 # Check _cached_byte_size_dirty inline to improve performance, since scalar 722 # setters are called frequently. 723 if not self._cached_byte_size_dirty: 724 self._Modified() 725 726 if field.containing_oneof: 727 def setter(self, new_value): 728 field_setter(self, new_value) 729 self._UpdateOneofState(field) 730 else: 731 setter = field_setter 732 733 setter.__module__ = None 734 setter.__doc__ = 'Setter for %s.' % proto_field_name 735 736 # Add a property to encapsulate the getter/setter. 737 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 738 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 739 740 741def _AddPropertiesForNonRepeatedCompositeField(field, cls): 742 """Adds a public property for a nonrepeated, composite protocol message field. 743 A composite field is a "group" or "message" field. 744 745 Clients can use this property to get the value of the field, but cannot 746 assign to the property directly. 747 748 Args: 749 field: A FieldDescriptor for this field. 750 cls: The class we're constructing. 751 """ 752 # TODO(robinson): Remove duplication with similar method 753 # for non-repeated scalars. 754 proto_field_name = field.name 755 property_name = _PropertyName(proto_field_name) 756 757 def getter(self): 758 field_value = self._fields.get(field) 759 if field_value is None: 760 # Construct a new object to represent this field. 761 field_value = field._default_constructor(self) 762 763 # Atomically check if another thread has preempted us and, if not, swap 764 # in the new object we just created. If someone has preempted us, we 765 # take that object and discard ours. 766 # WARNING: We are relying on setdefault() being atomic. This is true 767 # in CPython but we haven't investigated others. This warning appears 768 # in several other locations in this file. 769 field_value = self._fields.setdefault(field, field_value) 770 return field_value 771 getter.__module__ = None 772 getter.__doc__ = 'Getter for %s.' % proto_field_name 773 774 # We define a setter just so we can throw an exception with a more 775 # helpful error message. 776 def setter(self, new_value): 777 raise AttributeError('Assignment not allowed to composite field ' 778 '"%s" in protocol message object.' % proto_field_name) 779 780 # Add a property to encapsulate the getter. 781 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 782 setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) 783 784 785def _AddPropertiesForExtensions(descriptor, cls): 786 """Adds properties for all fields in this protocol message type.""" 787 extensions = descriptor.extensions_by_name 788 for extension_name, extension_field in extensions.items(): 789 constant_name = extension_name.upper() + '_FIELD_NUMBER' 790 setattr(cls, constant_name, extension_field.number) 791 792 # TODO(amauryfa): Migrate all users of these attributes to functions like 793 # pool.FindExtensionByNumber(descriptor). 794 if descriptor.file is not None: 795 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 796 pool = descriptor.file.pool 797 cls._extensions_by_number = pool._extensions_by_number[descriptor] 798 cls._extensions_by_name = pool._extensions_by_name[descriptor] 799 800def _AddStaticMethods(cls): 801 # TODO(robinson): This probably needs to be thread-safe(?) 802 def RegisterExtension(extension_handle): 803 extension_handle.containing_type = cls.DESCRIPTOR 804 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available. 805 # pylint: disable=protected-access 806 cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(extension_handle) 807 _AttachFieldHelpers(cls, extension_handle) 808 cls.RegisterExtension = staticmethod(RegisterExtension) 809 810 def FromString(s): 811 message = cls() 812 message.MergeFromString(s) 813 return message 814 cls.FromString = staticmethod(FromString) 815 816 817def _IsPresent(item): 818 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 819 value should be included in the list returned by ListFields().""" 820 821 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 822 return bool(item[1]) 823 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 824 return item[1]._is_present_in_parent 825 else: 826 return True 827 828 829def _AddListFieldsMethod(message_descriptor, cls): 830 """Helper for _AddMessageMethods().""" 831 832 def ListFields(self): 833 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 834 all_fields.sort(key = lambda item: item[0].number) 835 return all_fields 836 837 cls.ListFields = ListFields 838 839_PROTO3_ERROR_TEMPLATE = \ 840 ('Protocol message %s has no non-repeated submessage field "%s" ' 841 'nor marked as optional') 842_PROTO2_ERROR_TEMPLATE = 'Protocol message %s has no non-repeated field "%s"' 843 844def _AddHasFieldMethod(message_descriptor, cls): 845 """Helper for _AddMessageMethods().""" 846 847 is_proto3 = (message_descriptor.syntax == "proto3") 848 error_msg = _PROTO3_ERROR_TEMPLATE if is_proto3 else _PROTO2_ERROR_TEMPLATE 849 850 hassable_fields = {} 851 for field in message_descriptor.fields: 852 if field.label == _FieldDescriptor.LABEL_REPEATED: 853 continue 854 # For proto3, only submessages and fields inside a oneof have presence. 855 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and 856 not field.containing_oneof): 857 continue 858 hassable_fields[field.name] = field 859 860 # Has methods are supported for oneof descriptors. 861 for oneof in message_descriptor.oneofs: 862 hassable_fields[oneof.name] = oneof 863 864 def HasField(self, field_name): 865 try: 866 field = hassable_fields[field_name] 867 except KeyError: 868 raise ValueError(error_msg % (message_descriptor.full_name, field_name)) 869 870 if isinstance(field, descriptor_mod.OneofDescriptor): 871 try: 872 return HasField(self, self._oneofs[field].name) 873 except KeyError: 874 return False 875 else: 876 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 877 value = self._fields.get(field) 878 return value is not None and value._is_present_in_parent 879 else: 880 return field in self._fields 881 882 cls.HasField = HasField 883 884 885def _AddClearFieldMethod(message_descriptor, cls): 886 """Helper for _AddMessageMethods().""" 887 def ClearField(self, field_name): 888 try: 889 field = message_descriptor.fields_by_name[field_name] 890 except KeyError: 891 try: 892 field = message_descriptor.oneofs_by_name[field_name] 893 if field in self._oneofs: 894 field = self._oneofs[field] 895 else: 896 return 897 except KeyError: 898 raise ValueError('Protocol message %s has no "%s" field.' % 899 (message_descriptor.name, field_name)) 900 901 if field in self._fields: 902 # To match the C++ implementation, we need to invalidate iterators 903 # for map fields when ClearField() happens. 904 if hasattr(self._fields[field], 'InvalidateIterators'): 905 self._fields[field].InvalidateIterators() 906 907 # Note: If the field is a sub-message, its listener will still point 908 # at us. That's fine, because the worst than can happen is that it 909 # will call _Modified() and invalidate our byte size. Big deal. 910 del self._fields[field] 911 912 if self._oneofs.get(field.containing_oneof, None) is field: 913 del self._oneofs[field.containing_oneof] 914 915 # Always call _Modified() -- even if nothing was changed, this is 916 # a mutating method, and thus calling it should cause the field to become 917 # present in the parent message. 918 self._Modified() 919 920 cls.ClearField = ClearField 921 922 923def _AddClearExtensionMethod(cls): 924 """Helper for _AddMessageMethods().""" 925 def ClearExtension(self, extension_handle): 926 extension_dict._VerifyExtensionHandle(self, extension_handle) 927 928 # Similar to ClearField(), above. 929 if extension_handle in self._fields: 930 del self._fields[extension_handle] 931 self._Modified() 932 cls.ClearExtension = ClearExtension 933 934 935def _AddHasExtensionMethod(cls): 936 """Helper for _AddMessageMethods().""" 937 def HasExtension(self, extension_handle): 938 extension_dict._VerifyExtensionHandle(self, extension_handle) 939 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 940 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 941 942 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 943 value = self._fields.get(extension_handle) 944 return value is not None and value._is_present_in_parent 945 else: 946 return extension_handle in self._fields 947 cls.HasExtension = HasExtension 948 949def _InternalUnpackAny(msg): 950 """Unpacks Any message and returns the unpacked message. 951 952 This internal method is different from public Any Unpack method which takes 953 the target message as argument. _InternalUnpackAny method does not have 954 target message type and need to find the message type in descriptor pool. 955 956 Args: 957 msg: An Any message to be unpacked. 958 959 Returns: 960 The unpacked message. 961 """ 962 # TODO(amauryfa): Don't use the factory of generated messages. 963 # To make Any work with custom factories, use the message factory of the 964 # parent message. 965 # pylint: disable=g-import-not-at-top 966 from google.protobuf import symbol_database 967 factory = symbol_database.Default() 968 969 type_url = msg.type_url 970 971 if not type_url: 972 return None 973 974 # TODO(haberman): For now we just strip the hostname. Better logic will be 975 # required. 976 type_name = type_url.split('/')[-1] 977 descriptor = factory.pool.FindMessageTypeByName(type_name) 978 979 if descriptor is None: 980 return None 981 982 message_class = factory.GetPrototype(descriptor) 983 message = message_class() 984 985 message.ParseFromString(msg.value) 986 return message 987 988 989def _AddEqualsMethod(message_descriptor, cls): 990 """Helper for _AddMessageMethods().""" 991 def __eq__(self, other): 992 if (not isinstance(other, message_mod.Message) or 993 other.DESCRIPTOR != self.DESCRIPTOR): 994 return False 995 996 if self is other: 997 return True 998 999 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 1000 any_a = _InternalUnpackAny(self) 1001 any_b = _InternalUnpackAny(other) 1002 if any_a and any_b: 1003 return any_a == any_b 1004 1005 if not self.ListFields() == other.ListFields(): 1006 return False 1007 1008 # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions, 1009 # then use it for the comparison. 1010 unknown_fields = list(self._unknown_fields) 1011 unknown_fields.sort() 1012 other_unknown_fields = list(other._unknown_fields) 1013 other_unknown_fields.sort() 1014 return unknown_fields == other_unknown_fields 1015 1016 cls.__eq__ = __eq__ 1017 1018 1019def _AddStrMethod(message_descriptor, cls): 1020 """Helper for _AddMessageMethods().""" 1021 def __str__(self): 1022 return text_format.MessageToString(self) 1023 cls.__str__ = __str__ 1024 1025 1026def _AddReprMethod(message_descriptor, cls): 1027 """Helper for _AddMessageMethods().""" 1028 def __repr__(self): 1029 return text_format.MessageToString(self) 1030 cls.__repr__ = __repr__ 1031 1032 1033def _AddUnicodeMethod(unused_message_descriptor, cls): 1034 """Helper for _AddMessageMethods().""" 1035 1036 def __unicode__(self): 1037 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 1038 cls.__unicode__ = __unicode__ 1039 1040 1041def _BytesForNonRepeatedElement(value, field_number, field_type): 1042 """Returns the number of bytes needed to serialize a non-repeated element. 1043 The returned byte count includes space for tag information and any 1044 other additional space associated with serializing value. 1045 1046 Args: 1047 value: Value we're serializing. 1048 field_number: Field number of this value. (Since the field number 1049 is stored as part of a varint-encoded tag, this has an impact 1050 on the total bytes required to serialize the value). 1051 field_type: The type of the field. One of the TYPE_* constants 1052 within FieldDescriptor. 1053 """ 1054 try: 1055 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 1056 return fn(field_number, value) 1057 except KeyError: 1058 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 1059 1060 1061def _AddByteSizeMethod(message_descriptor, cls): 1062 """Helper for _AddMessageMethods().""" 1063 1064 def ByteSize(self): 1065 if not self._cached_byte_size_dirty: 1066 return self._cached_byte_size 1067 1068 size = 0 1069 descriptor = self.DESCRIPTOR 1070 if descriptor.GetOptions().map_entry: 1071 # Fields of map entry should always be serialized. 1072 size = descriptor.fields_by_name['key']._sizer(self.key) 1073 size += descriptor.fields_by_name['value']._sizer(self.value) 1074 else: 1075 for field_descriptor, field_value in self.ListFields(): 1076 size += field_descriptor._sizer(field_value) 1077 for tag_bytes, value_bytes in self._unknown_fields: 1078 size += len(tag_bytes) + len(value_bytes) 1079 1080 self._cached_byte_size = size 1081 self._cached_byte_size_dirty = False 1082 self._listener_for_children.dirty = False 1083 return size 1084 1085 cls.ByteSize = ByteSize 1086 1087 1088def _AddSerializeToStringMethod(message_descriptor, cls): 1089 """Helper for _AddMessageMethods().""" 1090 1091 def SerializeToString(self, **kwargs): 1092 # Check if the message has all of its required fields set. 1093 if not self.IsInitialized(): 1094 raise message_mod.EncodeError( 1095 'Message %s is missing required fields: %s' % ( 1096 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 1097 return self.SerializePartialToString(**kwargs) 1098 cls.SerializeToString = SerializeToString 1099 1100 1101def _AddSerializePartialToStringMethod(message_descriptor, cls): 1102 """Helper for _AddMessageMethods().""" 1103 1104 def SerializePartialToString(self, **kwargs): 1105 out = BytesIO() 1106 self._InternalSerialize(out.write, **kwargs) 1107 return out.getvalue() 1108 cls.SerializePartialToString = SerializePartialToString 1109 1110 def InternalSerialize(self, write_bytes, deterministic=None): 1111 if deterministic is None: 1112 deterministic = ( 1113 api_implementation.IsPythonDefaultSerializationDeterministic()) 1114 else: 1115 deterministic = bool(deterministic) 1116 1117 descriptor = self.DESCRIPTOR 1118 if descriptor.GetOptions().map_entry: 1119 # Fields of map entry should always be serialized. 1120 descriptor.fields_by_name['key']._encoder( 1121 write_bytes, self.key, deterministic) 1122 descriptor.fields_by_name['value']._encoder( 1123 write_bytes, self.value, deterministic) 1124 else: 1125 for field_descriptor, field_value in self.ListFields(): 1126 field_descriptor._encoder(write_bytes, field_value, deterministic) 1127 for tag_bytes, value_bytes in self._unknown_fields: 1128 write_bytes(tag_bytes) 1129 write_bytes(value_bytes) 1130 cls._InternalSerialize = InternalSerialize 1131 1132 1133def _AddMergeFromStringMethod(message_descriptor, cls): 1134 """Helper for _AddMessageMethods().""" 1135 def MergeFromString(self, serialized): 1136 if isinstance(serialized, memoryview) and six.PY2: 1137 raise TypeError( 1138 'memoryview not supported in Python 2 with the pure Python proto ' 1139 'implementation: this is to maintain compatibility with the C++ ' 1140 'implementation') 1141 1142 serialized = memoryview(serialized) 1143 length = len(serialized) 1144 try: 1145 if self._InternalParse(serialized, 0, length) != length: 1146 # The only reason _InternalParse would return early is if it 1147 # encountered an end-group tag. 1148 raise message_mod.DecodeError('Unexpected end-group tag.') 1149 except (IndexError, TypeError): 1150 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 1151 raise message_mod.DecodeError('Truncated message.') 1152 except struct.error as e: 1153 raise message_mod.DecodeError(e) 1154 return length # Return this for legacy reasons. 1155 cls.MergeFromString = MergeFromString 1156 1157 local_ReadTag = decoder.ReadTag 1158 local_SkipField = decoder.SkipField 1159 decoders_by_tag = cls._decoders_by_tag 1160 1161 def InternalParse(self, buffer, pos, end): 1162 """Create a message from serialized bytes. 1163 1164 Args: 1165 self: Message, instance of the proto message object. 1166 buffer: memoryview of the serialized data. 1167 pos: int, position to start in the serialized data. 1168 end: int, end position of the serialized data. 1169 1170 Returns: 1171 Message object. 1172 """ 1173 # Guard against internal misuse, since this function is called internally 1174 # quite extensively, and its easy to accidentally pass bytes. 1175 assert isinstance(buffer, memoryview) 1176 self._Modified() 1177 field_dict = self._fields 1178 # pylint: disable=protected-access 1179 unknown_field_set = self._unknown_field_set 1180 while pos != end: 1181 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 1182 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) 1183 if field_decoder is None: 1184 if not self._unknown_fields: # pylint: disable=protected-access 1185 self._unknown_fields = [] # pylint: disable=protected-access 1186 if unknown_field_set is None: 1187 # pylint: disable=protected-access 1188 self._unknown_field_set = containers.UnknownFieldSet() 1189 # pylint: disable=protected-access 1190 unknown_field_set = self._unknown_field_set 1191 # pylint: disable=protected-access 1192 (tag, _) = decoder._DecodeVarint(tag_bytes, 0) 1193 field_number, wire_type = wire_format.UnpackTag(tag) 1194 if field_number == 0: 1195 raise message_mod.DecodeError('Field number 0 is illegal.') 1196 # TODO(jieluo): remove old_pos. 1197 old_pos = new_pos 1198 (data, new_pos) = decoder._DecodeUnknownField( 1199 buffer, new_pos, wire_type) # pylint: disable=protected-access 1200 if new_pos == -1: 1201 return pos 1202 # pylint: disable=protected-access 1203 unknown_field_set._add(field_number, wire_type, data) 1204 # TODO(jieluo): remove _unknown_fields. 1205 new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) 1206 if new_pos == -1: 1207 return pos 1208 self._unknown_fields.append( 1209 (tag_bytes, buffer[old_pos:new_pos].tobytes())) 1210 pos = new_pos 1211 else: 1212 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1213 if field_desc: 1214 self._UpdateOneofState(field_desc) 1215 return pos 1216 cls._InternalParse = InternalParse 1217 1218 1219def _AddIsInitializedMethod(message_descriptor, cls): 1220 """Adds the IsInitialized and FindInitializationError methods to the 1221 protocol message class.""" 1222 1223 required_fields = [field for field in message_descriptor.fields 1224 if field.label == _FieldDescriptor.LABEL_REQUIRED] 1225 1226 def IsInitialized(self, errors=None): 1227 """Checks if all required fields of a message are set. 1228 1229 Args: 1230 errors: A list which, if provided, will be populated with the field 1231 paths of all missing required fields. 1232 1233 Returns: 1234 True iff the specified message has all required fields set. 1235 """ 1236 1237 # Performance is critical so we avoid HasField() and ListFields(). 1238 1239 for field in required_fields: 1240 if (field not in self._fields or 1241 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 1242 not self._fields[field]._is_present_in_parent)): 1243 if errors is not None: 1244 errors.extend(self.FindInitializationErrors()) 1245 return False 1246 1247 for field, value in list(self._fields.items()): # dict can change size! 1248 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1249 if field.label == _FieldDescriptor.LABEL_REPEATED: 1250 if (field.message_type.has_options and 1251 field.message_type.GetOptions().map_entry): 1252 continue 1253 for element in value: 1254 if not element.IsInitialized(): 1255 if errors is not None: 1256 errors.extend(self.FindInitializationErrors()) 1257 return False 1258 elif value._is_present_in_parent and not value.IsInitialized(): 1259 if errors is not None: 1260 errors.extend(self.FindInitializationErrors()) 1261 return False 1262 1263 return True 1264 1265 cls.IsInitialized = IsInitialized 1266 1267 def FindInitializationErrors(self): 1268 """Finds required fields which are not initialized. 1269 1270 Returns: 1271 A list of strings. Each string is a path to an uninitialized field from 1272 the top-level message, e.g. "foo.bar[5].baz". 1273 """ 1274 1275 errors = [] # simplify things 1276 1277 for field in required_fields: 1278 if not self.HasField(field.name): 1279 errors.append(field.name) 1280 1281 for field, value in self.ListFields(): 1282 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1283 if field.is_extension: 1284 name = '(%s)' % field.full_name 1285 else: 1286 name = field.name 1287 1288 if _IsMapField(field): 1289 if _IsMessageMapField(field): 1290 for key in value: 1291 element = value[key] 1292 prefix = '%s[%s].' % (name, key) 1293 sub_errors = element.FindInitializationErrors() 1294 errors += [prefix + error for error in sub_errors] 1295 else: 1296 # ScalarMaps can't have any initialization errors. 1297 pass 1298 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1299 for i in range(len(value)): 1300 element = value[i] 1301 prefix = '%s[%d].' % (name, i) 1302 sub_errors = element.FindInitializationErrors() 1303 errors += [prefix + error for error in sub_errors] 1304 else: 1305 prefix = name + '.' 1306 sub_errors = value.FindInitializationErrors() 1307 errors += [prefix + error for error in sub_errors] 1308 1309 return errors 1310 1311 cls.FindInitializationErrors = FindInitializationErrors 1312 1313 1314def _AddMergeFromMethod(cls): 1315 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 1316 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 1317 1318 def MergeFrom(self, msg): 1319 if not isinstance(msg, cls): 1320 raise TypeError( 1321 'Parameter to MergeFrom() must be instance of same class: ' 1322 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__)) 1323 1324 assert msg is not self 1325 self._Modified() 1326 1327 fields = self._fields 1328 1329 for field, value in msg._fields.items(): 1330 if field.label == LABEL_REPEATED: 1331 field_value = fields.get(field) 1332 if field_value is None: 1333 # Construct a new object to represent this field. 1334 field_value = field._default_constructor(self) 1335 fields[field] = field_value 1336 field_value.MergeFrom(value) 1337 elif field.cpp_type == CPPTYPE_MESSAGE: 1338 if value._is_present_in_parent: 1339 field_value = fields.get(field) 1340 if field_value is None: 1341 # Construct a new object to represent this field. 1342 field_value = field._default_constructor(self) 1343 fields[field] = field_value 1344 field_value.MergeFrom(value) 1345 else: 1346 self._fields[field] = value 1347 if field.containing_oneof: 1348 self._UpdateOneofState(field) 1349 1350 if msg._unknown_fields: 1351 if not self._unknown_fields: 1352 self._unknown_fields = [] 1353 self._unknown_fields.extend(msg._unknown_fields) 1354 # pylint: disable=protected-access 1355 if self._unknown_field_set is None: 1356 self._unknown_field_set = containers.UnknownFieldSet() 1357 self._unknown_field_set._extend(msg._unknown_field_set) 1358 1359 cls.MergeFrom = MergeFrom 1360 1361 1362def _AddWhichOneofMethod(message_descriptor, cls): 1363 def WhichOneof(self, oneof_name): 1364 """Returns the name of the currently set field inside a oneof, or None.""" 1365 try: 1366 field = message_descriptor.oneofs_by_name[oneof_name] 1367 except KeyError: 1368 raise ValueError( 1369 'Protocol message has no oneof "%s" field.' % oneof_name) 1370 1371 nested_field = self._oneofs.get(field, None) 1372 if nested_field is not None and self.HasField(nested_field.name): 1373 return nested_field.name 1374 else: 1375 return None 1376 1377 cls.WhichOneof = WhichOneof 1378 1379 1380def _Clear(self): 1381 # Clear fields. 1382 self._fields = {} 1383 self._unknown_fields = () 1384 # pylint: disable=protected-access 1385 if self._unknown_field_set is not None: 1386 self._unknown_field_set._clear() 1387 self._unknown_field_set = None 1388 1389 self._oneofs = {} 1390 self._Modified() 1391 1392 1393def _UnknownFields(self): 1394 if self._unknown_field_set is None: # pylint: disable=protected-access 1395 # pylint: disable=protected-access 1396 self._unknown_field_set = containers.UnknownFieldSet() 1397 return self._unknown_field_set # pylint: disable=protected-access 1398 1399 1400def _DiscardUnknownFields(self): 1401 self._unknown_fields = [] 1402 self._unknown_field_set = None # pylint: disable=protected-access 1403 for field, value in self.ListFields(): 1404 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1405 if _IsMapField(field): 1406 if _IsMessageMapField(field): 1407 for key in value: 1408 value[key].DiscardUnknownFields() 1409 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1410 for sub_message in value: 1411 sub_message.DiscardUnknownFields() 1412 else: 1413 value.DiscardUnknownFields() 1414 1415 1416def _SetListener(self, listener): 1417 if listener is None: 1418 self._listener = message_listener_mod.NullMessageListener() 1419 else: 1420 self._listener = listener 1421 1422 1423def _AddMessageMethods(message_descriptor, cls): 1424 """Adds implementations of all Message methods to cls.""" 1425 _AddListFieldsMethod(message_descriptor, cls) 1426 _AddHasFieldMethod(message_descriptor, cls) 1427 _AddClearFieldMethod(message_descriptor, cls) 1428 if message_descriptor.is_extendable: 1429 _AddClearExtensionMethod(cls) 1430 _AddHasExtensionMethod(cls) 1431 _AddEqualsMethod(message_descriptor, cls) 1432 _AddStrMethod(message_descriptor, cls) 1433 _AddReprMethod(message_descriptor, cls) 1434 _AddUnicodeMethod(message_descriptor, cls) 1435 _AddByteSizeMethod(message_descriptor, cls) 1436 _AddSerializeToStringMethod(message_descriptor, cls) 1437 _AddSerializePartialToStringMethod(message_descriptor, cls) 1438 _AddMergeFromStringMethod(message_descriptor, cls) 1439 _AddIsInitializedMethod(message_descriptor, cls) 1440 _AddMergeFromMethod(cls) 1441 _AddWhichOneofMethod(message_descriptor, cls) 1442 # Adds methods which do not depend on cls. 1443 cls.Clear = _Clear 1444 cls.UnknownFields = _UnknownFields 1445 cls.DiscardUnknownFields = _DiscardUnknownFields 1446 cls._SetListener = _SetListener 1447 1448 1449def _AddPrivateHelperMethods(message_descriptor, cls): 1450 """Adds implementation of private helper methods to cls.""" 1451 1452 def Modified(self): 1453 """Sets the _cached_byte_size_dirty bit to true, 1454 and propagates this to our listener iff this was a state change. 1455 """ 1456 1457 # Note: Some callers check _cached_byte_size_dirty before calling 1458 # _Modified() as an extra optimization. So, if this method is ever 1459 # changed such that it does stuff even when _cached_byte_size_dirty is 1460 # already true, the callers need to be updated. 1461 if not self._cached_byte_size_dirty: 1462 self._cached_byte_size_dirty = True 1463 self._listener_for_children.dirty = True 1464 self._is_present_in_parent = True 1465 self._listener.Modified() 1466 1467 def _UpdateOneofState(self, field): 1468 """Sets field as the active field in its containing oneof. 1469 1470 Will also delete currently active field in the oneof, if it is different 1471 from the argument. Does not mark the message as modified. 1472 """ 1473 other_field = self._oneofs.setdefault(field.containing_oneof, field) 1474 if other_field is not field: 1475 del self._fields[other_field] 1476 self._oneofs[field.containing_oneof] = field 1477 1478 cls._Modified = Modified 1479 cls.SetInParent = Modified 1480 cls._UpdateOneofState = _UpdateOneofState 1481 1482 1483class _Listener(object): 1484 1485 """MessageListener implementation that a parent message registers with its 1486 child message. 1487 1488 In order to support semantics like: 1489 1490 foo.bar.baz.qux = 23 1491 assert foo.HasField('bar') 1492 1493 ...child objects must have back references to their parents. 1494 This helper class is at the heart of this support. 1495 """ 1496 1497 def __init__(self, parent_message): 1498 """Args: 1499 parent_message: The message whose _Modified() method we should call when 1500 we receive Modified() messages. 1501 """ 1502 # This listener establishes a back reference from a child (contained) object 1503 # to its parent (containing) object. We make this a weak reference to avoid 1504 # creating cyclic garbage when the client finishes with the 'parent' object 1505 # in the tree. 1506 if isinstance(parent_message, weakref.ProxyType): 1507 self._parent_message_weakref = parent_message 1508 else: 1509 self._parent_message_weakref = weakref.proxy(parent_message) 1510 1511 # As an optimization, we also indicate directly on the listener whether 1512 # or not the parent message is dirty. This way we can avoid traversing 1513 # up the tree in the common case. 1514 self.dirty = False 1515 1516 def Modified(self): 1517 if self.dirty: 1518 return 1519 try: 1520 # Propagate the signal to our parents iff this is the first field set. 1521 self._parent_message_weakref._Modified() 1522 except ReferenceError: 1523 # We can get here if a client has kept a reference to a child object, 1524 # and is now setting a field on it, but the child's parent has been 1525 # garbage-collected. This is not an error. 1526 pass 1527 1528 1529class _OneofListener(_Listener): 1530 """Special listener implementation for setting composite oneof fields.""" 1531 1532 def __init__(self, parent_message, field): 1533 """Args: 1534 parent_message: The message whose _Modified() method we should call when 1535 we receive Modified() messages. 1536 field: The descriptor of the field being set in the parent message. 1537 """ 1538 super(_OneofListener, self).__init__(parent_message) 1539 self._field = field 1540 1541 def Modified(self): 1542 """Also updates the state of the containing oneof in the parent message.""" 1543 try: 1544 self._parent_message_weakref._UpdateOneofState(self._field) 1545 super(_OneofListener, self).Modified() 1546 except ReferenceError: 1547 pass 1548