• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53from io import BytesIO
54import struct
55import sys
56import weakref
57
58import six
59from six.moves import range
60
61# We use "as" to avoid name collisions with variables.
62from google.protobuf.internal import api_implementation
63from google.protobuf.internal import containers
64from google.protobuf.internal import decoder
65from google.protobuf.internal import encoder
66from google.protobuf.internal import enum_type_wrapper
67from google.protobuf.internal import extension_dict
68from google.protobuf.internal import message_listener as message_listener_mod
69from google.protobuf.internal import type_checkers
70from google.protobuf.internal import well_known_types
71from google.protobuf.internal import wire_format
72from google.protobuf import descriptor as descriptor_mod
73from google.protobuf import message as message_mod
74from google.protobuf import text_format
75
76_FieldDescriptor = descriptor_mod.FieldDescriptor
77_AnyFullTypeName = 'google.protobuf.Any'
78_ExtensionDict = extension_dict._ExtensionDict
79
80class GeneratedProtocolMessageType(type):
81
82  """Metaclass for protocol message classes created at runtime from Descriptors.
83
84  We add implementations for all methods described in the Message class.  We
85  also create properties to allow getting/setting all fields in the protocol
86  message.  Finally, we create slots to prevent users from accidentally
87  "setting" nonexistent fields in the protocol message, which then wouldn't get
88  serialized / deserialized properly.
89
90  The protocol compiler currently uses this metaclass to create protocol
91  message classes at runtime.  Clients can also manually create their own
92  classes at runtime, as in this example:
93
94  mydescriptor = Descriptor(.....)
95  factory = symbol_database.Default()
96  factory.pool.AddDescriptor(mydescriptor)
97  MyProtoClass = factory.GetPrototype(mydescriptor)
98  myproto_instance = MyProtoClass()
99  myproto.foo_field = 23
100  ...
101  """
102
103  # Must be consistent with the protocol-compiler code in
104  # proto2/compiler/internal/generator.*.
105  _DESCRIPTOR_KEY = 'DESCRIPTOR'
106
107  def __new__(cls, name, bases, dictionary):
108    """Custom allocation for runtime-generated class types.
109
110    We override __new__ because this is apparently the only place
111    where we can meaningfully set __slots__ on the class we're creating(?).
112    (The interplay between metaclasses and slots is not very well-documented).
113
114    Args:
115      name: Name of the class (ignored, but required by the
116        metaclass protocol).
117      bases: Base classes of the class we're constructing.
118        (Should be message.Message).  We ignore this field, but
119        it's required by the metaclass protocol
120      dictionary: The class dictionary of the class we're
121        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
122        a Descriptor object describing this protocol message
123        type.
124
125    Returns:
126      Newly-allocated class.
127
128    Raises:
129      RuntimeError: Generated code only work with python cpp extension.
130    """
131    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
132
133    if isinstance(descriptor, str):
134      raise RuntimeError('The generated code only work with python cpp '
135                         'extension, but it is using pure python runtime.')
136
137    # If a concrete class already exists for this descriptor, don't try to
138    # create another.  Doing so will break any messages that already exist with
139    # the existing class.
140    #
141    # The C++ implementation appears to have its own internal `PyMessageFactory`
142    # to achieve similar results.
143    #
144    # This most commonly happens in `text_format.py` when using descriptors from
145    # a custom pool; it calls symbol_database.Global().getPrototype() on a
146    # descriptor which already has an existing concrete class.
147    new_class = getattr(descriptor, '_concrete_class', None)
148    if new_class:
149      return new_class
150
151    if descriptor.full_name in well_known_types.WKTBASES:
152      bases += (well_known_types.WKTBASES[descriptor.full_name],)
153    _AddClassAttributesForNestedExtensions(descriptor, dictionary)
154    _AddSlots(descriptor, dictionary)
155
156    superclass = super(GeneratedProtocolMessageType, cls)
157    new_class = superclass.__new__(cls, name, bases, dictionary)
158    return new_class
159
160  def __init__(cls, name, bases, dictionary):
161    """Here we perform the majority of our work on the class.
162    We add enum getters, an __init__ method, implementations
163    of all Message methods, and properties for all fields
164    in the protocol type.
165
166    Args:
167      name: Name of the class (ignored, but required by the
168        metaclass protocol).
169      bases: Base classes of the class we're constructing.
170        (Should be message.Message).  We ignore this field, but
171        it's required by the metaclass protocol
172      dictionary: The class dictionary of the class we're
173        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
174        a Descriptor object describing this protocol message
175        type.
176    """
177    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
178
179    # If this is an _existing_ class looked up via `_concrete_class` in the
180    # __new__ method above, then we don't need to re-initialize anything.
181    existing_class = getattr(descriptor, '_concrete_class', None)
182    if existing_class:
183      assert existing_class is cls, (
184          'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
185          % (descriptor.full_name))
186      return
187
188    cls._decoders_by_tag = {}
189    if (descriptor.has_options and
190        descriptor.GetOptions().message_set_wire_format):
191      cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
192          decoder.MessageSetItemDecoder(descriptor), None)
193
194    # Attach stuff to each FieldDescriptor for quick lookup later on.
195    for field in descriptor.fields:
196      _AttachFieldHelpers(cls, field)
197
198    descriptor._concrete_class = cls  # pylint: disable=protected-access
199    _AddEnumValues(descriptor, cls)
200    _AddInitMethod(descriptor, cls)
201    _AddPropertiesForFields(descriptor, cls)
202    _AddPropertiesForExtensions(descriptor, cls)
203    _AddStaticMethods(cls)
204    _AddMessageMethods(descriptor, cls)
205    _AddPrivateHelperMethods(descriptor, cls)
206
207    superclass = super(GeneratedProtocolMessageType, cls)
208    superclass.__init__(name, bases, dictionary)
209
210
211# Stateless helpers for GeneratedProtocolMessageType below.
212# Outside clients should not access these directly.
213#
214# I opted not to make any of these methods on the metaclass, to make it more
215# clear that I'm not really using any state there and to keep clients from
216# thinking that they have direct access to these construction helpers.
217
218
219def _PropertyName(proto_field_name):
220  """Returns the name of the public property attribute which
221  clients can use to get and (in some cases) set the value
222  of a protocol message field.
223
224  Args:
225    proto_field_name: The protocol message field name, exactly
226      as it appears (or would appear) in a .proto file.
227  """
228  # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
229  # nnorwitz makes my day by writing:
230  # """
231  # FYI.  See the keyword module in the stdlib. This could be as simple as:
232  #
233  # if keyword.iskeyword(proto_field_name):
234  #   return proto_field_name + "_"
235  # return proto_field_name
236  # """
237  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
238  #   getattr() and setattr() to reflectively manipulate field values.  If we
239  #   rename the properties, then every such user has to also make sure to apply
240  #   the same transformation.  Note that currently if you name a field "yield",
241  #   you can still access it just fine using getattr/setattr -- it's not even
242  #   that cumbersome to do so.
243  # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
244  #   position.
245  return proto_field_name
246
247
248def _AddSlots(message_descriptor, dictionary):
249  """Adds a __slots__ entry to dictionary, containing the names of all valid
250  attributes for this message type.
251
252  Args:
253    message_descriptor: A Descriptor instance describing this message type.
254    dictionary: Class dictionary to which we'll add a '__slots__' entry.
255  """
256  dictionary['__slots__'] = ['_cached_byte_size',
257                             '_cached_byte_size_dirty',
258                             '_fields',
259                             '_unknown_fields',
260                             '_unknown_field_set',
261                             '_is_present_in_parent',
262                             '_listener',
263                             '_listener_for_children',
264                             '__weakref__',
265                             '_oneofs']
266
267
268def _IsMessageSetExtension(field):
269  return (field.is_extension and
270          field.containing_type.has_options and
271          field.containing_type.GetOptions().message_set_wire_format and
272          field.type == _FieldDescriptor.TYPE_MESSAGE and
273          field.label == _FieldDescriptor.LABEL_OPTIONAL)
274
275
276def _IsMapField(field):
277  return (field.type == _FieldDescriptor.TYPE_MESSAGE and
278          field.message_type.has_options and
279          field.message_type.GetOptions().map_entry)
280
281
282def _IsMessageMapField(field):
283  value_type = field.message_type.fields_by_name['value']
284  return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
285
286
287def _IsStrictUtf8Check(field):
288  if field.containing_type.syntax != 'proto3':
289    return False
290  enforce_utf8 = True
291  return enforce_utf8
292
293
294def _AttachFieldHelpers(cls, field_descriptor):
295  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
296  is_packable = (is_repeated and
297                 wire_format.IsTypePackable(field_descriptor.type))
298  is_proto3 = field_descriptor.containing_type.syntax == 'proto3'
299  if not is_packable:
300    is_packed = False
301  elif field_descriptor.containing_type.syntax == 'proto2':
302    is_packed = (field_descriptor.has_options and
303                field_descriptor.GetOptions().packed)
304  else:
305    has_packed_false = (field_descriptor.has_options and
306                        field_descriptor.GetOptions().HasField('packed') and
307                        field_descriptor.GetOptions().packed == False)
308    is_packed = not has_packed_false
309  is_map_entry = _IsMapField(field_descriptor)
310
311  if is_map_entry:
312    field_encoder = encoder.MapEncoder(field_descriptor)
313    sizer = encoder.MapSizer(field_descriptor,
314                             _IsMessageMapField(field_descriptor))
315  elif _IsMessageSetExtension(field_descriptor):
316    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
317    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
318  else:
319    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
320        field_descriptor.number, is_repeated, is_packed)
321    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
322        field_descriptor.number, is_repeated, is_packed)
323
324  field_descriptor._encoder = field_encoder
325  field_descriptor._sizer = sizer
326  field_descriptor._default_constructor = _DefaultValueConstructorForField(
327      field_descriptor)
328
329  def AddDecoder(wiretype, is_packed):
330    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
331    decode_type = field_descriptor.type
332    if (decode_type == _FieldDescriptor.TYPE_ENUM and
333        type_checkers.SupportsOpenEnums(field_descriptor)):
334      decode_type = _FieldDescriptor.TYPE_INT32
335
336    oneof_descriptor = None
337    clear_if_default = False
338    if field_descriptor.containing_oneof is not None:
339      oneof_descriptor = field_descriptor
340    elif (is_proto3 and not is_repeated and
341          field_descriptor.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE):
342      clear_if_default = True
343
344    if is_map_entry:
345      is_message_map = _IsMessageMapField(field_descriptor)
346
347      field_decoder = decoder.MapDecoder(
348          field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
349          is_message_map)
350    elif decode_type == _FieldDescriptor.TYPE_STRING:
351      is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor)
352      field_decoder = decoder.StringDecoder(
353          field_descriptor.number, is_repeated, is_packed,
354          field_descriptor, field_descriptor._default_constructor,
355          is_strict_utf8_check, clear_if_default)
356    elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
357      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
358          field_descriptor.number, is_repeated, is_packed,
359          field_descriptor, field_descriptor._default_constructor)
360    else:
361      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
362          field_descriptor.number, is_repeated, is_packed,
363          # pylint: disable=protected-access
364          field_descriptor, field_descriptor._default_constructor,
365          clear_if_default)
366
367    cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
368
369  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
370             False)
371
372  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
373    # To support wire compatibility of adding packed = true, add a decoder for
374    # packed values regardless of the field's options.
375    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
376
377
378def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
379  extensions = descriptor.extensions_by_name
380  for extension_name, extension_field in extensions.items():
381    assert extension_name not in dictionary
382    dictionary[extension_name] = extension_field
383
384
385def _AddEnumValues(descriptor, cls):
386  """Sets class-level attributes for all enum fields defined in this message.
387
388  Also exporting a class-level object that can name enum values.
389
390  Args:
391    descriptor: Descriptor object for this message type.
392    cls: Class we're constructing for this message type.
393  """
394  for enum_type in descriptor.enum_types:
395    setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
396    for enum_value in enum_type.values:
397      setattr(cls, enum_value.name, enum_value.number)
398
399
400def _GetInitializeDefaultForMap(field):
401  if field.label != _FieldDescriptor.LABEL_REPEATED:
402    raise ValueError('map_entry set on non-repeated field %s' % (
403        field.name))
404  fields_by_name = field.message_type.fields_by_name
405  key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
406
407  value_field = fields_by_name['value']
408  if _IsMessageMapField(field):
409    def MakeMessageMapDefault(message):
410      return containers.MessageMap(
411          message._listener_for_children, value_field.message_type, key_checker,
412          field.message_type)
413    return MakeMessageMapDefault
414  else:
415    value_checker = type_checkers.GetTypeChecker(value_field)
416    def MakePrimitiveMapDefault(message):
417      return containers.ScalarMap(
418          message._listener_for_children, key_checker, value_checker,
419          field.message_type)
420    return MakePrimitiveMapDefault
421
422def _DefaultValueConstructorForField(field):
423  """Returns a function which returns a default value for a field.
424
425  Args:
426    field: FieldDescriptor object for this field.
427
428  The returned function has one argument:
429    message: Message instance containing this field, or a weakref proxy
430      of same.
431
432  That function in turn returns a default value for this field.  The default
433    value may refer back to |message| via a weak reference.
434  """
435
436  if _IsMapField(field):
437    return _GetInitializeDefaultForMap(field)
438
439  if field.label == _FieldDescriptor.LABEL_REPEATED:
440    if field.has_default_value and field.default_value != []:
441      raise ValueError('Repeated field default value not empty list: %s' % (
442          field.default_value))
443    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
444      # We can't look at _concrete_class yet since it might not have
445      # been set.  (Depends on order in which we initialize the classes).
446      message_type = field.message_type
447      def MakeRepeatedMessageDefault(message):
448        return containers.RepeatedCompositeFieldContainer(
449            message._listener_for_children, field.message_type)
450      return MakeRepeatedMessageDefault
451    else:
452      type_checker = type_checkers.GetTypeChecker(field)
453      def MakeRepeatedScalarDefault(message):
454        return containers.RepeatedScalarFieldContainer(
455            message._listener_for_children, type_checker)
456      return MakeRepeatedScalarDefault
457
458  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
459    # _concrete_class may not yet be initialized.
460    message_type = field.message_type
461    def MakeSubMessageDefault(message):
462      assert getattr(message_type, '_concrete_class', None), (
463          'Uninitialized concrete class found for field %r (message type %r)'
464          % (field.full_name, message_type.full_name))
465      result = message_type._concrete_class()
466      result._SetListener(
467          _OneofListener(message, field)
468          if field.containing_oneof is not None
469          else message._listener_for_children)
470      return result
471    return MakeSubMessageDefault
472
473  def MakeScalarDefault(message):
474    # TODO(protobuf-team): This may be broken since there may not be
475    # default_value.  Combine with has_default_value somehow.
476    return field.default_value
477  return MakeScalarDefault
478
479
480def _ReraiseTypeErrorWithFieldName(message_name, field_name):
481  """Re-raise the currently-handled TypeError with the field name added."""
482  exc = sys.exc_info()[1]
483  if len(exc.args) == 1 and type(exc) is TypeError:
484    # simple TypeError; add field name to exception message
485    exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
486
487  # re-raise possibly-amended exception with original traceback:
488  six.reraise(type(exc), exc, sys.exc_info()[2])
489
490
491def _AddInitMethod(message_descriptor, cls):
492  """Adds an __init__ method to cls."""
493
494  def _GetIntegerEnumValue(enum_type, value):
495    """Convert a string or integer enum value to an integer.
496
497    If the value is a string, it is converted to the enum value in
498    enum_type with the same name.  If the value is not a string, it's
499    returned as-is.  (No conversion or bounds-checking is done.)
500    """
501    if isinstance(value, six.string_types):
502      try:
503        return enum_type.values_by_name[value].number
504      except KeyError:
505        raise ValueError('Enum type %s: unknown label "%s"' % (
506            enum_type.full_name, value))
507    return value
508
509  def init(self, **kwargs):
510    self._cached_byte_size = 0
511    self._cached_byte_size_dirty = len(kwargs) > 0
512    self._fields = {}
513    # Contains a mapping from oneof field descriptors to the descriptor
514    # of the currently set field in that oneof field.
515    self._oneofs = {}
516
517    # _unknown_fields is () when empty for efficiency, and will be turned into
518    # a list if fields are added.
519    self._unknown_fields = ()
520    # _unknown_field_set is None when empty for efficiency, and will be
521    # turned into UnknownFieldSet struct if fields are added.
522    self._unknown_field_set = None      # pylint: disable=protected-access
523    self._is_present_in_parent = False
524    self._listener = message_listener_mod.NullMessageListener()
525    self._listener_for_children = _Listener(self)
526    for field_name, field_value in kwargs.items():
527      field = _GetFieldByName(message_descriptor, field_name)
528      if field is None:
529        raise TypeError('%s() got an unexpected keyword argument "%s"' %
530                        (message_descriptor.name, field_name))
531      if field_value is None:
532        # field=None is the same as no field at all.
533        continue
534      if field.label == _FieldDescriptor.LABEL_REPEATED:
535        copy = field._default_constructor(self)
536        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
537          if _IsMapField(field):
538            if _IsMessageMapField(field):
539              for key in field_value:
540                copy[key].MergeFrom(field_value[key])
541            else:
542              copy.update(field_value)
543          else:
544            for val in field_value:
545              if isinstance(val, dict):
546                copy.add(**val)
547              else:
548                copy.add().MergeFrom(val)
549        else:  # Scalar
550          if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
551            field_value = [_GetIntegerEnumValue(field.enum_type, val)
552                           for val in field_value]
553          copy.extend(field_value)
554        self._fields[field] = copy
555      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
556        copy = field._default_constructor(self)
557        new_val = field_value
558        if isinstance(field_value, dict):
559          new_val = field.message_type._concrete_class(**field_value)
560        try:
561          copy.MergeFrom(new_val)
562        except TypeError:
563          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
564        self._fields[field] = copy
565      else:
566        if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
567          field_value = _GetIntegerEnumValue(field.enum_type, field_value)
568        try:
569          setattr(self, field_name, field_value)
570        except TypeError:
571          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
572
573  init.__module__ = None
574  init.__doc__ = None
575  cls.__init__ = init
576
577
578def _GetFieldByName(message_descriptor, field_name):
579  """Returns a field descriptor by field name.
580
581  Args:
582    message_descriptor: A Descriptor describing all fields in message.
583    field_name: The name of the field to retrieve.
584  Returns:
585    The field descriptor associated with the field name.
586  """
587  try:
588    return message_descriptor.fields_by_name[field_name]
589  except KeyError:
590    raise ValueError('Protocol message %s has no "%s" field.' %
591                     (message_descriptor.name, field_name))
592
593
594def _AddPropertiesForFields(descriptor, cls):
595  """Adds properties for all fields in this protocol message type."""
596  for field in descriptor.fields:
597    _AddPropertiesForField(field, cls)
598
599  if descriptor.is_extendable:
600    # _ExtensionDict is just an adaptor with no state so we allocate a new one
601    # every time it is accessed.
602    cls.Extensions = property(lambda self: _ExtensionDict(self))
603
604
605def _AddPropertiesForField(field, cls):
606  """Adds a public property for a protocol message field.
607  Clients can use this property to get and (in the case
608  of non-repeated scalar fields) directly set the value
609  of a protocol message field.
610
611  Args:
612    field: A FieldDescriptor for this field.
613    cls: The class we're constructing.
614  """
615  # Catch it if we add other types that we should
616  # handle specially here.
617  assert _FieldDescriptor.MAX_CPPTYPE == 10
618
619  constant_name = field.name.upper() + '_FIELD_NUMBER'
620  setattr(cls, constant_name, field.number)
621
622  if field.label == _FieldDescriptor.LABEL_REPEATED:
623    _AddPropertiesForRepeatedField(field, cls)
624  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
625    _AddPropertiesForNonRepeatedCompositeField(field, cls)
626  else:
627    _AddPropertiesForNonRepeatedScalarField(field, cls)
628
629
630class _FieldProperty(property):
631  __slots__ = ('DESCRIPTOR',)
632
633  def __init__(self, descriptor, getter, setter, doc):
634    property.__init__(self, getter, setter, doc=doc)
635    self.DESCRIPTOR = descriptor
636
637
638def _AddPropertiesForRepeatedField(field, cls):
639  """Adds a public property for a "repeated" protocol message field.  Clients
640  can use this property to get the value of the field, which will be either a
641  RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
642  below).
643
644  Note that when clients add values to these containers, we perform
645  type-checking in the case of repeated scalar fields, and we also set any
646  necessary "has" bits as a side-effect.
647
648  Args:
649    field: A FieldDescriptor for this field.
650    cls: The class we're constructing.
651  """
652  proto_field_name = field.name
653  property_name = _PropertyName(proto_field_name)
654
655  def getter(self):
656    field_value = self._fields.get(field)
657    if field_value is None:
658      # Construct a new object to represent this field.
659      field_value = field._default_constructor(self)
660
661      # Atomically check if another thread has preempted us and, if not, swap
662      # in the new object we just created.  If someone has preempted us, we
663      # take that object and discard ours.
664      # WARNING:  We are relying on setdefault() being atomic.  This is true
665      #   in CPython but we haven't investigated others.  This warning appears
666      #   in several other locations in this file.
667      field_value = self._fields.setdefault(field, field_value)
668    return field_value
669  getter.__module__ = None
670  getter.__doc__ = 'Getter for %s.' % proto_field_name
671
672  # We define a setter just so we can throw an exception with a more
673  # helpful error message.
674  def setter(self, new_value):
675    raise AttributeError('Assignment not allowed to repeated field '
676                         '"%s" in protocol message object.' % proto_field_name)
677
678  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
679  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
680
681
682def _AddPropertiesForNonRepeatedScalarField(field, cls):
683  """Adds a public property for a nonrepeated, scalar protocol message field.
684  Clients can use this property to get and directly set the value of the field.
685  Note that when the client sets the value of a field by using this property,
686  all necessary "has" bits are set as a side-effect, and we also perform
687  type-checking.
688
689  Args:
690    field: A FieldDescriptor for this field.
691    cls: The class we're constructing.
692  """
693  proto_field_name = field.name
694  property_name = _PropertyName(proto_field_name)
695  type_checker = type_checkers.GetTypeChecker(field)
696  default_value = field.default_value
697  is_proto3 = field.containing_type.syntax == 'proto3'
698
699  def getter(self):
700    # TODO(protobuf-team): This may be broken since there may not be
701    # default_value.  Combine with has_default_value somehow.
702    return self._fields.get(field, default_value)
703  getter.__module__ = None
704  getter.__doc__ = 'Getter for %s.' % proto_field_name
705
706  clear_when_set_to_default = is_proto3 and not field.containing_oneof
707
708  def field_setter(self, new_value):
709    # pylint: disable=protected-access
710    # Testing the value for truthiness captures all of the proto3 defaults
711    # (0, 0.0, enum 0, and False).
712    try:
713      new_value = type_checker.CheckValue(new_value)
714    except TypeError as e:
715      raise TypeError(
716          'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
717    if clear_when_set_to_default and not new_value:
718      self._fields.pop(field, None)
719    else:
720      self._fields[field] = new_value
721    # Check _cached_byte_size_dirty inline to improve performance, since scalar
722    # setters are called frequently.
723    if not self._cached_byte_size_dirty:
724      self._Modified()
725
726  if field.containing_oneof:
727    def setter(self, new_value):
728      field_setter(self, new_value)
729      self._UpdateOneofState(field)
730  else:
731    setter = field_setter
732
733  setter.__module__ = None
734  setter.__doc__ = 'Setter for %s.' % proto_field_name
735
736  # Add a property to encapsulate the getter/setter.
737  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
738  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
739
740
741def _AddPropertiesForNonRepeatedCompositeField(field, cls):
742  """Adds a public property for a nonrepeated, composite protocol message field.
743  A composite field is a "group" or "message" field.
744
745  Clients can use this property to get the value of the field, but cannot
746  assign to the property directly.
747
748  Args:
749    field: A FieldDescriptor for this field.
750    cls: The class we're constructing.
751  """
752  # TODO(robinson): Remove duplication with similar method
753  # for non-repeated scalars.
754  proto_field_name = field.name
755  property_name = _PropertyName(proto_field_name)
756
757  def getter(self):
758    field_value = self._fields.get(field)
759    if field_value is None:
760      # Construct a new object to represent this field.
761      field_value = field._default_constructor(self)
762
763      # Atomically check if another thread has preempted us and, if not, swap
764      # in the new object we just created.  If someone has preempted us, we
765      # take that object and discard ours.
766      # WARNING:  We are relying on setdefault() being atomic.  This is true
767      #   in CPython but we haven't investigated others.  This warning appears
768      #   in several other locations in this file.
769      field_value = self._fields.setdefault(field, field_value)
770    return field_value
771  getter.__module__ = None
772  getter.__doc__ = 'Getter for %s.' % proto_field_name
773
774  # We define a setter just so we can throw an exception with a more
775  # helpful error message.
776  def setter(self, new_value):
777    raise AttributeError('Assignment not allowed to composite field '
778                         '"%s" in protocol message object.' % proto_field_name)
779
780  # Add a property to encapsulate the getter.
781  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
782  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
783
784
785def _AddPropertiesForExtensions(descriptor, cls):
786  """Adds properties for all fields in this protocol message type."""
787  extensions = descriptor.extensions_by_name
788  for extension_name, extension_field in extensions.items():
789    constant_name = extension_name.upper() + '_FIELD_NUMBER'
790    setattr(cls, constant_name, extension_field.number)
791
792  # TODO(amauryfa): Migrate all users of these attributes to functions like
793  #   pool.FindExtensionByNumber(descriptor).
794  if descriptor.file is not None:
795    # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
796    pool = descriptor.file.pool
797    cls._extensions_by_number = pool._extensions_by_number[descriptor]
798    cls._extensions_by_name = pool._extensions_by_name[descriptor]
799
800def _AddStaticMethods(cls):
801  # TODO(robinson): This probably needs to be thread-safe(?)
802  def RegisterExtension(extension_handle):
803    extension_handle.containing_type = cls.DESCRIPTOR
804    # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
805    # pylint: disable=protected-access
806    cls.DESCRIPTOR.file.pool._AddExtensionDescriptor(extension_handle)
807    _AttachFieldHelpers(cls, extension_handle)
808  cls.RegisterExtension = staticmethod(RegisterExtension)
809
810  def FromString(s):
811    message = cls()
812    message.MergeFromString(s)
813    return message
814  cls.FromString = staticmethod(FromString)
815
816
817def _IsPresent(item):
818  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
819  value should be included in the list returned by ListFields()."""
820
821  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
822    return bool(item[1])
823  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
824    return item[1]._is_present_in_parent
825  else:
826    return True
827
828
829def _AddListFieldsMethod(message_descriptor, cls):
830  """Helper for _AddMessageMethods()."""
831
832  def ListFields(self):
833    all_fields = [item for item in self._fields.items() if _IsPresent(item)]
834    all_fields.sort(key = lambda item: item[0].number)
835    return all_fields
836
837  cls.ListFields = ListFields
838
839_PROTO3_ERROR_TEMPLATE = \
840  ('Protocol message %s has no non-repeated submessage field "%s" '
841   'nor marked as optional')
842_PROTO2_ERROR_TEMPLATE = 'Protocol message %s has no non-repeated field "%s"'
843
844def _AddHasFieldMethod(message_descriptor, cls):
845  """Helper for _AddMessageMethods()."""
846
847  is_proto3 = (message_descriptor.syntax == "proto3")
848  error_msg = _PROTO3_ERROR_TEMPLATE if is_proto3 else _PROTO2_ERROR_TEMPLATE
849
850  hassable_fields = {}
851  for field in message_descriptor.fields:
852    if field.label == _FieldDescriptor.LABEL_REPEATED:
853      continue
854    # For proto3, only submessages and fields inside a oneof have presence.
855    if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
856        not field.containing_oneof):
857      continue
858    hassable_fields[field.name] = field
859
860  # Has methods are supported for oneof descriptors.
861  for oneof in message_descriptor.oneofs:
862    hassable_fields[oneof.name] = oneof
863
864  def HasField(self, field_name):
865    try:
866      field = hassable_fields[field_name]
867    except KeyError:
868      raise ValueError(error_msg % (message_descriptor.full_name, field_name))
869
870    if isinstance(field, descriptor_mod.OneofDescriptor):
871      try:
872        return HasField(self, self._oneofs[field].name)
873      except KeyError:
874        return False
875    else:
876      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
877        value = self._fields.get(field)
878        return value is not None and value._is_present_in_parent
879      else:
880        return field in self._fields
881
882  cls.HasField = HasField
883
884
885def _AddClearFieldMethod(message_descriptor, cls):
886  """Helper for _AddMessageMethods()."""
887  def ClearField(self, field_name):
888    try:
889      field = message_descriptor.fields_by_name[field_name]
890    except KeyError:
891      try:
892        field = message_descriptor.oneofs_by_name[field_name]
893        if field in self._oneofs:
894          field = self._oneofs[field]
895        else:
896          return
897      except KeyError:
898        raise ValueError('Protocol message %s has no "%s" field.' %
899                         (message_descriptor.name, field_name))
900
901    if field in self._fields:
902      # To match the C++ implementation, we need to invalidate iterators
903      # for map fields when ClearField() happens.
904      if hasattr(self._fields[field], 'InvalidateIterators'):
905        self._fields[field].InvalidateIterators()
906
907      # Note:  If the field is a sub-message, its listener will still point
908      #   at us.  That's fine, because the worst than can happen is that it
909      #   will call _Modified() and invalidate our byte size.  Big deal.
910      del self._fields[field]
911
912      if self._oneofs.get(field.containing_oneof, None) is field:
913        del self._oneofs[field.containing_oneof]
914
915    # Always call _Modified() -- even if nothing was changed, this is
916    # a mutating method, and thus calling it should cause the field to become
917    # present in the parent message.
918    self._Modified()
919
920  cls.ClearField = ClearField
921
922
923def _AddClearExtensionMethod(cls):
924  """Helper for _AddMessageMethods()."""
925  def ClearExtension(self, extension_handle):
926    extension_dict._VerifyExtensionHandle(self, extension_handle)
927
928    # Similar to ClearField(), above.
929    if extension_handle in self._fields:
930      del self._fields[extension_handle]
931    self._Modified()
932  cls.ClearExtension = ClearExtension
933
934
935def _AddHasExtensionMethod(cls):
936  """Helper for _AddMessageMethods()."""
937  def HasExtension(self, extension_handle):
938    extension_dict._VerifyExtensionHandle(self, extension_handle)
939    if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
940      raise KeyError('"%s" is repeated.' % extension_handle.full_name)
941
942    if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
943      value = self._fields.get(extension_handle)
944      return value is not None and value._is_present_in_parent
945    else:
946      return extension_handle in self._fields
947  cls.HasExtension = HasExtension
948
949def _InternalUnpackAny(msg):
950  """Unpacks Any message and returns the unpacked message.
951
952  This internal method is different from public Any Unpack method which takes
953  the target message as argument. _InternalUnpackAny method does not have
954  target message type and need to find the message type in descriptor pool.
955
956  Args:
957    msg: An Any message to be unpacked.
958
959  Returns:
960    The unpacked message.
961  """
962  # TODO(amauryfa): Don't use the factory of generated messages.
963  # To make Any work with custom factories, use the message factory of the
964  # parent message.
965  # pylint: disable=g-import-not-at-top
966  from google.protobuf import symbol_database
967  factory = symbol_database.Default()
968
969  type_url = msg.type_url
970
971  if not type_url:
972    return None
973
974  # TODO(haberman): For now we just strip the hostname.  Better logic will be
975  # required.
976  type_name = type_url.split('/')[-1]
977  descriptor = factory.pool.FindMessageTypeByName(type_name)
978
979  if descriptor is None:
980    return None
981
982  message_class = factory.GetPrototype(descriptor)
983  message = message_class()
984
985  message.ParseFromString(msg.value)
986  return message
987
988
989def _AddEqualsMethod(message_descriptor, cls):
990  """Helper for _AddMessageMethods()."""
991  def __eq__(self, other):
992    if (not isinstance(other, message_mod.Message) or
993        other.DESCRIPTOR != self.DESCRIPTOR):
994      return False
995
996    if self is other:
997      return True
998
999    if self.DESCRIPTOR.full_name == _AnyFullTypeName:
1000      any_a = _InternalUnpackAny(self)
1001      any_b = _InternalUnpackAny(other)
1002      if any_a and any_b:
1003        return any_a == any_b
1004
1005    if not self.ListFields() == other.ListFields():
1006      return False
1007
1008    # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions,
1009    # then use it for the comparison.
1010    unknown_fields = list(self._unknown_fields)
1011    unknown_fields.sort()
1012    other_unknown_fields = list(other._unknown_fields)
1013    other_unknown_fields.sort()
1014    return unknown_fields == other_unknown_fields
1015
1016  cls.__eq__ = __eq__
1017
1018
1019def _AddStrMethod(message_descriptor, cls):
1020  """Helper for _AddMessageMethods()."""
1021  def __str__(self):
1022    return text_format.MessageToString(self)
1023  cls.__str__ = __str__
1024
1025
1026def _AddReprMethod(message_descriptor, cls):
1027  """Helper for _AddMessageMethods()."""
1028  def __repr__(self):
1029    return text_format.MessageToString(self)
1030  cls.__repr__ = __repr__
1031
1032
1033def _AddUnicodeMethod(unused_message_descriptor, cls):
1034  """Helper for _AddMessageMethods()."""
1035
1036  def __unicode__(self):
1037    return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1038  cls.__unicode__ = __unicode__
1039
1040
1041def _BytesForNonRepeatedElement(value, field_number, field_type):
1042  """Returns the number of bytes needed to serialize a non-repeated element.
1043  The returned byte count includes space for tag information and any
1044  other additional space associated with serializing value.
1045
1046  Args:
1047    value: Value we're serializing.
1048    field_number: Field number of this value.  (Since the field number
1049      is stored as part of a varint-encoded tag, this has an impact
1050      on the total bytes required to serialize the value).
1051    field_type: The type of the field.  One of the TYPE_* constants
1052      within FieldDescriptor.
1053  """
1054  try:
1055    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1056    return fn(field_number, value)
1057  except KeyError:
1058    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1059
1060
1061def _AddByteSizeMethod(message_descriptor, cls):
1062  """Helper for _AddMessageMethods()."""
1063
1064  def ByteSize(self):
1065    if not self._cached_byte_size_dirty:
1066      return self._cached_byte_size
1067
1068    size = 0
1069    descriptor = self.DESCRIPTOR
1070    if descriptor.GetOptions().map_entry:
1071      # Fields of map entry should always be serialized.
1072      size = descriptor.fields_by_name['key']._sizer(self.key)
1073      size += descriptor.fields_by_name['value']._sizer(self.value)
1074    else:
1075      for field_descriptor, field_value in self.ListFields():
1076        size += field_descriptor._sizer(field_value)
1077      for tag_bytes, value_bytes in self._unknown_fields:
1078        size += len(tag_bytes) + len(value_bytes)
1079
1080    self._cached_byte_size = size
1081    self._cached_byte_size_dirty = False
1082    self._listener_for_children.dirty = False
1083    return size
1084
1085  cls.ByteSize = ByteSize
1086
1087
1088def _AddSerializeToStringMethod(message_descriptor, cls):
1089  """Helper for _AddMessageMethods()."""
1090
1091  def SerializeToString(self, **kwargs):
1092    # Check if the message has all of its required fields set.
1093    if not self.IsInitialized():
1094      raise message_mod.EncodeError(
1095          'Message %s is missing required fields: %s' % (
1096          self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1097    return self.SerializePartialToString(**kwargs)
1098  cls.SerializeToString = SerializeToString
1099
1100
1101def _AddSerializePartialToStringMethod(message_descriptor, cls):
1102  """Helper for _AddMessageMethods()."""
1103
1104  def SerializePartialToString(self, **kwargs):
1105    out = BytesIO()
1106    self._InternalSerialize(out.write, **kwargs)
1107    return out.getvalue()
1108  cls.SerializePartialToString = SerializePartialToString
1109
1110  def InternalSerialize(self, write_bytes, deterministic=None):
1111    if deterministic is None:
1112      deterministic = (
1113          api_implementation.IsPythonDefaultSerializationDeterministic())
1114    else:
1115      deterministic = bool(deterministic)
1116
1117    descriptor = self.DESCRIPTOR
1118    if descriptor.GetOptions().map_entry:
1119      # Fields of map entry should always be serialized.
1120      descriptor.fields_by_name['key']._encoder(
1121          write_bytes, self.key, deterministic)
1122      descriptor.fields_by_name['value']._encoder(
1123          write_bytes, self.value, deterministic)
1124    else:
1125      for field_descriptor, field_value in self.ListFields():
1126        field_descriptor._encoder(write_bytes, field_value, deterministic)
1127      for tag_bytes, value_bytes in self._unknown_fields:
1128        write_bytes(tag_bytes)
1129        write_bytes(value_bytes)
1130  cls._InternalSerialize = InternalSerialize
1131
1132
1133def _AddMergeFromStringMethod(message_descriptor, cls):
1134  """Helper for _AddMessageMethods()."""
1135  def MergeFromString(self, serialized):
1136    if isinstance(serialized, memoryview) and six.PY2:
1137      raise TypeError(
1138          'memoryview not supported in Python 2 with the pure Python proto '
1139          'implementation: this is to maintain compatibility with the C++ '
1140          'implementation')
1141
1142    serialized = memoryview(serialized)
1143    length = len(serialized)
1144    try:
1145      if self._InternalParse(serialized, 0, length) != length:
1146        # The only reason _InternalParse would return early is if it
1147        # encountered an end-group tag.
1148        raise message_mod.DecodeError('Unexpected end-group tag.')
1149    except (IndexError, TypeError):
1150      # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1151      raise message_mod.DecodeError('Truncated message.')
1152    except struct.error as e:
1153      raise message_mod.DecodeError(e)
1154    return length   # Return this for legacy reasons.
1155  cls.MergeFromString = MergeFromString
1156
1157  local_ReadTag = decoder.ReadTag
1158  local_SkipField = decoder.SkipField
1159  decoders_by_tag = cls._decoders_by_tag
1160
1161  def InternalParse(self, buffer, pos, end):
1162    """Create a message from serialized bytes.
1163
1164    Args:
1165      self: Message, instance of the proto message object.
1166      buffer: memoryview of the serialized data.
1167      pos: int, position to start in the serialized data.
1168      end: int, end position of the serialized data.
1169
1170    Returns:
1171      Message object.
1172    """
1173    # Guard against internal misuse, since this function is called internally
1174    # quite extensively, and its easy to accidentally pass bytes.
1175    assert isinstance(buffer, memoryview)
1176    self._Modified()
1177    field_dict = self._fields
1178    # pylint: disable=protected-access
1179    unknown_field_set = self._unknown_field_set
1180    while pos != end:
1181      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1182      field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1183      if field_decoder is None:
1184        if not self._unknown_fields:   # pylint: disable=protected-access
1185          self._unknown_fields = []    # pylint: disable=protected-access
1186        if unknown_field_set is None:
1187          # pylint: disable=protected-access
1188          self._unknown_field_set = containers.UnknownFieldSet()
1189          # pylint: disable=protected-access
1190          unknown_field_set = self._unknown_field_set
1191        # pylint: disable=protected-access
1192        (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1193        field_number, wire_type = wire_format.UnpackTag(tag)
1194        if field_number == 0:
1195          raise message_mod.DecodeError('Field number 0 is illegal.')
1196        # TODO(jieluo): remove old_pos.
1197        old_pos = new_pos
1198        (data, new_pos) = decoder._DecodeUnknownField(
1199            buffer, new_pos, wire_type)  # pylint: disable=protected-access
1200        if new_pos == -1:
1201          return pos
1202        # pylint: disable=protected-access
1203        unknown_field_set._add(field_number, wire_type, data)
1204        # TODO(jieluo): remove _unknown_fields.
1205        new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1206        if new_pos == -1:
1207          return pos
1208        self._unknown_fields.append(
1209            (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1210        pos = new_pos
1211      else:
1212        pos = field_decoder(buffer, new_pos, end, self, field_dict)
1213        if field_desc:
1214          self._UpdateOneofState(field_desc)
1215    return pos
1216  cls._InternalParse = InternalParse
1217
1218
1219def _AddIsInitializedMethod(message_descriptor, cls):
1220  """Adds the IsInitialized and FindInitializationError methods to the
1221  protocol message class."""
1222
1223  required_fields = [field for field in message_descriptor.fields
1224                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
1225
1226  def IsInitialized(self, errors=None):
1227    """Checks if all required fields of a message are set.
1228
1229    Args:
1230      errors:  A list which, if provided, will be populated with the field
1231               paths of all missing required fields.
1232
1233    Returns:
1234      True iff the specified message has all required fields set.
1235    """
1236
1237    # Performance is critical so we avoid HasField() and ListFields().
1238
1239    for field in required_fields:
1240      if (field not in self._fields or
1241          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1242           not self._fields[field]._is_present_in_parent)):
1243        if errors is not None:
1244          errors.extend(self.FindInitializationErrors())
1245        return False
1246
1247    for field, value in list(self._fields.items()):  # dict can change size!
1248      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1249        if field.label == _FieldDescriptor.LABEL_REPEATED:
1250          if (field.message_type.has_options and
1251              field.message_type.GetOptions().map_entry):
1252            continue
1253          for element in value:
1254            if not element.IsInitialized():
1255              if errors is not None:
1256                errors.extend(self.FindInitializationErrors())
1257              return False
1258        elif value._is_present_in_parent and not value.IsInitialized():
1259          if errors is not None:
1260            errors.extend(self.FindInitializationErrors())
1261          return False
1262
1263    return True
1264
1265  cls.IsInitialized = IsInitialized
1266
1267  def FindInitializationErrors(self):
1268    """Finds required fields which are not initialized.
1269
1270    Returns:
1271      A list of strings.  Each string is a path to an uninitialized field from
1272      the top-level message, e.g. "foo.bar[5].baz".
1273    """
1274
1275    errors = []  # simplify things
1276
1277    for field in required_fields:
1278      if not self.HasField(field.name):
1279        errors.append(field.name)
1280
1281    for field, value in self.ListFields():
1282      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1283        if field.is_extension:
1284          name = '(%s)' % field.full_name
1285        else:
1286          name = field.name
1287
1288        if _IsMapField(field):
1289          if _IsMessageMapField(field):
1290            for key in value:
1291              element = value[key]
1292              prefix = '%s[%s].' % (name, key)
1293              sub_errors = element.FindInitializationErrors()
1294              errors += [prefix + error for error in sub_errors]
1295          else:
1296            # ScalarMaps can't have any initialization errors.
1297            pass
1298        elif field.label == _FieldDescriptor.LABEL_REPEATED:
1299          for i in range(len(value)):
1300            element = value[i]
1301            prefix = '%s[%d].' % (name, i)
1302            sub_errors = element.FindInitializationErrors()
1303            errors += [prefix + error for error in sub_errors]
1304        else:
1305          prefix = name + '.'
1306          sub_errors = value.FindInitializationErrors()
1307          errors += [prefix + error for error in sub_errors]
1308
1309    return errors
1310
1311  cls.FindInitializationErrors = FindInitializationErrors
1312
1313
1314def _AddMergeFromMethod(cls):
1315  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1316  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1317
1318  def MergeFrom(self, msg):
1319    if not isinstance(msg, cls):
1320      raise TypeError(
1321          'Parameter to MergeFrom() must be instance of same class: '
1322          'expected %s got %s.' % (cls.__name__, msg.__class__.__name__))
1323
1324    assert msg is not self
1325    self._Modified()
1326
1327    fields = self._fields
1328
1329    for field, value in msg._fields.items():
1330      if field.label == LABEL_REPEATED:
1331        field_value = fields.get(field)
1332        if field_value is None:
1333          # Construct a new object to represent this field.
1334          field_value = field._default_constructor(self)
1335          fields[field] = field_value
1336        field_value.MergeFrom(value)
1337      elif field.cpp_type == CPPTYPE_MESSAGE:
1338        if value._is_present_in_parent:
1339          field_value = fields.get(field)
1340          if field_value is None:
1341            # Construct a new object to represent this field.
1342            field_value = field._default_constructor(self)
1343            fields[field] = field_value
1344          field_value.MergeFrom(value)
1345      else:
1346        self._fields[field] = value
1347        if field.containing_oneof:
1348          self._UpdateOneofState(field)
1349
1350    if msg._unknown_fields:
1351      if not self._unknown_fields:
1352        self._unknown_fields = []
1353      self._unknown_fields.extend(msg._unknown_fields)
1354      # pylint: disable=protected-access
1355      if self._unknown_field_set is None:
1356        self._unknown_field_set = containers.UnknownFieldSet()
1357      self._unknown_field_set._extend(msg._unknown_field_set)
1358
1359  cls.MergeFrom = MergeFrom
1360
1361
1362def _AddWhichOneofMethod(message_descriptor, cls):
1363  def WhichOneof(self, oneof_name):
1364    """Returns the name of the currently set field inside a oneof, or None."""
1365    try:
1366      field = message_descriptor.oneofs_by_name[oneof_name]
1367    except KeyError:
1368      raise ValueError(
1369          'Protocol message has no oneof "%s" field.' % oneof_name)
1370
1371    nested_field = self._oneofs.get(field, None)
1372    if nested_field is not None and self.HasField(nested_field.name):
1373      return nested_field.name
1374    else:
1375      return None
1376
1377  cls.WhichOneof = WhichOneof
1378
1379
1380def _Clear(self):
1381  # Clear fields.
1382  self._fields = {}
1383  self._unknown_fields = ()
1384  # pylint: disable=protected-access
1385  if self._unknown_field_set is not None:
1386    self._unknown_field_set._clear()
1387    self._unknown_field_set = None
1388
1389  self._oneofs = {}
1390  self._Modified()
1391
1392
1393def _UnknownFields(self):
1394  if self._unknown_field_set is None:  # pylint: disable=protected-access
1395    # pylint: disable=protected-access
1396    self._unknown_field_set = containers.UnknownFieldSet()
1397  return self._unknown_field_set    # pylint: disable=protected-access
1398
1399
1400def _DiscardUnknownFields(self):
1401  self._unknown_fields = []
1402  self._unknown_field_set = None      # pylint: disable=protected-access
1403  for field, value in self.ListFields():
1404    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1405      if _IsMapField(field):
1406        if _IsMessageMapField(field):
1407          for key in value:
1408            value[key].DiscardUnknownFields()
1409      elif field.label == _FieldDescriptor.LABEL_REPEATED:
1410        for sub_message in value:
1411          sub_message.DiscardUnknownFields()
1412      else:
1413        value.DiscardUnknownFields()
1414
1415
1416def _SetListener(self, listener):
1417  if listener is None:
1418    self._listener = message_listener_mod.NullMessageListener()
1419  else:
1420    self._listener = listener
1421
1422
1423def _AddMessageMethods(message_descriptor, cls):
1424  """Adds implementations of all Message methods to cls."""
1425  _AddListFieldsMethod(message_descriptor, cls)
1426  _AddHasFieldMethod(message_descriptor, cls)
1427  _AddClearFieldMethod(message_descriptor, cls)
1428  if message_descriptor.is_extendable:
1429    _AddClearExtensionMethod(cls)
1430    _AddHasExtensionMethod(cls)
1431  _AddEqualsMethod(message_descriptor, cls)
1432  _AddStrMethod(message_descriptor, cls)
1433  _AddReprMethod(message_descriptor, cls)
1434  _AddUnicodeMethod(message_descriptor, cls)
1435  _AddByteSizeMethod(message_descriptor, cls)
1436  _AddSerializeToStringMethod(message_descriptor, cls)
1437  _AddSerializePartialToStringMethod(message_descriptor, cls)
1438  _AddMergeFromStringMethod(message_descriptor, cls)
1439  _AddIsInitializedMethod(message_descriptor, cls)
1440  _AddMergeFromMethod(cls)
1441  _AddWhichOneofMethod(message_descriptor, cls)
1442  # Adds methods which do not depend on cls.
1443  cls.Clear = _Clear
1444  cls.UnknownFields = _UnknownFields
1445  cls.DiscardUnknownFields = _DiscardUnknownFields
1446  cls._SetListener = _SetListener
1447
1448
1449def _AddPrivateHelperMethods(message_descriptor, cls):
1450  """Adds implementation of private helper methods to cls."""
1451
1452  def Modified(self):
1453    """Sets the _cached_byte_size_dirty bit to true,
1454    and propagates this to our listener iff this was a state change.
1455    """
1456
1457    # Note:  Some callers check _cached_byte_size_dirty before calling
1458    #   _Modified() as an extra optimization.  So, if this method is ever
1459    #   changed such that it does stuff even when _cached_byte_size_dirty is
1460    #   already true, the callers need to be updated.
1461    if not self._cached_byte_size_dirty:
1462      self._cached_byte_size_dirty = True
1463      self._listener_for_children.dirty = True
1464      self._is_present_in_parent = True
1465      self._listener.Modified()
1466
1467  def _UpdateOneofState(self, field):
1468    """Sets field as the active field in its containing oneof.
1469
1470    Will also delete currently active field in the oneof, if it is different
1471    from the argument. Does not mark the message as modified.
1472    """
1473    other_field = self._oneofs.setdefault(field.containing_oneof, field)
1474    if other_field is not field:
1475      del self._fields[other_field]
1476      self._oneofs[field.containing_oneof] = field
1477
1478  cls._Modified = Modified
1479  cls.SetInParent = Modified
1480  cls._UpdateOneofState = _UpdateOneofState
1481
1482
1483class _Listener(object):
1484
1485  """MessageListener implementation that a parent message registers with its
1486  child message.
1487
1488  In order to support semantics like:
1489
1490    foo.bar.baz.qux = 23
1491    assert foo.HasField('bar')
1492
1493  ...child objects must have back references to their parents.
1494  This helper class is at the heart of this support.
1495  """
1496
1497  def __init__(self, parent_message):
1498    """Args:
1499      parent_message: The message whose _Modified() method we should call when
1500        we receive Modified() messages.
1501    """
1502    # This listener establishes a back reference from a child (contained) object
1503    # to its parent (containing) object.  We make this a weak reference to avoid
1504    # creating cyclic garbage when the client finishes with the 'parent' object
1505    # in the tree.
1506    if isinstance(parent_message, weakref.ProxyType):
1507      self._parent_message_weakref = parent_message
1508    else:
1509      self._parent_message_weakref = weakref.proxy(parent_message)
1510
1511    # As an optimization, we also indicate directly on the listener whether
1512    # or not the parent message is dirty.  This way we can avoid traversing
1513    # up the tree in the common case.
1514    self.dirty = False
1515
1516  def Modified(self):
1517    if self.dirty:
1518      return
1519    try:
1520      # Propagate the signal to our parents iff this is the first field set.
1521      self._parent_message_weakref._Modified()
1522    except ReferenceError:
1523      # We can get here if a client has kept a reference to a child object,
1524      # and is now setting a field on it, but the child's parent has been
1525      # garbage-collected.  This is not an error.
1526      pass
1527
1528
1529class _OneofListener(_Listener):
1530  """Special listener implementation for setting composite oneof fields."""
1531
1532  def __init__(self, parent_message, field):
1533    """Args:
1534      parent_message: The message whose _Modified() method we should call when
1535        we receive Modified() messages.
1536      field: The descriptor of the field being set in the parent message.
1537    """
1538    super(_OneofListener, self).__init__(parent_message)
1539    self._field = field
1540
1541  def Modified(self):
1542    """Also updates the state of the containing oneof in the parent message."""
1543    try:
1544      self._parent_message_weakref._UpdateOneofState(self._field)
1545      super(_OneofListener, self).Modified()
1546    except ReferenceError:
1547      pass
1548