• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53from io import BytesIO
54import struct
55import sys
56import weakref
57
58import six
59from six.moves import range
60
61# We use "as" to avoid name collisions with variables.
62from google.protobuf.internal import api_implementation
63from google.protobuf.internal import containers
64from google.protobuf.internal import decoder
65from google.protobuf.internal import encoder
66from google.protobuf.internal import enum_type_wrapper
67from google.protobuf.internal import extension_dict
68from google.protobuf.internal import message_listener as message_listener_mod
69from google.protobuf.internal import type_checkers
70from google.protobuf.internal import well_known_types
71from google.protobuf.internal import wire_format
72from google.protobuf import descriptor as descriptor_mod
73from google.protobuf import message as message_mod
74from google.protobuf import text_format
75
76_FieldDescriptor = descriptor_mod.FieldDescriptor
77_AnyFullTypeName = 'google.protobuf.Any'
78_ExtensionDict = extension_dict._ExtensionDict
79
80class GeneratedProtocolMessageType(type):
81
82  """Metaclass for protocol message classes created at runtime from Descriptors.
83
84  We add implementations for all methods described in the Message class.  We
85  also create properties to allow getting/setting all fields in the protocol
86  message.  Finally, we create slots to prevent users from accidentally
87  "setting" nonexistent fields in the protocol message, which then wouldn't get
88  serialized / deserialized properly.
89
90  The protocol compiler currently uses this metaclass to create protocol
91  message classes at runtime.  Clients can also manually create their own
92  classes at runtime, as in this example:
93
94  mydescriptor = Descriptor(.....)
95  factory = symbol_database.Default()
96  factory.pool.AddDescriptor(mydescriptor)
97  MyProtoClass = factory.GetPrototype(mydescriptor)
98  myproto_instance = MyProtoClass()
99  myproto.foo_field = 23
100  ...
101  """
102
103  # Must be consistent with the protocol-compiler code in
104  # proto2/compiler/internal/generator.*.
105  _DESCRIPTOR_KEY = 'DESCRIPTOR'
106
107  def __new__(cls, name, bases, dictionary):
108    """Custom allocation for runtime-generated class types.
109
110    We override __new__ because this is apparently the only place
111    where we can meaningfully set __slots__ on the class we're creating(?).
112    (The interplay between metaclasses and slots is not very well-documented).
113
114    Args:
115      name: Name of the class (ignored, but required by the
116        metaclass protocol).
117      bases: Base classes of the class we're constructing.
118        (Should be message.Message).  We ignore this field, but
119        it's required by the metaclass protocol
120      dictionary: The class dictionary of the class we're
121        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
122        a Descriptor object describing this protocol message
123        type.
124
125    Returns:
126      Newly-allocated class.
127    """
128    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
129
130    # If a concrete class already exists for this descriptor, don't try to
131    # create another.  Doing so will break any messages that already exist with
132    # the existing class.
133    #
134    # The C++ implementation appears to have its own internal `PyMessageFactory`
135    # to achieve similar results.
136    #
137    # This most commonly happens in `text_format.py` when using descriptors from
138    # a custom pool; it calls symbol_database.Global().getPrototype() on a
139    # descriptor which already has an existing concrete class.
140    new_class = getattr(descriptor, '_concrete_class', None)
141    if new_class:
142      return new_class
143
144    if descriptor.full_name in well_known_types.WKTBASES:
145      bases += (well_known_types.WKTBASES[descriptor.full_name],)
146    _AddClassAttributesForNestedExtensions(descriptor, dictionary)
147    _AddSlots(descriptor, dictionary)
148
149    superclass = super(GeneratedProtocolMessageType, cls)
150    new_class = superclass.__new__(cls, name, bases, dictionary)
151    return new_class
152
153  def __init__(cls, name, bases, dictionary):
154    """Here we perform the majority of our work on the class.
155    We add enum getters, an __init__ method, implementations
156    of all Message methods, and properties for all fields
157    in the protocol type.
158
159    Args:
160      name: Name of the class (ignored, but required by the
161        metaclass protocol).
162      bases: Base classes of the class we're constructing.
163        (Should be message.Message).  We ignore this field, but
164        it's required by the metaclass protocol
165      dictionary: The class dictionary of the class we're
166        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
167        a Descriptor object describing this protocol message
168        type.
169    """
170    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
171
172    # If this is an _existing_ class looked up via `_concrete_class` in the
173    # __new__ method above, then we don't need to re-initialize anything.
174    existing_class = getattr(descriptor, '_concrete_class', None)
175    if existing_class:
176      assert existing_class is cls, (
177          'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
178          % (descriptor.full_name))
179      return
180
181    cls._decoders_by_tag = {}
182    if (descriptor.has_options and
183        descriptor.GetOptions().message_set_wire_format):
184      cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
185          decoder.MessageSetItemDecoder(descriptor), None)
186
187    # Attach stuff to each FieldDescriptor for quick lookup later on.
188    for field in descriptor.fields:
189      _AttachFieldHelpers(cls, field)
190
191    descriptor._concrete_class = cls  # pylint: disable=protected-access
192    _AddEnumValues(descriptor, cls)
193    _AddInitMethod(descriptor, cls)
194    _AddPropertiesForFields(descriptor, cls)
195    _AddPropertiesForExtensions(descriptor, cls)
196    _AddStaticMethods(cls)
197    _AddMessageMethods(descriptor, cls)
198    _AddPrivateHelperMethods(descriptor, cls)
199
200    superclass = super(GeneratedProtocolMessageType, cls)
201    superclass.__init__(name, bases, dictionary)
202
203
204# Stateless helpers for GeneratedProtocolMessageType below.
205# Outside clients should not access these directly.
206#
207# I opted not to make any of these methods on the metaclass, to make it more
208# clear that I'm not really using any state there and to keep clients from
209# thinking that they have direct access to these construction helpers.
210
211
212def _PropertyName(proto_field_name):
213  """Returns the name of the public property attribute which
214  clients can use to get and (in some cases) set the value
215  of a protocol message field.
216
217  Args:
218    proto_field_name: The protocol message field name, exactly
219      as it appears (or would appear) in a .proto file.
220  """
221  # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
222  # nnorwitz makes my day by writing:
223  # """
224  # FYI.  See the keyword module in the stdlib. This could be as simple as:
225  #
226  # if keyword.iskeyword(proto_field_name):
227  #   return proto_field_name + "_"
228  # return proto_field_name
229  # """
230  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
231  #   getattr() and setattr() to reflectively manipulate field values.  If we
232  #   rename the properties, then every such user has to also make sure to apply
233  #   the same transformation.  Note that currently if you name a field "yield",
234  #   you can still access it just fine using getattr/setattr -- it's not even
235  #   that cumbersome to do so.
236  # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
237  #   position.
238  return proto_field_name
239
240
241def _AddSlots(message_descriptor, dictionary):
242  """Adds a __slots__ entry to dictionary, containing the names of all valid
243  attributes for this message type.
244
245  Args:
246    message_descriptor: A Descriptor instance describing this message type.
247    dictionary: Class dictionary to which we'll add a '__slots__' entry.
248  """
249  dictionary['__slots__'] = ['_cached_byte_size',
250                             '_cached_byte_size_dirty',
251                             '_fields',
252                             '_unknown_fields',
253                             '_unknown_field_set',
254                             '_is_present_in_parent',
255                             '_listener',
256                             '_listener_for_children',
257                             '__weakref__',
258                             '_oneofs']
259
260
261def _IsMessageSetExtension(field):
262  return (field.is_extension and
263          field.containing_type.has_options and
264          field.containing_type.GetOptions().message_set_wire_format and
265          field.type == _FieldDescriptor.TYPE_MESSAGE and
266          field.label == _FieldDescriptor.LABEL_OPTIONAL)
267
268
269def _IsMapField(field):
270  return (field.type == _FieldDescriptor.TYPE_MESSAGE and
271          field.message_type.has_options and
272          field.message_type.GetOptions().map_entry)
273
274
275def _IsMessageMapField(field):
276  value_type = field.message_type.fields_by_name['value']
277  return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
278
279
280def _IsStrictUtf8Check(field):
281  if field.containing_type.syntax != 'proto3':
282    return False
283  enforce_utf8 = True
284  return enforce_utf8
285
286
287def _AttachFieldHelpers(cls, field_descriptor):
288  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
289  is_packable = (is_repeated and
290                 wire_format.IsTypePackable(field_descriptor.type))
291  if not is_packable:
292    is_packed = False
293  elif field_descriptor.containing_type.syntax == 'proto2':
294    is_packed = (field_descriptor.has_options and
295                field_descriptor.GetOptions().packed)
296  else:
297    has_packed_false = (field_descriptor.has_options and
298                        field_descriptor.GetOptions().HasField('packed') and
299                        field_descriptor.GetOptions().packed == False)
300    is_packed = not has_packed_false
301  is_map_entry = _IsMapField(field_descriptor)
302
303  if is_map_entry:
304    field_encoder = encoder.MapEncoder(field_descriptor)
305    sizer = encoder.MapSizer(field_descriptor,
306                             _IsMessageMapField(field_descriptor))
307  elif _IsMessageSetExtension(field_descriptor):
308    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
309    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
310  else:
311    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
312        field_descriptor.number, is_repeated, is_packed)
313    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
314        field_descriptor.number, is_repeated, is_packed)
315
316  field_descriptor._encoder = field_encoder
317  field_descriptor._sizer = sizer
318  field_descriptor._default_constructor = _DefaultValueConstructorForField(
319      field_descriptor)
320
321  def AddDecoder(wiretype, is_packed):
322    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
323    decode_type = field_descriptor.type
324    if (decode_type == _FieldDescriptor.TYPE_ENUM and
325        type_checkers.SupportsOpenEnums(field_descriptor)):
326      decode_type = _FieldDescriptor.TYPE_INT32
327
328    oneof_descriptor = None
329    if field_descriptor.containing_oneof is not None:
330      oneof_descriptor = field_descriptor
331
332    if is_map_entry:
333      is_message_map = _IsMessageMapField(field_descriptor)
334
335      field_decoder = decoder.MapDecoder(
336          field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
337          is_message_map)
338    elif decode_type == _FieldDescriptor.TYPE_STRING:
339      is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor)
340      field_decoder = decoder.StringDecoder(
341          field_descriptor.number, is_repeated, is_packed,
342          field_descriptor, field_descriptor._default_constructor,
343          is_strict_utf8_check)
344    else:
345      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
346          field_descriptor.number, is_repeated, is_packed,
347          field_descriptor, field_descriptor._default_constructor)
348
349    cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
350
351  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
352             False)
353
354  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
355    # To support wire compatibility of adding packed = true, add a decoder for
356    # packed values regardless of the field's options.
357    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
358
359
360def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
361  extensions = descriptor.extensions_by_name
362  for extension_name, extension_field in extensions.items():
363    assert extension_name not in dictionary
364    dictionary[extension_name] = extension_field
365
366
367def _AddEnumValues(descriptor, cls):
368  """Sets class-level attributes for all enum fields defined in this message.
369
370  Also exporting a class-level object that can name enum values.
371
372  Args:
373    descriptor: Descriptor object for this message type.
374    cls: Class we're constructing for this message type.
375  """
376  for enum_type in descriptor.enum_types:
377    setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
378    for enum_value in enum_type.values:
379      setattr(cls, enum_value.name, enum_value.number)
380
381
382def _GetInitializeDefaultForMap(field):
383  if field.label != _FieldDescriptor.LABEL_REPEATED:
384    raise ValueError('map_entry set on non-repeated field %s' % (
385        field.name))
386  fields_by_name = field.message_type.fields_by_name
387  key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
388
389  value_field = fields_by_name['value']
390  if _IsMessageMapField(field):
391    def MakeMessageMapDefault(message):
392      return containers.MessageMap(
393          message._listener_for_children, value_field.message_type, key_checker,
394          field.message_type)
395    return MakeMessageMapDefault
396  else:
397    value_checker = type_checkers.GetTypeChecker(value_field)
398    def MakePrimitiveMapDefault(message):
399      return containers.ScalarMap(
400          message._listener_for_children, key_checker, value_checker,
401          field.message_type)
402    return MakePrimitiveMapDefault
403
404def _DefaultValueConstructorForField(field):
405  """Returns a function which returns a default value for a field.
406
407  Args:
408    field: FieldDescriptor object for this field.
409
410  The returned function has one argument:
411    message: Message instance containing this field, or a weakref proxy
412      of same.
413
414  That function in turn returns a default value for this field.  The default
415    value may refer back to |message| via a weak reference.
416  """
417
418  if _IsMapField(field):
419    return _GetInitializeDefaultForMap(field)
420
421  if field.label == _FieldDescriptor.LABEL_REPEATED:
422    if field.has_default_value and field.default_value != []:
423      raise ValueError('Repeated field default value not empty list: %s' % (
424          field.default_value))
425    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
426      # We can't look at _concrete_class yet since it might not have
427      # been set.  (Depends on order in which we initialize the classes).
428      message_type = field.message_type
429      def MakeRepeatedMessageDefault(message):
430        return containers.RepeatedCompositeFieldContainer(
431            message._listener_for_children, field.message_type)
432      return MakeRepeatedMessageDefault
433    else:
434      type_checker = type_checkers.GetTypeChecker(field)
435      def MakeRepeatedScalarDefault(message):
436        return containers.RepeatedScalarFieldContainer(
437            message._listener_for_children, type_checker)
438      return MakeRepeatedScalarDefault
439
440  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
441    # _concrete_class may not yet be initialized.
442    message_type = field.message_type
443    def MakeSubMessageDefault(message):
444      assert getattr(message_type, '_concrete_class', None), (
445          'Uninitialized concrete class found for field %r (message type %r)'
446          % (field.full_name, message_type.full_name))
447      result = message_type._concrete_class()
448      result._SetListener(
449          _OneofListener(message, field)
450          if field.containing_oneof is not None
451          else message._listener_for_children)
452      return result
453    return MakeSubMessageDefault
454
455  def MakeScalarDefault(message):
456    # TODO(protobuf-team): This may be broken since there may not be
457    # default_value.  Combine with has_default_value somehow.
458    return field.default_value
459  return MakeScalarDefault
460
461
462def _ReraiseTypeErrorWithFieldName(message_name, field_name):
463  """Re-raise the currently-handled TypeError with the field name added."""
464  exc = sys.exc_info()[1]
465  if len(exc.args) == 1 and type(exc) is TypeError:
466    # simple TypeError; add field name to exception message
467    exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
468
469  # re-raise possibly-amended exception with original traceback:
470  six.reraise(type(exc), exc, sys.exc_info()[2])
471
472
473def _AddInitMethod(message_descriptor, cls):
474  """Adds an __init__ method to cls."""
475
476  def _GetIntegerEnumValue(enum_type, value):
477    """Convert a string or integer enum value to an integer.
478
479    If the value is a string, it is converted to the enum value in
480    enum_type with the same name.  If the value is not a string, it's
481    returned as-is.  (No conversion or bounds-checking is done.)
482    """
483    if isinstance(value, six.string_types):
484      try:
485        return enum_type.values_by_name[value].number
486      except KeyError:
487        raise ValueError('Enum type %s: unknown label "%s"' % (
488            enum_type.full_name, value))
489    return value
490
491  def init(self, **kwargs):
492    self._cached_byte_size = 0
493    self._cached_byte_size_dirty = len(kwargs) > 0
494    self._fields = {}
495    # Contains a mapping from oneof field descriptors to the descriptor
496    # of the currently set field in that oneof field.
497    self._oneofs = {}
498
499    # _unknown_fields is () when empty for efficiency, and will be turned into
500    # a list if fields are added.
501    self._unknown_fields = ()
502    # _unknown_field_set is None when empty for efficiency, and will be
503    # turned into UnknownFieldSet struct if fields are added.
504    self._unknown_field_set = None      # pylint: disable=protected-access
505    self._is_present_in_parent = False
506    self._listener = message_listener_mod.NullMessageListener()
507    self._listener_for_children = _Listener(self)
508    for field_name, field_value in kwargs.items():
509      field = _GetFieldByName(message_descriptor, field_name)
510      if field is None:
511        raise TypeError('%s() got an unexpected keyword argument "%s"' %
512                        (message_descriptor.name, field_name))
513      if field_value is None:
514        # field=None is the same as no field at all.
515        continue
516      if field.label == _FieldDescriptor.LABEL_REPEATED:
517        copy = field._default_constructor(self)
518        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
519          if _IsMapField(field):
520            if _IsMessageMapField(field):
521              for key in field_value:
522                copy[key].MergeFrom(field_value[key])
523            else:
524              copy.update(field_value)
525          else:
526            for val in field_value:
527              if isinstance(val, dict):
528                copy.add(**val)
529              else:
530                copy.add().MergeFrom(val)
531        else:  # Scalar
532          if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
533            field_value = [_GetIntegerEnumValue(field.enum_type, val)
534                           for val in field_value]
535          copy.extend(field_value)
536        self._fields[field] = copy
537      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
538        copy = field._default_constructor(self)
539        new_val = field_value
540        if isinstance(field_value, dict):
541          new_val = field.message_type._concrete_class(**field_value)
542        try:
543          copy.MergeFrom(new_val)
544        except TypeError:
545          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
546        self._fields[field] = copy
547      else:
548        if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
549          field_value = _GetIntegerEnumValue(field.enum_type, field_value)
550        try:
551          setattr(self, field_name, field_value)
552        except TypeError:
553          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
554
555  init.__module__ = None
556  init.__doc__ = None
557  cls.__init__ = init
558
559
560def _GetFieldByName(message_descriptor, field_name):
561  """Returns a field descriptor by field name.
562
563  Args:
564    message_descriptor: A Descriptor describing all fields in message.
565    field_name: The name of the field to retrieve.
566  Returns:
567    The field descriptor associated with the field name.
568  """
569  try:
570    return message_descriptor.fields_by_name[field_name]
571  except KeyError:
572    raise ValueError('Protocol message %s has no "%s" field.' %
573                     (message_descriptor.name, field_name))
574
575
576def _AddPropertiesForFields(descriptor, cls):
577  """Adds properties for all fields in this protocol message type."""
578  for field in descriptor.fields:
579    _AddPropertiesForField(field, cls)
580
581  if descriptor.is_extendable:
582    # _ExtensionDict is just an adaptor with no state so we allocate a new one
583    # every time it is accessed.
584    cls.Extensions = property(lambda self: _ExtensionDict(self))
585
586
587def _AddPropertiesForField(field, cls):
588  """Adds a public property for a protocol message field.
589  Clients can use this property to get and (in the case
590  of non-repeated scalar fields) directly set the value
591  of a protocol message field.
592
593  Args:
594    field: A FieldDescriptor for this field.
595    cls: The class we're constructing.
596  """
597  # Catch it if we add other types that we should
598  # handle specially here.
599  assert _FieldDescriptor.MAX_CPPTYPE == 10
600
601  constant_name = field.name.upper() + '_FIELD_NUMBER'
602  setattr(cls, constant_name, field.number)
603
604  if field.label == _FieldDescriptor.LABEL_REPEATED:
605    _AddPropertiesForRepeatedField(field, cls)
606  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
607    _AddPropertiesForNonRepeatedCompositeField(field, cls)
608  else:
609    _AddPropertiesForNonRepeatedScalarField(field, cls)
610
611
612class _FieldProperty(property):
613  __slots__ = ('DESCRIPTOR',)
614
615  def __init__(self, descriptor, getter, setter, doc):
616    property.__init__(self, getter, setter, doc=doc)
617    self.DESCRIPTOR = descriptor
618
619
620def _AddPropertiesForRepeatedField(field, cls):
621  """Adds a public property for a "repeated" protocol message field.  Clients
622  can use this property to get the value of the field, which will be either a
623  _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
624  below).
625
626  Note that when clients add values to these containers, we perform
627  type-checking in the case of repeated scalar fields, and we also set any
628  necessary "has" bits as a side-effect.
629
630  Args:
631    field: A FieldDescriptor for this field.
632    cls: The class we're constructing.
633  """
634  proto_field_name = field.name
635  property_name = _PropertyName(proto_field_name)
636
637  def getter(self):
638    field_value = self._fields.get(field)
639    if field_value is None:
640      # Construct a new object to represent this field.
641      field_value = field._default_constructor(self)
642
643      # Atomically check if another thread has preempted us and, if not, swap
644      # in the new object we just created.  If someone has preempted us, we
645      # take that object and discard ours.
646      # WARNING:  We are relying on setdefault() being atomic.  This is true
647      #   in CPython but we haven't investigated others.  This warning appears
648      #   in several other locations in this file.
649      field_value = self._fields.setdefault(field, field_value)
650    return field_value
651  getter.__module__ = None
652  getter.__doc__ = 'Getter for %s.' % proto_field_name
653
654  # We define a setter just so we can throw an exception with a more
655  # helpful error message.
656  def setter(self, new_value):
657    raise AttributeError('Assignment not allowed to repeated field '
658                         '"%s" in protocol message object.' % proto_field_name)
659
660  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
661  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
662
663
664def _AddPropertiesForNonRepeatedScalarField(field, cls):
665  """Adds a public property for a nonrepeated, scalar protocol message field.
666  Clients can use this property to get and directly set the value of the field.
667  Note that when the client sets the value of a field by using this property,
668  all necessary "has" bits are set as a side-effect, and we also perform
669  type-checking.
670
671  Args:
672    field: A FieldDescriptor for this field.
673    cls: The class we're constructing.
674  """
675  proto_field_name = field.name
676  property_name = _PropertyName(proto_field_name)
677  type_checker = type_checkers.GetTypeChecker(field)
678  default_value = field.default_value
679  valid_values = set()
680  is_proto3 = field.containing_type.syntax == 'proto3'
681
682  def getter(self):
683    # TODO(protobuf-team): This may be broken since there may not be
684    # default_value.  Combine with has_default_value somehow.
685    return self._fields.get(field, default_value)
686  getter.__module__ = None
687  getter.__doc__ = 'Getter for %s.' % proto_field_name
688
689  clear_when_set_to_default = is_proto3 and not field.containing_oneof
690
691  def field_setter(self, new_value):
692    # pylint: disable=protected-access
693    # Testing the value for truthiness captures all of the proto3 defaults
694    # (0, 0.0, enum 0, and False).
695    try:
696      new_value = type_checker.CheckValue(new_value)
697    except TypeError as e:
698      raise TypeError(
699          'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
700    if clear_when_set_to_default and not new_value:
701      self._fields.pop(field, None)
702    else:
703      self._fields[field] = new_value
704    # Check _cached_byte_size_dirty inline to improve performance, since scalar
705    # setters are called frequently.
706    if not self._cached_byte_size_dirty:
707      self._Modified()
708
709  if field.containing_oneof:
710    def setter(self, new_value):
711      field_setter(self, new_value)
712      self._UpdateOneofState(field)
713  else:
714    setter = field_setter
715
716  setter.__module__ = None
717  setter.__doc__ = 'Setter for %s.' % proto_field_name
718
719  # Add a property to encapsulate the getter/setter.
720  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
721  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
722
723
724def _AddPropertiesForNonRepeatedCompositeField(field, cls):
725  """Adds a public property for a nonrepeated, composite protocol message field.
726  A composite field is a "group" or "message" field.
727
728  Clients can use this property to get the value of the field, but cannot
729  assign to the property directly.
730
731  Args:
732    field: A FieldDescriptor for this field.
733    cls: The class we're constructing.
734  """
735  # TODO(robinson): Remove duplication with similar method
736  # for non-repeated scalars.
737  proto_field_name = field.name
738  property_name = _PropertyName(proto_field_name)
739
740  def getter(self):
741    field_value = self._fields.get(field)
742    if field_value is None:
743      # Construct a new object to represent this field.
744      field_value = field._default_constructor(self)
745
746      # Atomically check if another thread has preempted us and, if not, swap
747      # in the new object we just created.  If someone has preempted us, we
748      # take that object and discard ours.
749      # WARNING:  We are relying on setdefault() being atomic.  This is true
750      #   in CPython but we haven't investigated others.  This warning appears
751      #   in several other locations in this file.
752      field_value = self._fields.setdefault(field, field_value)
753    return field_value
754  getter.__module__ = None
755  getter.__doc__ = 'Getter for %s.' % proto_field_name
756
757  # We define a setter just so we can throw an exception with a more
758  # helpful error message.
759  def setter(self, new_value):
760    raise AttributeError('Assignment not allowed to composite field '
761                         '"%s" in protocol message object.' % proto_field_name)
762
763  # Add a property to encapsulate the getter.
764  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
765  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
766
767
768def _AddPropertiesForExtensions(descriptor, cls):
769  """Adds properties for all fields in this protocol message type."""
770  extensions = descriptor.extensions_by_name
771  for extension_name, extension_field in extensions.items():
772    constant_name = extension_name.upper() + '_FIELD_NUMBER'
773    setattr(cls, constant_name, extension_field.number)
774
775  # TODO(amauryfa): Migrate all users of these attributes to functions like
776  #   pool.FindExtensionByNumber(descriptor).
777  if descriptor.file is not None:
778    # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
779    pool = descriptor.file.pool
780    cls._extensions_by_number = pool._extensions_by_number[descriptor]
781    cls._extensions_by_name = pool._extensions_by_name[descriptor]
782
783def _AddStaticMethods(cls):
784  # TODO(robinson): This probably needs to be thread-safe(?)
785  def RegisterExtension(extension_handle):
786    extension_handle.containing_type = cls.DESCRIPTOR
787    # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
788    cls.DESCRIPTOR.file.pool.AddExtensionDescriptor(extension_handle)
789    _AttachFieldHelpers(cls, extension_handle)
790  cls.RegisterExtension = staticmethod(RegisterExtension)
791
792  def FromString(s):
793    message = cls()
794    message.MergeFromString(s)
795    return message
796  cls.FromString = staticmethod(FromString)
797
798
799def _IsPresent(item):
800  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
801  value should be included in the list returned by ListFields()."""
802
803  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
804    return bool(item[1])
805  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
806    return item[1]._is_present_in_parent
807  else:
808    return True
809
810
811def _AddListFieldsMethod(message_descriptor, cls):
812  """Helper for _AddMessageMethods()."""
813
814  def ListFields(self):
815    all_fields = [item for item in self._fields.items() if _IsPresent(item)]
816    all_fields.sort(key = lambda item: item[0].number)
817    return all_fields
818
819  cls.ListFields = ListFields
820
821_PROTO3_ERROR_TEMPLATE = \
822  'Protocol message %s has no non-repeated submessage field "%s"'
823_PROTO2_ERROR_TEMPLATE = 'Protocol message %s has no non-repeated field "%s"'
824
825def _AddHasFieldMethod(message_descriptor, cls):
826  """Helper for _AddMessageMethods()."""
827
828  is_proto3 = (message_descriptor.syntax == "proto3")
829  error_msg = _PROTO3_ERROR_TEMPLATE if is_proto3 else _PROTO2_ERROR_TEMPLATE
830
831  hassable_fields = {}
832  for field in message_descriptor.fields:
833    if field.label == _FieldDescriptor.LABEL_REPEATED:
834      continue
835    # For proto3, only submessages and fields inside a oneof have presence.
836    if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
837        not field.containing_oneof):
838      continue
839    hassable_fields[field.name] = field
840
841  if not is_proto3:
842    # Fields inside oneofs are never repeated (enforced by the compiler).
843    for oneof in message_descriptor.oneofs:
844      hassable_fields[oneof.name] = oneof
845
846  def HasField(self, field_name):
847    try:
848      field = hassable_fields[field_name]
849    except KeyError:
850      raise ValueError(error_msg % (message_descriptor.full_name, field_name))
851
852    if isinstance(field, descriptor_mod.OneofDescriptor):
853      try:
854        return HasField(self, self._oneofs[field].name)
855      except KeyError:
856        return False
857    else:
858      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
859        value = self._fields.get(field)
860        return value is not None and value._is_present_in_parent
861      else:
862        return field in self._fields
863
864  cls.HasField = HasField
865
866
867def _AddClearFieldMethod(message_descriptor, cls):
868  """Helper for _AddMessageMethods()."""
869  def ClearField(self, field_name):
870    try:
871      field = message_descriptor.fields_by_name[field_name]
872    except KeyError:
873      try:
874        field = message_descriptor.oneofs_by_name[field_name]
875        if field in self._oneofs:
876          field = self._oneofs[field]
877        else:
878          return
879      except KeyError:
880        raise ValueError('Protocol message %s has no "%s" field.' %
881                         (message_descriptor.name, field_name))
882
883    if field in self._fields:
884      # To match the C++ implementation, we need to invalidate iterators
885      # for map fields when ClearField() happens.
886      if hasattr(self._fields[field], 'InvalidateIterators'):
887        self._fields[field].InvalidateIterators()
888
889      # Note:  If the field is a sub-message, its listener will still point
890      #   at us.  That's fine, because the worst than can happen is that it
891      #   will call _Modified() and invalidate our byte size.  Big deal.
892      del self._fields[field]
893
894      if self._oneofs.get(field.containing_oneof, None) is field:
895        del self._oneofs[field.containing_oneof]
896
897    # Always call _Modified() -- even if nothing was changed, this is
898    # a mutating method, and thus calling it should cause the field to become
899    # present in the parent message.
900    self._Modified()
901
902  cls.ClearField = ClearField
903
904
905def _AddClearExtensionMethod(cls):
906  """Helper for _AddMessageMethods()."""
907  def ClearExtension(self, extension_handle):
908    extension_dict._VerifyExtensionHandle(self, extension_handle)
909
910    # Similar to ClearField(), above.
911    if extension_handle in self._fields:
912      del self._fields[extension_handle]
913    self._Modified()
914  cls.ClearExtension = ClearExtension
915
916
917def _AddHasExtensionMethod(cls):
918  """Helper for _AddMessageMethods()."""
919  def HasExtension(self, extension_handle):
920    extension_dict._VerifyExtensionHandle(self, extension_handle)
921    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
922      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
923
924    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
925      value = self._fields.get(extension_handle)
926      return value is not None and value._is_present_in_parent
927    else:
928      return extension_handle in self._fields
929  cls.HasExtension = HasExtension
930
931def _InternalUnpackAny(msg):
932  """Unpacks Any message and returns the unpacked message.
933
934  This internal method is different from public Any Unpack method which takes
935  the target message as argument. _InternalUnpackAny method does not have
936  target message type and need to find the message type in descriptor pool.
937
938  Args:
939    msg: An Any message to be unpacked.
940
941  Returns:
942    The unpacked message.
943  """
944  # TODO(amauryfa): Don't use the factory of generated messages.
945  # To make Any work with custom factories, use the message factory of the
946  # parent message.
947  # pylint: disable=g-import-not-at-top
948  from google.protobuf import symbol_database
949  factory = symbol_database.Default()
950
951  type_url = msg.type_url
952
953  if not type_url:
954    return None
955
956  # TODO(haberman): For now we just strip the hostname.  Better logic will be
957  # required.
958  type_name = type_url.split('/')[-1]
959  descriptor = factory.pool.FindMessageTypeByName(type_name)
960
961  if descriptor is None:
962    return None
963
964  message_class = factory.GetPrototype(descriptor)
965  message = message_class()
966
967  message.ParseFromString(msg.value)
968  return message
969
970
971def _AddEqualsMethod(message_descriptor, cls):
972  """Helper for _AddMessageMethods()."""
973  def __eq__(self, other):
974    if (not isinstance(other, message_mod.Message) or
975        other.DESCRIPTOR != self.DESCRIPTOR):
976      return False
977
978    if self is other:
979      return True
980
981    if self.DESCRIPTOR.full_name == _AnyFullTypeName:
982      any_a = _InternalUnpackAny(self)
983      any_b = _InternalUnpackAny(other)
984      if any_a and any_b:
985        return any_a == any_b
986
987    if not self.ListFields() == other.ListFields():
988      return False
989
990    # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions,
991    # then use it for the comparison.
992    unknown_fields = list(self._unknown_fields)
993    unknown_fields.sort()
994    other_unknown_fields = list(other._unknown_fields)
995    other_unknown_fields.sort()
996    return unknown_fields == other_unknown_fields
997
998  cls.__eq__ = __eq__
999
1000
1001def _AddStrMethod(message_descriptor, cls):
1002  """Helper for _AddMessageMethods()."""
1003  def __str__(self):
1004    return text_format.MessageToString(self)
1005  cls.__str__ = __str__
1006
1007
1008def _AddReprMethod(message_descriptor, cls):
1009  """Helper for _AddMessageMethods()."""
1010  def __repr__(self):
1011    return text_format.MessageToString(self)
1012  cls.__repr__ = __repr__
1013
1014
1015def _AddUnicodeMethod(unused_message_descriptor, cls):
1016  """Helper for _AddMessageMethods()."""
1017
1018  def __unicode__(self):
1019    return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1020  cls.__unicode__ = __unicode__
1021
1022
1023def _BytesForNonRepeatedElement(value, field_number, field_type):
1024  """Returns the number of bytes needed to serialize a non-repeated element.
1025  The returned byte count includes space for tag information and any
1026  other additional space associated with serializing value.
1027
1028  Args:
1029    value: Value we're serializing.
1030    field_number: Field number of this value.  (Since the field number
1031      is stored as part of a varint-encoded tag, this has an impact
1032      on the total bytes required to serialize the value).
1033    field_type: The type of the field.  One of the TYPE_* constants
1034      within FieldDescriptor.
1035  """
1036  try:
1037    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1038    return fn(field_number, value)
1039  except KeyError:
1040    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1041
1042
1043def _AddByteSizeMethod(message_descriptor, cls):
1044  """Helper for _AddMessageMethods()."""
1045
1046  def ByteSize(self):
1047    if not self._cached_byte_size_dirty:
1048      return self._cached_byte_size
1049
1050    size = 0
1051    descriptor = self.DESCRIPTOR
1052    if descriptor.GetOptions().map_entry:
1053      # Fields of map entry should always be serialized.
1054      size = descriptor.fields_by_name['key']._sizer(self.key)
1055      size += descriptor.fields_by_name['value']._sizer(self.value)
1056    else:
1057      for field_descriptor, field_value in self.ListFields():
1058        size += field_descriptor._sizer(field_value)
1059      for tag_bytes, value_bytes in self._unknown_fields:
1060        size += len(tag_bytes) + len(value_bytes)
1061
1062    self._cached_byte_size = size
1063    self._cached_byte_size_dirty = False
1064    self._listener_for_children.dirty = False
1065    return size
1066
1067  cls.ByteSize = ByteSize
1068
1069
1070def _AddSerializeToStringMethod(message_descriptor, cls):
1071  """Helper for _AddMessageMethods()."""
1072
1073  def SerializeToString(self, **kwargs):
1074    # Check if the message has all of its required fields set.
1075    errors = []
1076    if not self.IsInitialized():
1077      raise message_mod.EncodeError(
1078          'Message %s is missing required fields: %s' % (
1079          self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1080    return self.SerializePartialToString(**kwargs)
1081  cls.SerializeToString = SerializeToString
1082
1083
1084def _AddSerializePartialToStringMethod(message_descriptor, cls):
1085  """Helper for _AddMessageMethods()."""
1086
1087  def SerializePartialToString(self, **kwargs):
1088    out = BytesIO()
1089    self._InternalSerialize(out.write, **kwargs)
1090    return out.getvalue()
1091  cls.SerializePartialToString = SerializePartialToString
1092
1093  def InternalSerialize(self, write_bytes, deterministic=None):
1094    if deterministic is None:
1095      deterministic = (
1096          api_implementation.IsPythonDefaultSerializationDeterministic())
1097    else:
1098      deterministic = bool(deterministic)
1099
1100    descriptor = self.DESCRIPTOR
1101    if descriptor.GetOptions().map_entry:
1102      # Fields of map entry should always be serialized.
1103      descriptor.fields_by_name['key']._encoder(
1104          write_bytes, self.key, deterministic)
1105      descriptor.fields_by_name['value']._encoder(
1106          write_bytes, self.value, deterministic)
1107    else:
1108      for field_descriptor, field_value in self.ListFields():
1109        field_descriptor._encoder(write_bytes, field_value, deterministic)
1110      for tag_bytes, value_bytes in self._unknown_fields:
1111        write_bytes(tag_bytes)
1112        write_bytes(value_bytes)
1113  cls._InternalSerialize = InternalSerialize
1114
1115
1116def _AddMergeFromStringMethod(message_descriptor, cls):
1117  """Helper for _AddMessageMethods()."""
1118  def MergeFromString(self, serialized):
1119    if isinstance(serialized, memoryview) and six.PY2:
1120      raise TypeError(
1121          'memoryview not supported in Python 2 with the pure Python proto '
1122          'implementation: this is to maintain compatibility with the C++ '
1123          'implementation')
1124
1125    serialized = memoryview(serialized)
1126    length = len(serialized)
1127    try:
1128      if self._InternalParse(serialized, 0, length) != length:
1129        # The only reason _InternalParse would return early is if it
1130        # encountered an end-group tag.
1131        raise message_mod.DecodeError('Unexpected end-group tag.')
1132    except (IndexError, TypeError):
1133      # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1134      raise message_mod.DecodeError('Truncated message.')
1135    except struct.error as e:
1136      raise message_mod.DecodeError(e)
1137    return length   # Return this for legacy reasons.
1138  cls.MergeFromString = MergeFromString
1139
1140  local_ReadTag = decoder.ReadTag
1141  local_SkipField = decoder.SkipField
1142  decoders_by_tag = cls._decoders_by_tag
1143
1144  def InternalParse(self, buffer, pos, end):
1145    """Create a message from serialized bytes.
1146
1147    Args:
1148      self: Message, instance of the proto message object.
1149      buffer: memoryview of the serialized data.
1150      pos: int, position to start in the serialized data.
1151      end: int, end position of the serialized data.
1152
1153    Returns:
1154      Message object.
1155    """
1156    # Guard against internal misuse, since this function is called internally
1157    # quite extensively, and its easy to accidentally pass bytes.
1158    assert isinstance(buffer, memoryview)
1159    self._Modified()
1160    field_dict = self._fields
1161    # pylint: disable=protected-access
1162    unknown_field_set = self._unknown_field_set
1163    while pos != end:
1164      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1165      field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1166      if field_decoder is None:
1167        if not self._unknown_fields:   # pylint: disable=protected-access
1168          self._unknown_fields = []    # pylint: disable=protected-access
1169        if unknown_field_set is None:
1170          # pylint: disable=protected-access
1171          self._unknown_field_set = containers.UnknownFieldSet()
1172          # pylint: disable=protected-access
1173          unknown_field_set = self._unknown_field_set
1174        # pylint: disable=protected-access
1175        (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1176        field_number, wire_type = wire_format.UnpackTag(tag)
1177        # TODO(jieluo): remove old_pos.
1178        old_pos = new_pos
1179        (data, new_pos) = decoder._DecodeUnknownField(
1180            buffer, new_pos, wire_type)  # pylint: disable=protected-access
1181        if new_pos == -1:
1182          return pos
1183        # pylint: disable=protected-access
1184        unknown_field_set._add(field_number, wire_type, data)
1185        # TODO(jieluo): remove _unknown_fields.
1186        new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1187        if new_pos == -1:
1188          return pos
1189        self._unknown_fields.append(
1190            (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1191        pos = new_pos
1192      else:
1193        pos = field_decoder(buffer, new_pos, end, self, field_dict)
1194        if field_desc:
1195          self._UpdateOneofState(field_desc)
1196    return pos
1197  cls._InternalParse = InternalParse
1198
1199
1200def _AddIsInitializedMethod(message_descriptor, cls):
1201  """Adds the IsInitialized and FindInitializationError methods to the
1202  protocol message class."""
1203
1204  required_fields = [field for field in message_descriptor.fields
1205                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
1206
1207  def IsInitialized(self, errors=None):
1208    """Checks if all required fields of a message are set.
1209
1210    Args:
1211      errors:  A list which, if provided, will be populated with the field
1212               paths of all missing required fields.
1213
1214    Returns:
1215      True iff the specified message has all required fields set.
1216    """
1217
1218    # Performance is critical so we avoid HasField() and ListFields().
1219
1220    for field in required_fields:
1221      if (field not in self._fields or
1222          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1223           not self._fields[field]._is_present_in_parent)):
1224        if errors is not None:
1225          errors.extend(self.FindInitializationErrors())
1226        return False
1227
1228    for field, value in list(self._fields.items()):  # dict can change size!
1229      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1230        if field.label == _FieldDescriptor.LABEL_REPEATED:
1231          if (field.message_type.has_options and
1232              field.message_type.GetOptions().map_entry):
1233            continue
1234          for element in value:
1235            if not element.IsInitialized():
1236              if errors is not None:
1237                errors.extend(self.FindInitializationErrors())
1238              return False
1239        elif value._is_present_in_parent and not value.IsInitialized():
1240          if errors is not None:
1241            errors.extend(self.FindInitializationErrors())
1242          return False
1243
1244    return True
1245
1246  cls.IsInitialized = IsInitialized
1247
1248  def FindInitializationErrors(self):
1249    """Finds required fields which are not initialized.
1250
1251    Returns:
1252      A list of strings.  Each string is a path to an uninitialized field from
1253      the top-level message, e.g. "foo.bar[5].baz".
1254    """
1255
1256    errors = []  # simplify things
1257
1258    for field in required_fields:
1259      if not self.HasField(field.name):
1260        errors.append(field.name)
1261
1262    for field, value in self.ListFields():
1263      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1264        if field.is_extension:
1265          name = '(%s)' % field.full_name
1266        else:
1267          name = field.name
1268
1269        if _IsMapField(field):
1270          if _IsMessageMapField(field):
1271            for key in value:
1272              element = value[key]
1273              prefix = '%s[%s].' % (name, key)
1274              sub_errors = element.FindInitializationErrors()
1275              errors += [prefix + error for error in sub_errors]
1276          else:
1277            # ScalarMaps can't have any initialization errors.
1278            pass
1279        elif field.label == _FieldDescriptor.LABEL_REPEATED:
1280          for i in range(len(value)):
1281            element = value[i]
1282            prefix = '%s[%d].' % (name, i)
1283            sub_errors = element.FindInitializationErrors()
1284            errors += [prefix + error for error in sub_errors]
1285        else:
1286          prefix = name + '.'
1287          sub_errors = value.FindInitializationErrors()
1288          errors += [prefix + error for error in sub_errors]
1289
1290    return errors
1291
1292  cls.FindInitializationErrors = FindInitializationErrors
1293
1294
1295def _AddMergeFromMethod(cls):
1296  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1297  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1298
1299  def MergeFrom(self, msg):
1300    if not isinstance(msg, cls):
1301      raise TypeError(
1302          'Parameter to MergeFrom() must be instance of same class: '
1303          'expected %s got %s.' % (cls.__name__, msg.__class__.__name__))
1304
1305    assert msg is not self
1306    self._Modified()
1307
1308    fields = self._fields
1309
1310    for field, value in msg._fields.items():
1311      if field.label == LABEL_REPEATED:
1312        field_value = fields.get(field)
1313        if field_value is None:
1314          # Construct a new object to represent this field.
1315          field_value = field._default_constructor(self)
1316          fields[field] = field_value
1317        field_value.MergeFrom(value)
1318      elif field.cpp_type == CPPTYPE_MESSAGE:
1319        if value._is_present_in_parent:
1320          field_value = fields.get(field)
1321          if field_value is None:
1322            # Construct a new object to represent this field.
1323            field_value = field._default_constructor(self)
1324            fields[field] = field_value
1325          field_value.MergeFrom(value)
1326      else:
1327        self._fields[field] = value
1328        if field.containing_oneof:
1329          self._UpdateOneofState(field)
1330
1331    if msg._unknown_fields:
1332      if not self._unknown_fields:
1333        self._unknown_fields = []
1334      self._unknown_fields.extend(msg._unknown_fields)
1335      # pylint: disable=protected-access
1336      if self._unknown_field_set is None:
1337        self._unknown_field_set = containers.UnknownFieldSet()
1338      self._unknown_field_set._extend(msg._unknown_field_set)
1339
1340  cls.MergeFrom = MergeFrom
1341
1342
1343def _AddWhichOneofMethod(message_descriptor, cls):
1344  def WhichOneof(self, oneof_name):
1345    """Returns the name of the currently set field inside a oneof, or None."""
1346    try:
1347      field = message_descriptor.oneofs_by_name[oneof_name]
1348    except KeyError:
1349      raise ValueError(
1350          'Protocol message has no oneof "%s" field.' % oneof_name)
1351
1352    nested_field = self._oneofs.get(field, None)
1353    if nested_field is not None and self.HasField(nested_field.name):
1354      return nested_field.name
1355    else:
1356      return None
1357
1358  cls.WhichOneof = WhichOneof
1359
1360
1361def _AddReduceMethod(cls):
1362  def __reduce__(self):  # pylint: disable=invalid-name
1363    return (type(self), (), self.__getstate__())
1364  cls.__reduce__ = __reduce__
1365
1366
1367def _Clear(self):
1368  # Clear fields.
1369  self._fields = {}
1370  self._unknown_fields = ()
1371  # pylint: disable=protected-access
1372  if self._unknown_field_set is not None:
1373    self._unknown_field_set._clear()
1374    self._unknown_field_set = None
1375
1376  self._oneofs = {}
1377  self._Modified()
1378
1379
1380def _UnknownFields(self):
1381  if self._unknown_field_set is None:  # pylint: disable=protected-access
1382    # pylint: disable=protected-access
1383    self._unknown_field_set = containers.UnknownFieldSet()
1384  return self._unknown_field_set    # pylint: disable=protected-access
1385
1386
1387def _DiscardUnknownFields(self):
1388  self._unknown_fields = []
1389  self._unknown_field_set = None      # pylint: disable=protected-access
1390  for field, value in self.ListFields():
1391    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1392      if _IsMapField(field):
1393        if _IsMessageMapField(field):
1394          for key in value:
1395            value[key].DiscardUnknownFields()
1396      elif field.label == _FieldDescriptor.LABEL_REPEATED:
1397        for sub_message in value:
1398          sub_message.DiscardUnknownFields()
1399      else:
1400        value.DiscardUnknownFields()
1401
1402
1403def _SetListener(self, listener):
1404  if listener is None:
1405    self._listener = message_listener_mod.NullMessageListener()
1406  else:
1407    self._listener = listener
1408
1409
1410def _AddMessageMethods(message_descriptor, cls):
1411  """Adds implementations of all Message methods to cls."""
1412  _AddListFieldsMethod(message_descriptor, cls)
1413  _AddHasFieldMethod(message_descriptor, cls)
1414  _AddClearFieldMethod(message_descriptor, cls)
1415  if message_descriptor.is_extendable:
1416    _AddClearExtensionMethod(cls)
1417    _AddHasExtensionMethod(cls)
1418  _AddEqualsMethod(message_descriptor, cls)
1419  _AddStrMethod(message_descriptor, cls)
1420  _AddReprMethod(message_descriptor, cls)
1421  _AddUnicodeMethod(message_descriptor, cls)
1422  _AddByteSizeMethod(message_descriptor, cls)
1423  _AddSerializeToStringMethod(message_descriptor, cls)
1424  _AddSerializePartialToStringMethod(message_descriptor, cls)
1425  _AddMergeFromStringMethod(message_descriptor, cls)
1426  _AddIsInitializedMethod(message_descriptor, cls)
1427  _AddMergeFromMethod(cls)
1428  _AddWhichOneofMethod(message_descriptor, cls)
1429  _AddReduceMethod(cls)
1430  # Adds methods which do not depend on cls.
1431  cls.Clear = _Clear
1432  cls.UnknownFields = _UnknownFields
1433  cls.DiscardUnknownFields = _DiscardUnknownFields
1434  cls._SetListener = _SetListener
1435
1436
1437def _AddPrivateHelperMethods(message_descriptor, cls):
1438  """Adds implementation of private helper methods to cls."""
1439
1440  def Modified(self):
1441    """Sets the _cached_byte_size_dirty bit to true,
1442    and propagates this to our listener iff this was a state change.
1443    """
1444
1445    # Note:  Some callers check _cached_byte_size_dirty before calling
1446    #   _Modified() as an extra optimization.  So, if this method is ever
1447    #   changed such that it does stuff even when _cached_byte_size_dirty is
1448    #   already true, the callers need to be updated.
1449    if not self._cached_byte_size_dirty:
1450      self._cached_byte_size_dirty = True
1451      self._listener_for_children.dirty = True
1452      self._is_present_in_parent = True
1453      self._listener.Modified()
1454
1455  def _UpdateOneofState(self, field):
1456    """Sets field as the active field in its containing oneof.
1457
1458    Will also delete currently active field in the oneof, if it is different
1459    from the argument. Does not mark the message as modified.
1460    """
1461    other_field = self._oneofs.setdefault(field.containing_oneof, field)
1462    if other_field is not field:
1463      del self._fields[other_field]
1464      self._oneofs[field.containing_oneof] = field
1465
1466  cls._Modified = Modified
1467  cls.SetInParent = Modified
1468  cls._UpdateOneofState = _UpdateOneofState
1469
1470
1471class _Listener(object):
1472
1473  """MessageListener implementation that a parent message registers with its
1474  child message.
1475
1476  In order to support semantics like:
1477
1478    foo.bar.baz.qux = 23
1479    assert foo.HasField('bar')
1480
1481  ...child objects must have back references to their parents.
1482  This helper class is at the heart of this support.
1483  """
1484
1485  def __init__(self, parent_message):
1486    """Args:
1487      parent_message: The message whose _Modified() method we should call when
1488        we receive Modified() messages.
1489    """
1490    # This listener establishes a back reference from a child (contained) object
1491    # to its parent (containing) object.  We make this a weak reference to avoid
1492    # creating cyclic garbage when the client finishes with the 'parent' object
1493    # in the tree.
1494    if isinstance(parent_message, weakref.ProxyType):
1495      self._parent_message_weakref = parent_message
1496    else:
1497      self._parent_message_weakref = weakref.proxy(parent_message)
1498
1499    # As an optimization, we also indicate directly on the listener whether
1500    # or not the parent message is dirty.  This way we can avoid traversing
1501    # up the tree in the common case.
1502    self.dirty = False
1503
1504  def Modified(self):
1505    if self.dirty:
1506      return
1507    try:
1508      # Propagate the signal to our parents iff this is the first field set.
1509      self._parent_message_weakref._Modified()
1510    except ReferenceError:
1511      # We can get here if a client has kept a reference to a child object,
1512      # and is now setting a field on it, but the child's parent has been
1513      # garbage-collected.  This is not an error.
1514      pass
1515
1516
1517class _OneofListener(_Listener):
1518  """Special listener implementation for setting composite oneof fields."""
1519
1520  def __init__(self, parent_message, field):
1521    """Args:
1522      parent_message: The message whose _Modified() method we should call when
1523        we receive Modified() messages.
1524      field: The descriptor of the field being set in the parent message.
1525    """
1526    super(_OneofListener, self).__init__(parent_message)
1527    self._field = field
1528
1529  def Modified(self):
1530    """Also updates the state of the containing oneof in the parent message."""
1531    try:
1532      self._parent_message_weakref._UpdateOneofState(self._field)
1533      super(_OneofListener, self).Modified()
1534    except ReferenceError:
1535      pass
1536