1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31# This code is meant to work on Python 2.4 and above only. 32# 33# TODO(robinson): Helpers for verbose, common checks like seeing if a 34# descriptor's cpp_type is CPPTYPE_MESSAGE. 35 36"""Contains a metaclass and helper functions used to create 37protocol message classes from Descriptor objects at runtime. 38 39Recall that a metaclass is the "type" of a class. 40(A class is to a metaclass what an instance is to a class.) 41 42In this case, we use the GeneratedProtocolMessageType metaclass 43to inject all the useful functionality into the classes 44output by the protocol compiler at compile-time. 45 46The upshot of all this is that the real implementation 47details for ALL pure-Python protocol buffers are *here in 48this file*. 49""" 50 51__author__ = 'robinson@google.com (Will Robinson)' 52 53from io import BytesIO 54import sys 55import struct 56import weakref 57 58import six 59try: 60 import six.moves.copyreg as copyreg 61except ImportError: 62 # On some platforms, for example gMac, we run native Python because there is 63 # nothing like hermetic Python. This means lesser control on the system and 64 # the six.moves package may be missing (is missing on 20150321 on gMac). Be 65 # extra conservative and try to load the old replacement if it fails. 66 import copy_reg as copyreg 67 68# We use "as" to avoid name collisions with variables. 69from google.protobuf.internal import containers 70from google.protobuf.internal import decoder 71from google.protobuf.internal import encoder 72from google.protobuf.internal import enum_type_wrapper 73from google.protobuf.internal import message_listener as message_listener_mod 74from google.protobuf.internal import type_checkers 75from google.protobuf.internal import well_known_types 76from google.protobuf.internal import wire_format 77from google.protobuf import descriptor as descriptor_mod 78from google.protobuf import message as message_mod 79from google.protobuf import symbol_database 80from google.protobuf import text_format 81 82_FieldDescriptor = descriptor_mod.FieldDescriptor 83_AnyFullTypeName = 'google.protobuf.Any' 84 85 86class GeneratedProtocolMessageType(type): 87 88 """Metaclass for protocol message classes created at runtime from Descriptors. 89 90 We add implementations for all methods described in the Message class. We 91 also create properties to allow getting/setting all fields in the protocol 92 message. Finally, we create slots to prevent users from accidentally 93 "setting" nonexistent fields in the protocol message, which then wouldn't get 94 serialized / deserialized properly. 95 96 The protocol compiler currently uses this metaclass to create protocol 97 message classes at runtime. Clients can also manually create their own 98 classes at runtime, as in this example: 99 100 mydescriptor = Descriptor(.....) 101 class MyProtoClass(Message): 102 __metaclass__ = GeneratedProtocolMessageType 103 DESCRIPTOR = mydescriptor 104 myproto_instance = MyProtoClass() 105 myproto.foo_field = 23 106 ... 107 108 The above example will not work for nested types. If you wish to include them, 109 use reflection.MakeClass() instead of manually instantiating the class in 110 order to create the appropriate class structure. 111 """ 112 113 # Must be consistent with the protocol-compiler code in 114 # proto2/compiler/internal/generator.*. 115 _DESCRIPTOR_KEY = 'DESCRIPTOR' 116 117 def __new__(cls, name, bases, dictionary): 118 """Custom allocation for runtime-generated class types. 119 120 We override __new__ because this is apparently the only place 121 where we can meaningfully set __slots__ on the class we're creating(?). 122 (The interplay between metaclasses and slots is not very well-documented). 123 124 Args: 125 name: Name of the class (ignored, but required by the 126 metaclass protocol). 127 bases: Base classes of the class we're constructing. 128 (Should be message.Message). We ignore this field, but 129 it's required by the metaclass protocol 130 dictionary: The class dictionary of the class we're 131 constructing. dictionary[_DESCRIPTOR_KEY] must contain 132 a Descriptor object describing this protocol message 133 type. 134 135 Returns: 136 Newly-allocated class. 137 """ 138 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 139 if descriptor.full_name in well_known_types.WKTBASES: 140 bases += (well_known_types.WKTBASES[descriptor.full_name],) 141 _AddClassAttributesForNestedExtensions(descriptor, dictionary) 142 _AddSlots(descriptor, dictionary) 143 144 superclass = super(GeneratedProtocolMessageType, cls) 145 new_class = superclass.__new__(cls, name, bases, dictionary) 146 return new_class 147 148 def __init__(cls, name, bases, dictionary): 149 """Here we perform the majority of our work on the class. 150 We add enum getters, an __init__ method, implementations 151 of all Message methods, and properties for all fields 152 in the protocol type. 153 154 Args: 155 name: Name of the class (ignored, but required by the 156 metaclass protocol). 157 bases: Base classes of the class we're constructing. 158 (Should be message.Message). We ignore this field, but 159 it's required by the metaclass protocol 160 dictionary: The class dictionary of the class we're 161 constructing. dictionary[_DESCRIPTOR_KEY] must contain 162 a Descriptor object describing this protocol message 163 type. 164 """ 165 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] 166 cls._decoders_by_tag = {} 167 cls._extensions_by_name = {} 168 cls._extensions_by_number = {} 169 if (descriptor.has_options and 170 descriptor.GetOptions().message_set_wire_format): 171 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( 172 decoder.MessageSetItemDecoder(cls._extensions_by_number), None) 173 174 # Attach stuff to each FieldDescriptor for quick lookup later on. 175 for field in descriptor.fields: 176 _AttachFieldHelpers(cls, field) 177 178 descriptor._concrete_class = cls # pylint: disable=protected-access 179 _AddEnumValues(descriptor, cls) 180 _AddInitMethod(descriptor, cls) 181 _AddPropertiesForFields(descriptor, cls) 182 _AddPropertiesForExtensions(descriptor, cls) 183 _AddStaticMethods(cls) 184 _AddMessageMethods(descriptor, cls) 185 _AddPrivateHelperMethods(descriptor, cls) 186 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) 187 188 superclass = super(GeneratedProtocolMessageType, cls) 189 superclass.__init__(name, bases, dictionary) 190 191 192# Stateless helpers for GeneratedProtocolMessageType below. 193# Outside clients should not access these directly. 194# 195# I opted not to make any of these methods on the metaclass, to make it more 196# clear that I'm not really using any state there and to keep clients from 197# thinking that they have direct access to these construction helpers. 198 199 200def _PropertyName(proto_field_name): 201 """Returns the name of the public property attribute which 202 clients can use to get and (in some cases) set the value 203 of a protocol message field. 204 205 Args: 206 proto_field_name: The protocol message field name, exactly 207 as it appears (or would appear) in a .proto file. 208 """ 209 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. 210 # nnorwitz makes my day by writing: 211 # """ 212 # FYI. See the keyword module in the stdlib. This could be as simple as: 213 # 214 # if keyword.iskeyword(proto_field_name): 215 # return proto_field_name + "_" 216 # return proto_field_name 217 # """ 218 # Kenton says: The above is a BAD IDEA. People rely on being able to use 219 # getattr() and setattr() to reflectively manipulate field values. If we 220 # rename the properties, then every such user has to also make sure to apply 221 # the same transformation. Note that currently if you name a field "yield", 222 # you can still access it just fine using getattr/setattr -- it's not even 223 # that cumbersome to do so. 224 # TODO(kenton): Remove this method entirely if/when everyone agrees with my 225 # position. 226 return proto_field_name 227 228 229def _VerifyExtensionHandle(message, extension_handle): 230 """Verify that the given extension handle is valid.""" 231 232 if not isinstance(extension_handle, _FieldDescriptor): 233 raise KeyError('HasExtension() expects an extension handle, got: %s' % 234 extension_handle) 235 236 if not extension_handle.is_extension: 237 raise KeyError('"%s" is not an extension.' % extension_handle.full_name) 238 239 if not extension_handle.containing_type: 240 raise KeyError('"%s" is missing a containing_type.' 241 % extension_handle.full_name) 242 243 if extension_handle.containing_type is not message.DESCRIPTOR: 244 raise KeyError('Extension "%s" extends message type "%s", but this ' 245 'message is of type "%s".' % 246 (extension_handle.full_name, 247 extension_handle.containing_type.full_name, 248 message.DESCRIPTOR.full_name)) 249 250 251def _AddSlots(message_descriptor, dictionary): 252 """Adds a __slots__ entry to dictionary, containing the names of all valid 253 attributes for this message type. 254 255 Args: 256 message_descriptor: A Descriptor instance describing this message type. 257 dictionary: Class dictionary to which we'll add a '__slots__' entry. 258 """ 259 dictionary['__slots__'] = ['_cached_byte_size', 260 '_cached_byte_size_dirty', 261 '_fields', 262 '_unknown_fields', 263 '_is_present_in_parent', 264 '_listener', 265 '_listener_for_children', 266 '__weakref__', 267 '_oneofs'] 268 269 270def _IsMessageSetExtension(field): 271 return (field.is_extension and 272 field.containing_type.has_options and 273 field.containing_type.GetOptions().message_set_wire_format and 274 field.type == _FieldDescriptor.TYPE_MESSAGE and 275 field.label == _FieldDescriptor.LABEL_OPTIONAL) 276 277 278def _IsMapField(field): 279 return (field.type == _FieldDescriptor.TYPE_MESSAGE and 280 field.message_type.has_options and 281 field.message_type.GetOptions().map_entry) 282 283 284def _IsMessageMapField(field): 285 value_type = field.message_type.fields_by_name["value"] 286 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE 287 288 289def _AttachFieldHelpers(cls, field_descriptor): 290 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) 291 is_packable = (is_repeated and 292 wire_format.IsTypePackable(field_descriptor.type)) 293 if not is_packable: 294 is_packed = False 295 elif field_descriptor.containing_type.syntax == "proto2": 296 is_packed = (field_descriptor.has_options and 297 field_descriptor.GetOptions().packed) 298 else: 299 has_packed_false = (field_descriptor.has_options and 300 field_descriptor.GetOptions().HasField("packed") and 301 field_descriptor.GetOptions().packed == False) 302 is_packed = not has_packed_false 303 is_map_entry = _IsMapField(field_descriptor) 304 305 if is_map_entry: 306 field_encoder = encoder.MapEncoder(field_descriptor) 307 sizer = encoder.MapSizer(field_descriptor) 308 elif _IsMessageSetExtension(field_descriptor): 309 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) 310 sizer = encoder.MessageSetItemSizer(field_descriptor.number) 311 else: 312 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( 313 field_descriptor.number, is_repeated, is_packed) 314 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( 315 field_descriptor.number, is_repeated, is_packed) 316 317 field_descriptor._encoder = field_encoder 318 field_descriptor._sizer = sizer 319 field_descriptor._default_constructor = _DefaultValueConstructorForField( 320 field_descriptor) 321 322 def AddDecoder(wiretype, is_packed): 323 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) 324 decode_type = field_descriptor.type 325 if (decode_type == _FieldDescriptor.TYPE_ENUM and 326 type_checkers.SupportsOpenEnums(field_descriptor)): 327 decode_type = _FieldDescriptor.TYPE_INT32 328 329 oneof_descriptor = None 330 if field_descriptor.containing_oneof is not None: 331 oneof_descriptor = field_descriptor 332 333 if is_map_entry: 334 is_message_map = _IsMessageMapField(field_descriptor) 335 336 field_decoder = decoder.MapDecoder( 337 field_descriptor, _GetInitializeDefaultForMap(field_descriptor), 338 is_message_map) 339 else: 340 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( 341 field_descriptor.number, is_repeated, is_packed, 342 field_descriptor, field_descriptor._default_constructor) 343 344 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) 345 346 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], 347 False) 348 349 if is_repeated and wire_format.IsTypePackable(field_descriptor.type): 350 # To support wire compatibility of adding packed = true, add a decoder for 351 # packed values regardless of the field's options. 352 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) 353 354 355def _AddClassAttributesForNestedExtensions(descriptor, dictionary): 356 extension_dict = descriptor.extensions_by_name 357 for extension_name, extension_field in extension_dict.items(): 358 assert extension_name not in dictionary 359 dictionary[extension_name] = extension_field 360 361 362def _AddEnumValues(descriptor, cls): 363 """Sets class-level attributes for all enum fields defined in this message. 364 365 Also exporting a class-level object that can name enum values. 366 367 Args: 368 descriptor: Descriptor object for this message type. 369 cls: Class we're constructing for this message type. 370 """ 371 for enum_type in descriptor.enum_types: 372 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) 373 for enum_value in enum_type.values: 374 setattr(cls, enum_value.name, enum_value.number) 375 376 377def _GetInitializeDefaultForMap(field): 378 if field.label != _FieldDescriptor.LABEL_REPEATED: 379 raise ValueError('map_entry set on non-repeated field %s' % ( 380 field.name)) 381 fields_by_name = field.message_type.fields_by_name 382 key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) 383 384 value_field = fields_by_name['value'] 385 if _IsMessageMapField(field): 386 def MakeMessageMapDefault(message): 387 return containers.MessageMap( 388 message._listener_for_children, value_field.message_type, key_checker) 389 return MakeMessageMapDefault 390 else: 391 value_checker = type_checkers.GetTypeChecker(value_field) 392 def MakePrimitiveMapDefault(message): 393 return containers.ScalarMap( 394 message._listener_for_children, key_checker, value_checker) 395 return MakePrimitiveMapDefault 396 397def _DefaultValueConstructorForField(field): 398 """Returns a function which returns a default value for a field. 399 400 Args: 401 field: FieldDescriptor object for this field. 402 403 The returned function has one argument: 404 message: Message instance containing this field, or a weakref proxy 405 of same. 406 407 That function in turn returns a default value for this field. The default 408 value may refer back to |message| via a weak reference. 409 """ 410 411 if _IsMapField(field): 412 return _GetInitializeDefaultForMap(field) 413 414 if field.label == _FieldDescriptor.LABEL_REPEATED: 415 if field.has_default_value and field.default_value != []: 416 raise ValueError('Repeated field default value not empty list: %s' % ( 417 field.default_value)) 418 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 419 # We can't look at _concrete_class yet since it might not have 420 # been set. (Depends on order in which we initialize the classes). 421 message_type = field.message_type 422 def MakeRepeatedMessageDefault(message): 423 return containers.RepeatedCompositeFieldContainer( 424 message._listener_for_children, field.message_type) 425 return MakeRepeatedMessageDefault 426 else: 427 type_checker = type_checkers.GetTypeChecker(field) 428 def MakeRepeatedScalarDefault(message): 429 return containers.RepeatedScalarFieldContainer( 430 message._listener_for_children, type_checker) 431 return MakeRepeatedScalarDefault 432 433 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 434 # _concrete_class may not yet be initialized. 435 message_type = field.message_type 436 def MakeSubMessageDefault(message): 437 result = message_type._concrete_class() 438 result._SetListener( 439 _OneofListener(message, field) 440 if field.containing_oneof is not None 441 else message._listener_for_children) 442 return result 443 return MakeSubMessageDefault 444 445 def MakeScalarDefault(message): 446 # TODO(protobuf-team): This may be broken since there may not be 447 # default_value. Combine with has_default_value somehow. 448 return field.default_value 449 return MakeScalarDefault 450 451 452def _ReraiseTypeErrorWithFieldName(message_name, field_name): 453 """Re-raise the currently-handled TypeError with the field name added.""" 454 exc = sys.exc_info()[1] 455 if len(exc.args) == 1 and type(exc) is TypeError: 456 # simple TypeError; add field name to exception message 457 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name)) 458 459 # re-raise possibly-amended exception with original traceback: 460 six.reraise(type(exc), exc, sys.exc_info()[2]) 461 462 463def _AddInitMethod(message_descriptor, cls): 464 """Adds an __init__ method to cls.""" 465 466 def _GetIntegerEnumValue(enum_type, value): 467 """Convert a string or integer enum value to an integer. 468 469 If the value is a string, it is converted to the enum value in 470 enum_type with the same name. If the value is not a string, it's 471 returned as-is. (No conversion or bounds-checking is done.) 472 """ 473 if isinstance(value, six.string_types): 474 try: 475 return enum_type.values_by_name[value].number 476 except KeyError: 477 raise ValueError('Enum type %s: unknown label "%s"' % ( 478 enum_type.full_name, value)) 479 return value 480 481 def init(self, **kwargs): 482 self._cached_byte_size = 0 483 self._cached_byte_size_dirty = len(kwargs) > 0 484 self._fields = {} 485 # Contains a mapping from oneof field descriptors to the descriptor 486 # of the currently set field in that oneof field. 487 self._oneofs = {} 488 489 # _unknown_fields is () when empty for efficiency, and will be turned into 490 # a list if fields are added. 491 self._unknown_fields = () 492 self._is_present_in_parent = False 493 self._listener = message_listener_mod.NullMessageListener() 494 self._listener_for_children = _Listener(self) 495 for field_name, field_value in kwargs.items(): 496 field = _GetFieldByName(message_descriptor, field_name) 497 if field is None: 498 raise TypeError("%s() got an unexpected keyword argument '%s'" % 499 (message_descriptor.name, field_name)) 500 if field_value is None: 501 # field=None is the same as no field at all. 502 continue 503 if field.label == _FieldDescriptor.LABEL_REPEATED: 504 copy = field._default_constructor(self) 505 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite 506 if _IsMapField(field): 507 if _IsMessageMapField(field): 508 for key in field_value: 509 copy[key].MergeFrom(field_value[key]) 510 else: 511 copy.update(field_value) 512 else: 513 for val in field_value: 514 if isinstance(val, dict): 515 copy.add(**val) 516 else: 517 copy.add().MergeFrom(val) 518 else: # Scalar 519 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 520 field_value = [_GetIntegerEnumValue(field.enum_type, val) 521 for val in field_value] 522 copy.extend(field_value) 523 self._fields[field] = copy 524 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 525 copy = field._default_constructor(self) 526 new_val = field_value 527 if isinstance(field_value, dict): 528 new_val = field.message_type._concrete_class(**field_value) 529 try: 530 copy.MergeFrom(new_val) 531 except TypeError: 532 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 533 self._fields[field] = copy 534 else: 535 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: 536 field_value = _GetIntegerEnumValue(field.enum_type, field_value) 537 try: 538 setattr(self, field_name, field_value) 539 except TypeError: 540 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) 541 542 init.__module__ = None 543 init.__doc__ = None 544 cls.__init__ = init 545 546 547def _GetFieldByName(message_descriptor, field_name): 548 """Returns a field descriptor by field name. 549 550 Args: 551 message_descriptor: A Descriptor describing all fields in message. 552 field_name: The name of the field to retrieve. 553 Returns: 554 The field descriptor associated with the field name. 555 """ 556 try: 557 return message_descriptor.fields_by_name[field_name] 558 except KeyError: 559 raise ValueError('Protocol message %s has no "%s" field.' % 560 (message_descriptor.name, field_name)) 561 562 563def _AddPropertiesForFields(descriptor, cls): 564 """Adds properties for all fields in this protocol message type.""" 565 for field in descriptor.fields: 566 _AddPropertiesForField(field, cls) 567 568 if descriptor.is_extendable: 569 # _ExtensionDict is just an adaptor with no state so we allocate a new one 570 # every time it is accessed. 571 cls.Extensions = property(lambda self: _ExtensionDict(self)) 572 573 574def _AddPropertiesForField(field, cls): 575 """Adds a public property for a protocol message field. 576 Clients can use this property to get and (in the case 577 of non-repeated scalar fields) directly set the value 578 of a protocol message field. 579 580 Args: 581 field: A FieldDescriptor for this field. 582 cls: The class we're constructing. 583 """ 584 # Catch it if we add other types that we should 585 # handle specially here. 586 assert _FieldDescriptor.MAX_CPPTYPE == 10 587 588 constant_name = field.name.upper() + "_FIELD_NUMBER" 589 setattr(cls, constant_name, field.number) 590 591 if field.label == _FieldDescriptor.LABEL_REPEATED: 592 _AddPropertiesForRepeatedField(field, cls) 593 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 594 _AddPropertiesForNonRepeatedCompositeField(field, cls) 595 else: 596 _AddPropertiesForNonRepeatedScalarField(field, cls) 597 598 599def _AddPropertiesForRepeatedField(field, cls): 600 """Adds a public property for a "repeated" protocol message field. Clients 601 can use this property to get the value of the field, which will be either a 602 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see 603 below). 604 605 Note that when clients add values to these containers, we perform 606 type-checking in the case of repeated scalar fields, and we also set any 607 necessary "has" bits as a side-effect. 608 609 Args: 610 field: A FieldDescriptor for this field. 611 cls: The class we're constructing. 612 """ 613 proto_field_name = field.name 614 property_name = _PropertyName(proto_field_name) 615 616 def getter(self): 617 field_value = self._fields.get(field) 618 if field_value is None: 619 # Construct a new object to represent this field. 620 field_value = field._default_constructor(self) 621 622 # Atomically check if another thread has preempted us and, if not, swap 623 # in the new object we just created. If someone has preempted us, we 624 # take that object and discard ours. 625 # WARNING: We are relying on setdefault() being atomic. This is true 626 # in CPython but we haven't investigated others. This warning appears 627 # in several other locations in this file. 628 field_value = self._fields.setdefault(field, field_value) 629 return field_value 630 getter.__module__ = None 631 getter.__doc__ = 'Getter for %s.' % proto_field_name 632 633 # We define a setter just so we can throw an exception with a more 634 # helpful error message. 635 def setter(self, new_value): 636 raise AttributeError('Assignment not allowed to repeated field ' 637 '"%s" in protocol message object.' % proto_field_name) 638 639 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 640 setattr(cls, property_name, property(getter, setter, doc=doc)) 641 642 643def _AddPropertiesForNonRepeatedScalarField(field, cls): 644 """Adds a public property for a nonrepeated, scalar protocol message field. 645 Clients can use this property to get and directly set the value of the field. 646 Note that when the client sets the value of a field by using this property, 647 all necessary "has" bits are set as a side-effect, and we also perform 648 type-checking. 649 650 Args: 651 field: A FieldDescriptor for this field. 652 cls: The class we're constructing. 653 """ 654 proto_field_name = field.name 655 property_name = _PropertyName(proto_field_name) 656 type_checker = type_checkers.GetTypeChecker(field) 657 default_value = field.default_value 658 valid_values = set() 659 is_proto3 = field.containing_type.syntax == "proto3" 660 661 def getter(self): 662 # TODO(protobuf-team): This may be broken since there may not be 663 # default_value. Combine with has_default_value somehow. 664 return self._fields.get(field, default_value) 665 getter.__module__ = None 666 getter.__doc__ = 'Getter for %s.' % proto_field_name 667 668 clear_when_set_to_default = is_proto3 and not field.containing_oneof 669 670 def field_setter(self, new_value): 671 # pylint: disable=protected-access 672 # Testing the value for truthiness captures all of the proto3 defaults 673 # (0, 0.0, enum 0, and False). 674 new_value = type_checker.CheckValue(new_value) 675 if clear_when_set_to_default and not new_value: 676 self._fields.pop(field, None) 677 else: 678 self._fields[field] = new_value 679 # Check _cached_byte_size_dirty inline to improve performance, since scalar 680 # setters are called frequently. 681 if not self._cached_byte_size_dirty: 682 self._Modified() 683 684 if field.containing_oneof: 685 def setter(self, new_value): 686 field_setter(self, new_value) 687 self._UpdateOneofState(field) 688 else: 689 setter = field_setter 690 691 setter.__module__ = None 692 setter.__doc__ = 'Setter for %s.' % proto_field_name 693 694 # Add a property to encapsulate the getter/setter. 695 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 696 setattr(cls, property_name, property(getter, setter, doc=doc)) 697 698 699def _AddPropertiesForNonRepeatedCompositeField(field, cls): 700 """Adds a public property for a nonrepeated, composite protocol message field. 701 A composite field is a "group" or "message" field. 702 703 Clients can use this property to get the value of the field, but cannot 704 assign to the property directly. 705 706 Args: 707 field: A FieldDescriptor for this field. 708 cls: The class we're constructing. 709 """ 710 # TODO(robinson): Remove duplication with similar method 711 # for non-repeated scalars. 712 proto_field_name = field.name 713 property_name = _PropertyName(proto_field_name) 714 715 def getter(self): 716 field_value = self._fields.get(field) 717 if field_value is None: 718 # Construct a new object to represent this field. 719 field_value = field._default_constructor(self) 720 721 # Atomically check if another thread has preempted us and, if not, swap 722 # in the new object we just created. If someone has preempted us, we 723 # take that object and discard ours. 724 # WARNING: We are relying on setdefault() being atomic. This is true 725 # in CPython but we haven't investigated others. This warning appears 726 # in several other locations in this file. 727 field_value = self._fields.setdefault(field, field_value) 728 return field_value 729 getter.__module__ = None 730 getter.__doc__ = 'Getter for %s.' % proto_field_name 731 732 # We define a setter just so we can throw an exception with a more 733 # helpful error message. 734 def setter(self, new_value): 735 raise AttributeError('Assignment not allowed to composite field ' 736 '"%s" in protocol message object.' % proto_field_name) 737 738 # Add a property to encapsulate the getter. 739 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name 740 setattr(cls, property_name, property(getter, setter, doc=doc)) 741 742 743def _AddPropertiesForExtensions(descriptor, cls): 744 """Adds properties for all fields in this protocol message type.""" 745 extension_dict = descriptor.extensions_by_name 746 for extension_name, extension_field in extension_dict.items(): 747 constant_name = extension_name.upper() + "_FIELD_NUMBER" 748 setattr(cls, constant_name, extension_field.number) 749 750 751def _AddStaticMethods(cls): 752 # TODO(robinson): This probably needs to be thread-safe(?) 753 def RegisterExtension(extension_handle): 754 extension_handle.containing_type = cls.DESCRIPTOR 755 _AttachFieldHelpers(cls, extension_handle) 756 757 # Try to insert our extension, failing if an extension with the same number 758 # already exists. 759 actual_handle = cls._extensions_by_number.setdefault( 760 extension_handle.number, extension_handle) 761 if actual_handle is not extension_handle: 762 raise AssertionError( 763 'Extensions "%s" and "%s" both try to extend message type "%s" with ' 764 'field number %d.' % 765 (extension_handle.full_name, actual_handle.full_name, 766 cls.DESCRIPTOR.full_name, extension_handle.number)) 767 768 cls._extensions_by_name[extension_handle.full_name] = extension_handle 769 770 handle = extension_handle # avoid line wrapping 771 if _IsMessageSetExtension(handle): 772 # MessageSet extension. Also register under type name. 773 cls._extensions_by_name[ 774 extension_handle.message_type.full_name] = extension_handle 775 776 cls.RegisterExtension = staticmethod(RegisterExtension) 777 778 def FromString(s): 779 message = cls() 780 message.MergeFromString(s) 781 return message 782 cls.FromString = staticmethod(FromString) 783 784 785def _IsPresent(item): 786 """Given a (FieldDescriptor, value) tuple from _fields, return true if the 787 value should be included in the list returned by ListFields().""" 788 789 if item[0].label == _FieldDescriptor.LABEL_REPEATED: 790 return bool(item[1]) 791 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 792 return item[1]._is_present_in_parent 793 else: 794 return True 795 796 797def _AddListFieldsMethod(message_descriptor, cls): 798 """Helper for _AddMessageMethods().""" 799 800 def ListFields(self): 801 all_fields = [item for item in self._fields.items() if _IsPresent(item)] 802 all_fields.sort(key = lambda item: item[0].number) 803 return all_fields 804 805 cls.ListFields = ListFields 806 807_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' 808_Proto2HasError = 'Protocol message has no non-repeated field "%s"' 809 810def _AddHasFieldMethod(message_descriptor, cls): 811 """Helper for _AddMessageMethods().""" 812 813 is_proto3 = (message_descriptor.syntax == "proto3") 814 error_msg = _Proto3HasError if is_proto3 else _Proto2HasError 815 816 hassable_fields = {} 817 for field in message_descriptor.fields: 818 if field.label == _FieldDescriptor.LABEL_REPEATED: 819 continue 820 # For proto3, only submessages and fields inside a oneof have presence. 821 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and 822 not field.containing_oneof): 823 continue 824 hassable_fields[field.name] = field 825 826 if not is_proto3: 827 # Fields inside oneofs are never repeated (enforced by the compiler). 828 for oneof in message_descriptor.oneofs: 829 hassable_fields[oneof.name] = oneof 830 831 def HasField(self, field_name): 832 try: 833 field = hassable_fields[field_name] 834 except KeyError: 835 raise ValueError(error_msg % field_name) 836 837 if isinstance(field, descriptor_mod.OneofDescriptor): 838 try: 839 return HasField(self, self._oneofs[field].name) 840 except KeyError: 841 return False 842 else: 843 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 844 value = self._fields.get(field) 845 return value is not None and value._is_present_in_parent 846 else: 847 return field in self._fields 848 849 cls.HasField = HasField 850 851 852def _AddClearFieldMethod(message_descriptor, cls): 853 """Helper for _AddMessageMethods().""" 854 def ClearField(self, field_name): 855 try: 856 field = message_descriptor.fields_by_name[field_name] 857 except KeyError: 858 try: 859 field = message_descriptor.oneofs_by_name[field_name] 860 if field in self._oneofs: 861 field = self._oneofs[field] 862 else: 863 return 864 except KeyError: 865 raise ValueError('Protocol message %s() has no "%s" field.' % 866 (message_descriptor.name, field_name)) 867 868 if field in self._fields: 869 # To match the C++ implementation, we need to invalidate iterators 870 # for map fields when ClearField() happens. 871 if hasattr(self._fields[field], 'InvalidateIterators'): 872 self._fields[field].InvalidateIterators() 873 874 # Note: If the field is a sub-message, its listener will still point 875 # at us. That's fine, because the worst than can happen is that it 876 # will call _Modified() and invalidate our byte size. Big deal. 877 del self._fields[field] 878 879 if self._oneofs.get(field.containing_oneof, None) is field: 880 del self._oneofs[field.containing_oneof] 881 882 # Always call _Modified() -- even if nothing was changed, this is 883 # a mutating method, and thus calling it should cause the field to become 884 # present in the parent message. 885 self._Modified() 886 887 cls.ClearField = ClearField 888 889 890def _AddClearExtensionMethod(cls): 891 """Helper for _AddMessageMethods().""" 892 def ClearExtension(self, extension_handle): 893 _VerifyExtensionHandle(self, extension_handle) 894 895 # Similar to ClearField(), above. 896 if extension_handle in self._fields: 897 del self._fields[extension_handle] 898 self._Modified() 899 cls.ClearExtension = ClearExtension 900 901 902def _AddHasExtensionMethod(cls): 903 """Helper for _AddMessageMethods().""" 904 def HasExtension(self, extension_handle): 905 _VerifyExtensionHandle(self, extension_handle) 906 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 907 raise KeyError('"%s" is repeated.' % extension_handle.full_name) 908 909 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 910 value = self._fields.get(extension_handle) 911 return value is not None and value._is_present_in_parent 912 else: 913 return extension_handle in self._fields 914 cls.HasExtension = HasExtension 915 916def _InternalUnpackAny(msg): 917 """Unpacks Any message and returns the unpacked message. 918 919 This internal method is differnt from public Any Unpack method which takes 920 the target message as argument. _InternalUnpackAny method does not have 921 target message type and need to find the message type in descriptor pool. 922 923 Args: 924 msg: An Any message to be unpacked. 925 926 Returns: 927 The unpacked message. 928 """ 929 type_url = msg.type_url 930 db = symbol_database.Default() 931 932 if not type_url: 933 return None 934 935 # TODO(haberman): For now we just strip the hostname. Better logic will be 936 # required. 937 type_name = type_url.split("/")[-1] 938 descriptor = db.pool.FindMessageTypeByName(type_name) 939 940 if descriptor is None: 941 return None 942 943 message_class = db.GetPrototype(descriptor) 944 message = message_class() 945 946 message.ParseFromString(msg.value) 947 return message 948 949def _AddEqualsMethod(message_descriptor, cls): 950 """Helper for _AddMessageMethods().""" 951 def __eq__(self, other): 952 if (not isinstance(other, message_mod.Message) or 953 other.DESCRIPTOR != self.DESCRIPTOR): 954 return False 955 956 if self is other: 957 return True 958 959 if self.DESCRIPTOR.full_name == _AnyFullTypeName: 960 any_a = _InternalUnpackAny(self) 961 any_b = _InternalUnpackAny(other) 962 if any_a and any_b: 963 return any_a == any_b 964 965 if not self.ListFields() == other.ListFields(): 966 return False 967 968 # Sort unknown fields because their order shouldn't affect equality test. 969 unknown_fields = list(self._unknown_fields) 970 unknown_fields.sort() 971 other_unknown_fields = list(other._unknown_fields) 972 other_unknown_fields.sort() 973 974 return unknown_fields == other_unknown_fields 975 976 cls.__eq__ = __eq__ 977 978 979def _AddStrMethod(message_descriptor, cls): 980 """Helper for _AddMessageMethods().""" 981 def __str__(self): 982 return text_format.MessageToString(self) 983 cls.__str__ = __str__ 984 985 986def _AddReprMethod(message_descriptor, cls): 987 """Helper for _AddMessageMethods().""" 988 def __repr__(self): 989 return text_format.MessageToString(self) 990 cls.__repr__ = __repr__ 991 992 993def _AddUnicodeMethod(unused_message_descriptor, cls): 994 """Helper for _AddMessageMethods().""" 995 996 def __unicode__(self): 997 return text_format.MessageToString(self, as_utf8=True).decode('utf-8') 998 cls.__unicode__ = __unicode__ 999 1000 1001def _BytesForNonRepeatedElement(value, field_number, field_type): 1002 """Returns the number of bytes needed to serialize a non-repeated element. 1003 The returned byte count includes space for tag information and any 1004 other additional space associated with serializing value. 1005 1006 Args: 1007 value: Value we're serializing. 1008 field_number: Field number of this value. (Since the field number 1009 is stored as part of a varint-encoded tag, this has an impact 1010 on the total bytes required to serialize the value). 1011 field_type: The type of the field. One of the TYPE_* constants 1012 within FieldDescriptor. 1013 """ 1014 try: 1015 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] 1016 return fn(field_number, value) 1017 except KeyError: 1018 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) 1019 1020 1021def _AddByteSizeMethod(message_descriptor, cls): 1022 """Helper for _AddMessageMethods().""" 1023 1024 def ByteSize(self): 1025 if not self._cached_byte_size_dirty: 1026 return self._cached_byte_size 1027 1028 size = 0 1029 for field_descriptor, field_value in self.ListFields(): 1030 size += field_descriptor._sizer(field_value) 1031 1032 for tag_bytes, value_bytes in self._unknown_fields: 1033 size += len(tag_bytes) + len(value_bytes) 1034 1035 self._cached_byte_size = size 1036 self._cached_byte_size_dirty = False 1037 self._listener_for_children.dirty = False 1038 return size 1039 1040 cls.ByteSize = ByteSize 1041 1042 1043def _AddSerializeToStringMethod(message_descriptor, cls): 1044 """Helper for _AddMessageMethods().""" 1045 1046 def SerializeToString(self): 1047 # Check if the message has all of its required fields set. 1048 errors = [] 1049 if not self.IsInitialized(): 1050 raise message_mod.EncodeError( 1051 'Message %s is missing required fields: %s' % ( 1052 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) 1053 return self.SerializePartialToString() 1054 cls.SerializeToString = SerializeToString 1055 1056 1057def _AddSerializePartialToStringMethod(message_descriptor, cls): 1058 """Helper for _AddMessageMethods().""" 1059 1060 def SerializePartialToString(self): 1061 out = BytesIO() 1062 self._InternalSerialize(out.write) 1063 return out.getvalue() 1064 cls.SerializePartialToString = SerializePartialToString 1065 1066 def InternalSerialize(self, write_bytes): 1067 for field_descriptor, field_value in self.ListFields(): 1068 field_descriptor._encoder(write_bytes, field_value) 1069 for tag_bytes, value_bytes in self._unknown_fields: 1070 write_bytes(tag_bytes) 1071 write_bytes(value_bytes) 1072 cls._InternalSerialize = InternalSerialize 1073 1074 1075def _AddMergeFromStringMethod(message_descriptor, cls): 1076 """Helper for _AddMessageMethods().""" 1077 def MergeFromString(self, serialized): 1078 length = len(serialized) 1079 try: 1080 if self._InternalParse(serialized, 0, length) != length: 1081 # The only reason _InternalParse would return early is if it 1082 # encountered an end-group tag. 1083 raise message_mod.DecodeError('Unexpected end-group tag.') 1084 except (IndexError, TypeError): 1085 # Now ord(buf[p:p+1]) == ord('') gets TypeError. 1086 raise message_mod.DecodeError('Truncated message.') 1087 except struct.error as e: 1088 raise message_mod.DecodeError(e) 1089 return length # Return this for legacy reasons. 1090 cls.MergeFromString = MergeFromString 1091 1092 local_ReadTag = decoder.ReadTag 1093 local_SkipField = decoder.SkipField 1094 decoders_by_tag = cls._decoders_by_tag 1095 is_proto3 = message_descriptor.syntax == "proto3" 1096 1097 def InternalParse(self, buffer, pos, end): 1098 self._Modified() 1099 field_dict = self._fields 1100 unknown_field_list = self._unknown_fields 1101 while pos != end: 1102 (tag_bytes, new_pos) = local_ReadTag(buffer, pos) 1103 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) 1104 if field_decoder is None: 1105 value_start_pos = new_pos 1106 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) 1107 if new_pos == -1: 1108 return pos 1109 if not is_proto3: 1110 if not unknown_field_list: 1111 unknown_field_list = self._unknown_fields = [] 1112 unknown_field_list.append( 1113 (tag_bytes, buffer[value_start_pos:new_pos])) 1114 pos = new_pos 1115 else: 1116 pos = field_decoder(buffer, new_pos, end, self, field_dict) 1117 if field_desc: 1118 self._UpdateOneofState(field_desc) 1119 return pos 1120 cls._InternalParse = InternalParse 1121 1122 1123def _AddIsInitializedMethod(message_descriptor, cls): 1124 """Adds the IsInitialized and FindInitializationError methods to the 1125 protocol message class.""" 1126 1127 required_fields = [field for field in message_descriptor.fields 1128 if field.label == _FieldDescriptor.LABEL_REQUIRED] 1129 1130 def IsInitialized(self, errors=None): 1131 """Checks if all required fields of a message are set. 1132 1133 Args: 1134 errors: A list which, if provided, will be populated with the field 1135 paths of all missing required fields. 1136 1137 Returns: 1138 True iff the specified message has all required fields set. 1139 """ 1140 1141 # Performance is critical so we avoid HasField() and ListFields(). 1142 1143 for field in required_fields: 1144 if (field not in self._fields or 1145 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and 1146 not self._fields[field]._is_present_in_parent)): 1147 if errors is not None: 1148 errors.extend(self.FindInitializationErrors()) 1149 return False 1150 1151 for field, value in list(self._fields.items()): # dict can change size! 1152 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1153 if field.label == _FieldDescriptor.LABEL_REPEATED: 1154 if (field.message_type.has_options and 1155 field.message_type.GetOptions().map_entry): 1156 continue 1157 for element in value: 1158 if not element.IsInitialized(): 1159 if errors is not None: 1160 errors.extend(self.FindInitializationErrors()) 1161 return False 1162 elif value._is_present_in_parent and not value.IsInitialized(): 1163 if errors is not None: 1164 errors.extend(self.FindInitializationErrors()) 1165 return False 1166 1167 return True 1168 1169 cls.IsInitialized = IsInitialized 1170 1171 def FindInitializationErrors(self): 1172 """Finds required fields which are not initialized. 1173 1174 Returns: 1175 A list of strings. Each string is a path to an uninitialized field from 1176 the top-level message, e.g. "foo.bar[5].baz". 1177 """ 1178 1179 errors = [] # simplify things 1180 1181 for field in required_fields: 1182 if not self.HasField(field.name): 1183 errors.append(field.name) 1184 1185 for field, value in self.ListFields(): 1186 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1187 if field.is_extension: 1188 name = "(%s)" % field.full_name 1189 else: 1190 name = field.name 1191 1192 if _IsMapField(field): 1193 if _IsMessageMapField(field): 1194 for key in value: 1195 element = value[key] 1196 prefix = "%s[%s]." % (name, key) 1197 sub_errors = element.FindInitializationErrors() 1198 errors += [prefix + error for error in sub_errors] 1199 else: 1200 # ScalarMaps can't have any initialization errors. 1201 pass 1202 elif field.label == _FieldDescriptor.LABEL_REPEATED: 1203 for i in range(len(value)): 1204 element = value[i] 1205 prefix = "%s[%d]." % (name, i) 1206 sub_errors = element.FindInitializationErrors() 1207 errors += [prefix + error for error in sub_errors] 1208 else: 1209 prefix = name + "." 1210 sub_errors = value.FindInitializationErrors() 1211 errors += [prefix + error for error in sub_errors] 1212 1213 return errors 1214 1215 cls.FindInitializationErrors = FindInitializationErrors 1216 1217 1218def _AddMergeFromMethod(cls): 1219 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED 1220 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE 1221 1222 def MergeFrom(self, msg): 1223 if not isinstance(msg, cls): 1224 raise TypeError( 1225 "Parameter to MergeFrom() must be instance of same class: " 1226 "expected %s got %s." % (cls.__name__, type(msg).__name__)) 1227 1228 assert msg is not self 1229 self._Modified() 1230 1231 fields = self._fields 1232 1233 for field, value in msg._fields.items(): 1234 if field.label == LABEL_REPEATED: 1235 field_value = fields.get(field) 1236 if field_value is None: 1237 # Construct a new object to represent this field. 1238 field_value = field._default_constructor(self) 1239 fields[field] = field_value 1240 field_value.MergeFrom(value) 1241 elif field.cpp_type == CPPTYPE_MESSAGE: 1242 if value._is_present_in_parent: 1243 field_value = fields.get(field) 1244 if field_value is None: 1245 # Construct a new object to represent this field. 1246 field_value = field._default_constructor(self) 1247 fields[field] = field_value 1248 field_value.MergeFrom(value) 1249 else: 1250 self._fields[field] = value 1251 if field.containing_oneof: 1252 self._UpdateOneofState(field) 1253 1254 if msg._unknown_fields: 1255 if not self._unknown_fields: 1256 self._unknown_fields = [] 1257 self._unknown_fields.extend(msg._unknown_fields) 1258 1259 cls.MergeFrom = MergeFrom 1260 1261 1262def _AddWhichOneofMethod(message_descriptor, cls): 1263 def WhichOneof(self, oneof_name): 1264 """Returns the name of the currently set field inside a oneof, or None.""" 1265 try: 1266 field = message_descriptor.oneofs_by_name[oneof_name] 1267 except KeyError: 1268 raise ValueError( 1269 'Protocol message has no oneof "%s" field.' % oneof_name) 1270 1271 nested_field = self._oneofs.get(field, None) 1272 if nested_field is not None and self.HasField(nested_field.name): 1273 return nested_field.name 1274 else: 1275 return None 1276 1277 cls.WhichOneof = WhichOneof 1278 1279 1280def _Clear(self): 1281 # Clear fields. 1282 self._fields = {} 1283 self._unknown_fields = () 1284 self._oneofs = {} 1285 self._Modified() 1286 1287 1288def _DiscardUnknownFields(self): 1289 self._unknown_fields = [] 1290 for field, value in self.ListFields(): 1291 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1292 if field.label == _FieldDescriptor.LABEL_REPEATED: 1293 for sub_message in value: 1294 sub_message.DiscardUnknownFields() 1295 else: 1296 value.DiscardUnknownFields() 1297 1298 1299def _SetListener(self, listener): 1300 if listener is None: 1301 self._listener = message_listener_mod.NullMessageListener() 1302 else: 1303 self._listener = listener 1304 1305 1306def _AddMessageMethods(message_descriptor, cls): 1307 """Adds implementations of all Message methods to cls.""" 1308 _AddListFieldsMethod(message_descriptor, cls) 1309 _AddHasFieldMethod(message_descriptor, cls) 1310 _AddClearFieldMethod(message_descriptor, cls) 1311 if message_descriptor.is_extendable: 1312 _AddClearExtensionMethod(cls) 1313 _AddHasExtensionMethod(cls) 1314 _AddEqualsMethod(message_descriptor, cls) 1315 _AddStrMethod(message_descriptor, cls) 1316 _AddReprMethod(message_descriptor, cls) 1317 _AddUnicodeMethod(message_descriptor, cls) 1318 _AddByteSizeMethod(message_descriptor, cls) 1319 _AddSerializeToStringMethod(message_descriptor, cls) 1320 _AddSerializePartialToStringMethod(message_descriptor, cls) 1321 _AddMergeFromStringMethod(message_descriptor, cls) 1322 _AddIsInitializedMethod(message_descriptor, cls) 1323 _AddMergeFromMethod(cls) 1324 _AddWhichOneofMethod(message_descriptor, cls) 1325 # Adds methods which do not depend on cls. 1326 cls.Clear = _Clear 1327 cls.DiscardUnknownFields = _DiscardUnknownFields 1328 cls._SetListener = _SetListener 1329 1330 1331def _AddPrivateHelperMethods(message_descriptor, cls): 1332 """Adds implementation of private helper methods to cls.""" 1333 1334 def Modified(self): 1335 """Sets the _cached_byte_size_dirty bit to true, 1336 and propagates this to our listener iff this was a state change. 1337 """ 1338 1339 # Note: Some callers check _cached_byte_size_dirty before calling 1340 # _Modified() as an extra optimization. So, if this method is ever 1341 # changed such that it does stuff even when _cached_byte_size_dirty is 1342 # already true, the callers need to be updated. 1343 if not self._cached_byte_size_dirty: 1344 self._cached_byte_size_dirty = True 1345 self._listener_for_children.dirty = True 1346 self._is_present_in_parent = True 1347 self._listener.Modified() 1348 1349 def _UpdateOneofState(self, field): 1350 """Sets field as the active field in its containing oneof. 1351 1352 Will also delete currently active field in the oneof, if it is different 1353 from the argument. Does not mark the message as modified. 1354 """ 1355 other_field = self._oneofs.setdefault(field.containing_oneof, field) 1356 if other_field is not field: 1357 del self._fields[other_field] 1358 self._oneofs[field.containing_oneof] = field 1359 1360 cls._Modified = Modified 1361 cls.SetInParent = Modified 1362 cls._UpdateOneofState = _UpdateOneofState 1363 1364 1365class _Listener(object): 1366 1367 """MessageListener implementation that a parent message registers with its 1368 child message. 1369 1370 In order to support semantics like: 1371 1372 foo.bar.baz.qux = 23 1373 assert foo.HasField('bar') 1374 1375 ...child objects must have back references to their parents. 1376 This helper class is at the heart of this support. 1377 """ 1378 1379 def __init__(self, parent_message): 1380 """Args: 1381 parent_message: The message whose _Modified() method we should call when 1382 we receive Modified() messages. 1383 """ 1384 # This listener establishes a back reference from a child (contained) object 1385 # to its parent (containing) object. We make this a weak reference to avoid 1386 # creating cyclic garbage when the client finishes with the 'parent' object 1387 # in the tree. 1388 if isinstance(parent_message, weakref.ProxyType): 1389 self._parent_message_weakref = parent_message 1390 else: 1391 self._parent_message_weakref = weakref.proxy(parent_message) 1392 1393 # As an optimization, we also indicate directly on the listener whether 1394 # or not the parent message is dirty. This way we can avoid traversing 1395 # up the tree in the common case. 1396 self.dirty = False 1397 1398 def Modified(self): 1399 if self.dirty: 1400 return 1401 try: 1402 # Propagate the signal to our parents iff this is the first field set. 1403 self._parent_message_weakref._Modified() 1404 except ReferenceError: 1405 # We can get here if a client has kept a reference to a child object, 1406 # and is now setting a field on it, but the child's parent has been 1407 # garbage-collected. This is not an error. 1408 pass 1409 1410 1411class _OneofListener(_Listener): 1412 """Special listener implementation for setting composite oneof fields.""" 1413 1414 def __init__(self, parent_message, field): 1415 """Args: 1416 parent_message: The message whose _Modified() method we should call when 1417 we receive Modified() messages. 1418 field: The descriptor of the field being set in the parent message. 1419 """ 1420 super(_OneofListener, self).__init__(parent_message) 1421 self._field = field 1422 1423 def Modified(self): 1424 """Also updates the state of the containing oneof in the parent message.""" 1425 try: 1426 self._parent_message_weakref._UpdateOneofState(self._field) 1427 super(_OneofListener, self).Modified() 1428 except ReferenceError: 1429 pass 1430 1431 1432# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... 1433# TODO(robinson): Unify error handling of "unknown extension" crap. 1434# TODO(robinson): Support iteritems()-style iteration over all 1435# extensions with the "has" bits turned on? 1436class _ExtensionDict(object): 1437 1438 """Dict-like container for supporting an indexable "Extensions" 1439 field on proto instances. 1440 1441 Note that in all cases we expect extension handles to be 1442 FieldDescriptors. 1443 """ 1444 1445 def __init__(self, extended_message): 1446 """extended_message: Message instance for which we are the Extensions dict. 1447 """ 1448 1449 self._extended_message = extended_message 1450 1451 def __getitem__(self, extension_handle): 1452 """Returns the current value of the given extension handle.""" 1453 1454 _VerifyExtensionHandle(self._extended_message, extension_handle) 1455 1456 result = self._extended_message._fields.get(extension_handle) 1457 if result is not None: 1458 return result 1459 1460 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: 1461 result = extension_handle._default_constructor(self._extended_message) 1462 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: 1463 result = extension_handle.message_type._concrete_class() 1464 try: 1465 result._SetListener(self._extended_message._listener_for_children) 1466 except ReferenceError: 1467 pass 1468 else: 1469 # Singular scalar -- just return the default without inserting into the 1470 # dict. 1471 return extension_handle.default_value 1472 1473 # Atomically check if another thread has preempted us and, if not, swap 1474 # in the new object we just created. If someone has preempted us, we 1475 # take that object and discard ours. 1476 # WARNING: We are relying on setdefault() being atomic. This is true 1477 # in CPython but we haven't investigated others. This warning appears 1478 # in several other locations in this file. 1479 result = self._extended_message._fields.setdefault( 1480 extension_handle, result) 1481 1482 return result 1483 1484 def __eq__(self, other): 1485 if not isinstance(other, self.__class__): 1486 return False 1487 1488 my_fields = self._extended_message.ListFields() 1489 other_fields = other._extended_message.ListFields() 1490 1491 # Get rid of non-extension fields. 1492 my_fields = [ field for field in my_fields if field.is_extension ] 1493 other_fields = [ field for field in other_fields if field.is_extension ] 1494 1495 return my_fields == other_fields 1496 1497 def __ne__(self, other): 1498 return not self == other 1499 1500 def __hash__(self): 1501 raise TypeError('unhashable object') 1502 1503 # Note that this is only meaningful for non-repeated, scalar extension 1504 # fields. Note also that we may have to call _Modified() when we do 1505 # successfully set a field this way, to set any necssary "has" bits in the 1506 # ancestors of the extended message. 1507 def __setitem__(self, extension_handle, value): 1508 """If extension_handle specifies a non-repeated, scalar extension 1509 field, sets the value of that field. 1510 """ 1511 1512 _VerifyExtensionHandle(self._extended_message, extension_handle) 1513 1514 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or 1515 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): 1516 raise TypeError( 1517 'Cannot assign to extension "%s" because it is a repeated or ' 1518 'composite type.' % extension_handle.full_name) 1519 1520 # It's slightly wasteful to lookup the type checker each time, 1521 # but we expect this to be a vanishingly uncommon case anyway. 1522 type_checker = type_checkers.GetTypeChecker(extension_handle) 1523 # pylint: disable=protected-access 1524 self._extended_message._fields[extension_handle] = ( 1525 type_checker.CheckValue(value)) 1526 self._extended_message._Modified() 1527 1528 def _FindExtensionByName(self, name): 1529 """Tries to find a known extension with the specified name. 1530 1531 Args: 1532 name: Extension full name. 1533 1534 Returns: 1535 Extension field descriptor. 1536 """ 1537 return self._extended_message._extensions_by_name.get(name, None) 1538 1539 def _FindExtensionByNumber(self, number): 1540 """Tries to find a known extension with the field number. 1541 1542 Args: 1543 number: Extension field number. 1544 1545 Returns: 1546 Extension field descriptor. 1547 """ 1548 return self._extended_message._extensions_by_number.get(number, None) 1549