• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8# This code is meant to work on Python 2.4 and above only.
9#
10# TODO: Helpers for verbose, common checks like seeing if a
11# descriptor's cpp_type is CPPTYPE_MESSAGE.
12
13"""Contains a metaclass and helper functions used to create
14protocol message classes from Descriptor objects at runtime.
15
16Recall that a metaclass is the "type" of a class.
17(A class is to a metaclass what an instance is to a class.)
18
19In this case, we use the GeneratedProtocolMessageType metaclass
20to inject all the useful functionality into the classes
21output by the protocol compiler at compile-time.
22
23The upshot of all this is that the real implementation
24details for ALL pure-Python protocol buffers are *here in
25this file*.
26"""
27
28__author__ = 'robinson@google.com (Will Robinson)'
29
30import datetime
31from io import BytesIO
32import struct
33import sys
34import warnings
35import weakref
36
37from google.protobuf import descriptor as descriptor_mod
38from google.protobuf import message as message_mod
39from google.protobuf import text_format
40# We use "as" to avoid name collisions with variables.
41from google.protobuf.internal import api_implementation
42from google.protobuf.internal import containers
43from google.protobuf.internal import decoder
44from google.protobuf.internal import encoder
45from google.protobuf.internal import enum_type_wrapper
46from google.protobuf.internal import extension_dict
47from google.protobuf.internal import message_listener as message_listener_mod
48from google.protobuf.internal import type_checkers
49from google.protobuf.internal import well_known_types
50from google.protobuf.internal import wire_format
51
52_FieldDescriptor = descriptor_mod.FieldDescriptor
53_AnyFullTypeName = 'google.protobuf.Any'
54_StructFullTypeName = 'google.protobuf.Struct'
55_ListValueFullTypeName = 'google.protobuf.ListValue'
56_ExtensionDict = extension_dict._ExtensionDict
57
58class GeneratedProtocolMessageType(type):
59
60  """Metaclass for protocol message classes created at runtime from Descriptors.
61
62  We add implementations for all methods described in the Message class.  We
63  also create properties to allow getting/setting all fields in the protocol
64  message.  Finally, we create slots to prevent users from accidentally
65  "setting" nonexistent fields in the protocol message, which then wouldn't get
66  serialized / deserialized properly.
67
68  The protocol compiler currently uses this metaclass to create protocol
69  message classes at runtime.  Clients can also manually create their own
70  classes at runtime, as in this example:
71
72  mydescriptor = Descriptor(.....)
73  factory = symbol_database.Default()
74  factory.pool.AddDescriptor(mydescriptor)
75  MyProtoClass = message_factory.GetMessageClass(mydescriptor)
76  myproto_instance = MyProtoClass()
77  myproto.foo_field = 23
78  ...
79  """
80
81  # Must be consistent with the protocol-compiler code in
82  # proto2/compiler/internal/generator.*.
83  _DESCRIPTOR_KEY = 'DESCRIPTOR'
84
85  def __new__(cls, name, bases, dictionary):
86    """Custom allocation for runtime-generated class types.
87
88    We override __new__ because this is apparently the only place
89    where we can meaningfully set __slots__ on the class we're creating(?).
90    (The interplay between metaclasses and slots is not very well-documented).
91
92    Args:
93      name: Name of the class (ignored, but required by the
94        metaclass protocol).
95      bases: Base classes of the class we're constructing.
96        (Should be message.Message).  We ignore this field, but
97        it's required by the metaclass protocol
98      dictionary: The class dictionary of the class we're
99        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
100        a Descriptor object describing this protocol message
101        type.
102
103    Returns:
104      Newly-allocated class.
105
106    Raises:
107      RuntimeError: Generated code only work with python cpp extension.
108    """
109    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
110
111    if isinstance(descriptor, str):
112      raise RuntimeError('The generated code only work with python cpp '
113                         'extension, but it is using pure python runtime.')
114
115    # If a concrete class already exists for this descriptor, don't try to
116    # create another.  Doing so will break any messages that already exist with
117    # the existing class.
118    #
119    # The C++ implementation appears to have its own internal `PyMessageFactory`
120    # to achieve similar results.
121    #
122    # This most commonly happens in `text_format.py` when using descriptors from
123    # a custom pool; it calls message_factory.GetMessageClass() on a
124    # descriptor which already has an existing concrete class.
125    new_class = getattr(descriptor, '_concrete_class', None)
126    if new_class:
127      return new_class
128
129    if descriptor.full_name in well_known_types.WKTBASES:
130      bases += (well_known_types.WKTBASES[descriptor.full_name],)
131    _AddClassAttributesForNestedExtensions(descriptor, dictionary)
132    _AddSlots(descriptor, dictionary)
133
134    superclass = super(GeneratedProtocolMessageType, cls)
135    new_class = superclass.__new__(cls, name, bases, dictionary)
136    return new_class
137
138  def __init__(cls, name, bases, dictionary):
139    """Here we perform the majority of our work on the class.
140    We add enum getters, an __init__ method, implementations
141    of all Message methods, and properties for all fields
142    in the protocol type.
143
144    Args:
145      name: Name of the class (ignored, but required by the
146        metaclass protocol).
147      bases: Base classes of the class we're constructing.
148        (Should be message.Message).  We ignore this field, but
149        it's required by the metaclass protocol
150      dictionary: The class dictionary of the class we're
151        constructing.  dictionary[_DESCRIPTOR_KEY] must contain
152        a Descriptor object describing this protocol message
153        type.
154    """
155    descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
156
157    # If this is an _existing_ class looked up via `_concrete_class` in the
158    # __new__ method above, then we don't need to re-initialize anything.
159    existing_class = getattr(descriptor, '_concrete_class', None)
160    if existing_class:
161      assert existing_class is cls, (
162          'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
163          % (descriptor.full_name))
164      return
165
166    cls._message_set_decoders_by_tag = {}
167    cls._fields_by_tag = {}
168    if (descriptor.has_options and
169        descriptor.GetOptions().message_set_wire_format):
170      cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
171          decoder.MessageSetItemDecoder(descriptor),
172          None,
173      )
174
175    # Attach stuff to each FieldDescriptor for quick lookup later on.
176    for field in descriptor.fields:
177      _AttachFieldHelpers(cls, field)
178
179    if descriptor.is_extendable and hasattr(descriptor.file, 'pool'):
180      extensions = descriptor.file.pool.FindAllExtensions(descriptor)
181      for ext in extensions:
182        _AttachFieldHelpers(cls, ext)
183
184    descriptor._concrete_class = cls  # pylint: disable=protected-access
185    _AddEnumValues(descriptor, cls)
186    _AddInitMethod(descriptor, cls)
187    _AddPropertiesForFields(descriptor, cls)
188    _AddPropertiesForExtensions(descriptor, cls)
189    _AddStaticMethods(cls)
190    _AddMessageMethods(descriptor, cls)
191    _AddPrivateHelperMethods(descriptor, cls)
192
193    superclass = super(GeneratedProtocolMessageType, cls)
194    superclass.__init__(name, bases, dictionary)
195
196
197# Stateless helpers for GeneratedProtocolMessageType below.
198# Outside clients should not access these directly.
199#
200# I opted not to make any of these methods on the metaclass, to make it more
201# clear that I'm not really using any state there and to keep clients from
202# thinking that they have direct access to these construction helpers.
203
204
205def _PropertyName(proto_field_name):
206  """Returns the name of the public property attribute which
207  clients can use to get and (in some cases) set the value
208  of a protocol message field.
209
210  Args:
211    proto_field_name: The protocol message field name, exactly
212      as it appears (or would appear) in a .proto file.
213  """
214  # TODO: Escape Python keywords (e.g., yield), and test this support.
215  # nnorwitz makes my day by writing:
216  # """
217  # FYI.  See the keyword module in the stdlib. This could be as simple as:
218  #
219  # if keyword.iskeyword(proto_field_name):
220  #   return proto_field_name + "_"
221  # return proto_field_name
222  # """
223  # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
224  #   getattr() and setattr() to reflectively manipulate field values.  If we
225  #   rename the properties, then every such user has to also make sure to apply
226  #   the same transformation.  Note that currently if you name a field "yield",
227  #   you can still access it just fine using getattr/setattr -- it's not even
228  #   that cumbersome to do so.
229  # TODO:  Remove this method entirely if/when everyone agrees with my
230  #   position.
231  return proto_field_name
232
233
234def _AddSlots(message_descriptor, dictionary):
235  """Adds a __slots__ entry to dictionary, containing the names of all valid
236  attributes for this message type.
237
238  Args:
239    message_descriptor: A Descriptor instance describing this message type.
240    dictionary: Class dictionary to which we'll add a '__slots__' entry.
241  """
242  dictionary['__slots__'] = ['_cached_byte_size',
243                             '_cached_byte_size_dirty',
244                             '_fields',
245                             '_unknown_fields',
246                             '_is_present_in_parent',
247                             '_listener',
248                             '_listener_for_children',
249                             '__weakref__',
250                             '_oneofs']
251
252
253def _IsMessageSetExtension(field):
254  return (field.is_extension and
255          field.containing_type.has_options and
256          field.containing_type.GetOptions().message_set_wire_format and
257          field.type == _FieldDescriptor.TYPE_MESSAGE and
258          field.label == _FieldDescriptor.LABEL_OPTIONAL)
259
260
261def _IsMapField(field):
262  return (field.type == _FieldDescriptor.TYPE_MESSAGE and
263          field.message_type._is_map_entry)
264
265
266def _IsMessageMapField(field):
267  value_type = field.message_type.fields_by_name['value']
268  return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
269
270def _AttachFieldHelpers(cls, field_descriptor):
271  is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
272  field_descriptor._default_constructor = _DefaultValueConstructorForField(
273      field_descriptor
274  )
275
276  def AddFieldByTag(wiretype, is_packed):
277    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
278    cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
279
280  AddFieldByTag(
281      type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
282  )
283
284  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
285    # To support wire compatibility of adding packed = true, add a decoder for
286    # packed values regardless of the field's options.
287    AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
288
289
290def _MaybeAddEncoder(cls, field_descriptor):
291  if hasattr(field_descriptor, '_encoder'):
292    return
293  is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
294  is_map_entry = _IsMapField(field_descriptor)
295  is_packed = field_descriptor.is_packed
296
297  if is_map_entry:
298    field_encoder = encoder.MapEncoder(field_descriptor)
299    sizer = encoder.MapSizer(field_descriptor,
300                             _IsMessageMapField(field_descriptor))
301  elif _IsMessageSetExtension(field_descriptor):
302    field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
303    sizer = encoder.MessageSetItemSizer(field_descriptor.number)
304  else:
305    field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
306        field_descriptor.number, is_repeated, is_packed)
307    sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
308        field_descriptor.number, is_repeated, is_packed)
309
310  field_descriptor._sizer = sizer
311  field_descriptor._encoder = field_encoder
312
313
314def _MaybeAddDecoder(cls, field_descriptor):
315  if hasattr(field_descriptor, '_decoders'):
316    return
317
318  is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
319  is_map_entry = _IsMapField(field_descriptor)
320  helper_decoders = {}
321
322  def AddDecoder(is_packed):
323    decode_type = field_descriptor.type
324    if (decode_type == _FieldDescriptor.TYPE_ENUM and
325        not field_descriptor.enum_type.is_closed):
326      decode_type = _FieldDescriptor.TYPE_INT32
327
328    oneof_descriptor = None
329    if field_descriptor.containing_oneof is not None:
330      oneof_descriptor = field_descriptor
331
332    if is_map_entry:
333      is_message_map = _IsMessageMapField(field_descriptor)
334
335      field_decoder = decoder.MapDecoder(
336          field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
337          is_message_map)
338    elif decode_type == _FieldDescriptor.TYPE_STRING:
339      field_decoder = decoder.StringDecoder(
340          field_descriptor.number, is_repeated, is_packed,
341          field_descriptor, field_descriptor._default_constructor,
342          not field_descriptor.has_presence)
343    elif field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
344      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
345          field_descriptor.number, is_repeated, is_packed,
346          field_descriptor, field_descriptor._default_constructor)
347    else:
348      field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
349          field_descriptor.number, is_repeated, is_packed,
350          # pylint: disable=protected-access
351          field_descriptor, field_descriptor._default_constructor,
352          not field_descriptor.has_presence)
353
354    helper_decoders[is_packed] = field_decoder
355
356  AddDecoder(False)
357
358  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
359    # To support wire compatibility of adding packed = true, add a decoder for
360    # packed values regardless of the field's options.
361    AddDecoder(True)
362
363  field_descriptor._decoders = helper_decoders
364
365
366def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
367  extensions = descriptor.extensions_by_name
368  for extension_name, extension_field in extensions.items():
369    assert extension_name not in dictionary
370    dictionary[extension_name] = extension_field
371
372
373def _AddEnumValues(descriptor, cls):
374  """Sets class-level attributes for all enum fields defined in this message.
375
376  Also exporting a class-level object that can name enum values.
377
378  Args:
379    descriptor: Descriptor object for this message type.
380    cls: Class we're constructing for this message type.
381  """
382  for enum_type in descriptor.enum_types:
383    setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
384    for enum_value in enum_type.values:
385      setattr(cls, enum_value.name, enum_value.number)
386
387
388def _GetInitializeDefaultForMap(field):
389  if field.label != _FieldDescriptor.LABEL_REPEATED:
390    raise ValueError('map_entry set on non-repeated field %s' % (
391        field.name))
392  fields_by_name = field.message_type.fields_by_name
393  key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
394
395  value_field = fields_by_name['value']
396  if _IsMessageMapField(field):
397    def MakeMessageMapDefault(message):
398      return containers.MessageMap(
399          message._listener_for_children, value_field.message_type, key_checker,
400          field.message_type)
401    return MakeMessageMapDefault
402  else:
403    value_checker = type_checkers.GetTypeChecker(value_field)
404    def MakePrimitiveMapDefault(message):
405      return containers.ScalarMap(
406          message._listener_for_children, key_checker, value_checker,
407          field.message_type)
408    return MakePrimitiveMapDefault
409
410def _DefaultValueConstructorForField(field):
411  """Returns a function which returns a default value for a field.
412
413  Args:
414    field: FieldDescriptor object for this field.
415
416  The returned function has one argument:
417    message: Message instance containing this field, or a weakref proxy
418      of same.
419
420  That function in turn returns a default value for this field.  The default
421    value may refer back to |message| via a weak reference.
422  """
423
424  if _IsMapField(field):
425    return _GetInitializeDefaultForMap(field)
426
427  if field.label == _FieldDescriptor.LABEL_REPEATED:
428    if field.has_default_value and field.default_value != []:
429      raise ValueError('Repeated field default value not empty list: %s' % (
430          field.default_value))
431    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
432      # We can't look at _concrete_class yet since it might not have
433      # been set.  (Depends on order in which we initialize the classes).
434      message_type = field.message_type
435      def MakeRepeatedMessageDefault(message):
436        return containers.RepeatedCompositeFieldContainer(
437            message._listener_for_children, field.message_type)
438      return MakeRepeatedMessageDefault
439    else:
440      type_checker = type_checkers.GetTypeChecker(field)
441      def MakeRepeatedScalarDefault(message):
442        return containers.RepeatedScalarFieldContainer(
443            message._listener_for_children, type_checker)
444      return MakeRepeatedScalarDefault
445
446  if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
447    message_type = field.message_type
448    def MakeSubMessageDefault(message):
449      # _concrete_class may not yet be initialized.
450      if not hasattr(message_type, '_concrete_class'):
451        from google.protobuf import message_factory
452        message_factory.GetMessageClass(message_type)
453      result = message_type._concrete_class()
454      result._SetListener(
455          _OneofListener(message, field)
456          if field.containing_oneof is not None
457          else message._listener_for_children)
458      return result
459    return MakeSubMessageDefault
460
461  def MakeScalarDefault(message):
462    # TODO: This may be broken since there may not be
463    # default_value.  Combine with has_default_value somehow.
464    return field.default_value
465  return MakeScalarDefault
466
467
468def _ReraiseTypeErrorWithFieldName(message_name, field_name):
469  """Re-raise the currently-handled TypeError with the field name added."""
470  exc = sys.exc_info()[1]
471  if len(exc.args) == 1 and type(exc) is TypeError:
472    # simple TypeError; add field name to exception message
473    exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
474
475  # re-raise possibly-amended exception with original traceback:
476  raise exc.with_traceback(sys.exc_info()[2])
477
478
479def _AddInitMethod(message_descriptor, cls):
480  """Adds an __init__ method to cls."""
481
482  def _GetIntegerEnumValue(enum_type, value):
483    """Convert a string or integer enum value to an integer.
484
485    If the value is a string, it is converted to the enum value in
486    enum_type with the same name.  If the value is not a string, it's
487    returned as-is.  (No conversion or bounds-checking is done.)
488    """
489    if isinstance(value, str):
490      try:
491        return enum_type.values_by_name[value].number
492      except KeyError:
493        raise ValueError('Enum type %s: unknown label "%s"' % (
494            enum_type.full_name, value))
495    return value
496
497  def init(self, **kwargs):
498    self._cached_byte_size = 0
499    self._cached_byte_size_dirty = len(kwargs) > 0
500    self._fields = {}
501    # Contains a mapping from oneof field descriptors to the descriptor
502    # of the currently set field in that oneof field.
503    self._oneofs = {}
504
505    # _unknown_fields is () when empty for efficiency, and will be turned into
506    # a list if fields are added.
507    self._unknown_fields = ()
508    self._is_present_in_parent = False
509    self._listener = message_listener_mod.NullMessageListener()
510    self._listener_for_children = _Listener(self)
511    for field_name, field_value in kwargs.items():
512      field = _GetFieldByName(message_descriptor, field_name)
513      if field is None:
514        raise TypeError('%s() got an unexpected keyword argument "%s"' %
515                        (message_descriptor.name, field_name))
516      if field_value is None:
517        # field=None is the same as no field at all.
518        continue
519      if field.label == _FieldDescriptor.LABEL_REPEATED:
520        field_copy = field._default_constructor(self)
521        if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
522          if _IsMapField(field):
523            if _IsMessageMapField(field):
524              for key in field_value:
525                field_copy[key].MergeFrom(field_value[key])
526            else:
527              field_copy.update(field_value)
528          else:
529            for val in field_value:
530              if isinstance(val, dict):
531                field_copy.add(**val)
532              else:
533                field_copy.add().MergeFrom(val)
534        else:  # Scalar
535          if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
536            field_value = [_GetIntegerEnumValue(field.enum_type, val)
537                           for val in field_value]
538          field_copy.extend(field_value)
539        self._fields[field] = field_copy
540      elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
541        field_copy = field._default_constructor(self)
542        new_val = None
543        if isinstance(field_value, message_mod.Message):
544          new_val = field_value
545        elif isinstance(field_value, dict):
546          if field.message_type.full_name == _StructFullTypeName:
547            field_copy.Clear()
548            if len(field_value) == 1 and 'fields' in field_value:
549              try:
550                field_copy.update(field_value)
551              except:
552                # Fall back to init normal message field
553                field_copy.Clear()
554                new_val = field.message_type._concrete_class(**field_value)
555            else:
556              field_copy.update(field_value)
557          else:
558            new_val = field.message_type._concrete_class(**field_value)
559        elif hasattr(field_copy, '_internal_assign'):
560          field_copy._internal_assign(field_value)
561        else:
562          raise TypeError(
563              'Message field {0}.{1} must be initialized with a '
564              'dict or instance of same class, got {2}.'.format(
565                  message_descriptor.name,
566                  field_name,
567                  type(field_value).__name__,
568              )
569          )
570
571        if new_val != None:
572          try:
573            field_copy.MergeFrom(new_val)
574          except TypeError:
575            _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
576        self._fields[field] = field_copy
577      else:
578        if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
579          field_value = _GetIntegerEnumValue(field.enum_type, field_value)
580        try:
581          setattr(self, field_name, field_value)
582        except TypeError:
583          _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
584
585  init.__module__ = None
586  init.__doc__ = None
587  cls.__init__ = init
588
589
590def _GetFieldByName(message_descriptor, field_name):
591  """Returns a field descriptor by field name.
592
593  Args:
594    message_descriptor: A Descriptor describing all fields in message.
595    field_name: The name of the field to retrieve.
596  Returns:
597    The field descriptor associated with the field name.
598  """
599  try:
600    return message_descriptor.fields_by_name[field_name]
601  except KeyError:
602    raise ValueError('Protocol message %s has no "%s" field.' %
603                     (message_descriptor.name, field_name))
604
605
606def _AddPropertiesForFields(descriptor, cls):
607  """Adds properties for all fields in this protocol message type."""
608  for field in descriptor.fields:
609    _AddPropertiesForField(field, cls)
610
611  if descriptor.is_extendable:
612    # _ExtensionDict is just an adaptor with no state so we allocate a new one
613    # every time it is accessed.
614    cls.Extensions = property(lambda self: _ExtensionDict(self))
615
616
617def _AddPropertiesForField(field, cls):
618  """Adds a public property for a protocol message field.
619  Clients can use this property to get and (in the case
620  of non-repeated scalar fields) directly set the value
621  of a protocol message field.
622
623  Args:
624    field: A FieldDescriptor for this field.
625    cls: The class we're constructing.
626  """
627  # Catch it if we add other types that we should
628  # handle specially here.
629  assert _FieldDescriptor.MAX_CPPTYPE == 10
630
631  constant_name = field.name.upper() + '_FIELD_NUMBER'
632  setattr(cls, constant_name, field.number)
633
634  if field.label == _FieldDescriptor.LABEL_REPEATED:
635    _AddPropertiesForRepeatedField(field, cls)
636  elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
637    _AddPropertiesForNonRepeatedCompositeField(field, cls)
638  else:
639    _AddPropertiesForNonRepeatedScalarField(field, cls)
640
641
642class _FieldProperty(property):
643  __slots__ = ('DESCRIPTOR',)
644
645  def __init__(self, descriptor, getter, setter, doc):
646    property.__init__(self, getter, setter, doc=doc)
647    self.DESCRIPTOR = descriptor
648
649
650def _AddPropertiesForRepeatedField(field, cls):
651  """Adds a public property for a "repeated" protocol message field.  Clients
652  can use this property to get the value of the field, which will be either a
653  RepeatedScalarFieldContainer or RepeatedCompositeFieldContainer (see
654  below).
655
656  Note that when clients add values to these containers, we perform
657  type-checking in the case of repeated scalar fields, and we also set any
658  necessary "has" bits as a side-effect.
659
660  Args:
661    field: A FieldDescriptor for this field.
662    cls: The class we're constructing.
663  """
664  proto_field_name = field.name
665  property_name = _PropertyName(proto_field_name)
666
667  def getter(self):
668    field_value = self._fields.get(field)
669    if field_value is None:
670      # Construct a new object to represent this field.
671      field_value = field._default_constructor(self)
672
673      # Atomically check if another thread has preempted us and, if not, swap
674      # in the new object we just created.  If someone has preempted us, we
675      # take that object and discard ours.
676      # WARNING:  We are relying on setdefault() being atomic.  This is true
677      #   in CPython but we haven't investigated others.  This warning appears
678      #   in several other locations in this file.
679      field_value = self._fields.setdefault(field, field_value)
680    return field_value
681  getter.__module__ = None
682  getter.__doc__ = 'Getter for %s.' % proto_field_name
683
684  # We define a setter just so we can throw an exception with a more
685  # helpful error message.
686  def setter(self, new_value):
687    raise AttributeError('Assignment not allowed to repeated field '
688                         '"%s" in protocol message object.' % proto_field_name)
689
690  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
691  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
692
693
694def _AddPropertiesForNonRepeatedScalarField(field, cls):
695  """Adds a public property for a nonrepeated, scalar protocol message field.
696  Clients can use this property to get and directly set the value of the field.
697  Note that when the client sets the value of a field by using this property,
698  all necessary "has" bits are set as a side-effect, and we also perform
699  type-checking.
700
701  Args:
702    field: A FieldDescriptor for this field.
703    cls: The class we're constructing.
704  """
705  proto_field_name = field.name
706  property_name = _PropertyName(proto_field_name)
707  type_checker = type_checkers.GetTypeChecker(field)
708  default_value = field.default_value
709
710  def getter(self):
711    # TODO: This may be broken since there may not be
712    # default_value.  Combine with has_default_value somehow.
713    return self._fields.get(field, default_value)
714  getter.__module__ = None
715  getter.__doc__ = 'Getter for %s.' % proto_field_name
716
717  def field_setter(self, new_value):
718    # pylint: disable=protected-access
719    # Testing the value for truthiness captures all of the proto3 defaults
720    # (0, 0.0, enum 0, and False).
721    try:
722      new_value = type_checker.CheckValue(new_value)
723    except TypeError as e:
724      raise TypeError(
725          'Cannot set %s to %.1024r: %s' % (field.full_name, new_value, e))
726    if not field.has_presence and not new_value:
727      self._fields.pop(field, None)
728    else:
729      self._fields[field] = new_value
730    # Check _cached_byte_size_dirty inline to improve performance, since scalar
731    # setters are called frequently.
732    if not self._cached_byte_size_dirty:
733      self._Modified()
734
735  if field.containing_oneof:
736    def setter(self, new_value):
737      field_setter(self, new_value)
738      self._UpdateOneofState(field)
739  else:
740    setter = field_setter
741
742  setter.__module__ = None
743  setter.__doc__ = 'Setter for %s.' % proto_field_name
744
745  # Add a property to encapsulate the getter/setter.
746  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
747  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
748
749
750def _AddPropertiesForNonRepeatedCompositeField(field, cls):
751  """Adds a public property for a nonrepeated, composite protocol message field.
752  A composite field is a "group" or "message" field.
753
754  Clients can use this property to get the value of the field, but cannot
755  assign to the property directly.
756
757  Args:
758    field: A FieldDescriptor for this field.
759    cls: The class we're constructing.
760  """
761  # TODO: Remove duplication with similar method
762  # for non-repeated scalars.
763  proto_field_name = field.name
764  property_name = _PropertyName(proto_field_name)
765
766  def getter(self):
767    field_value = self._fields.get(field)
768    if field_value is None:
769      # Construct a new object to represent this field.
770      field_value = field._default_constructor(self)
771
772      # Atomically check if another thread has preempted us and, if not, swap
773      # in the new object we just created.  If someone has preempted us, we
774      # take that object and discard ours.
775      # WARNING:  We are relying on setdefault() being atomic.  This is true
776      #   in CPython but we haven't investigated others.  This warning appears
777      #   in several other locations in this file.
778      field_value = self._fields.setdefault(field, field_value)
779    return field_value
780  getter.__module__ = None
781  getter.__doc__ = 'Getter for %s.' % proto_field_name
782
783  # We define a setter just so we can throw an exception with a more
784  # helpful error message.
785  def setter(self, new_value):
786    if field.message_type.full_name == 'google.protobuf.Timestamp':
787      getter(self)
788      self._fields[field].FromDatetime(new_value)
789    elif field.message_type.full_name == 'google.protobuf.Duration':
790      getter(self)
791      self._fields[field].FromTimedelta(new_value)
792    elif field.message_type.full_name == _StructFullTypeName:
793      getter(self)
794      self._fields[field].Clear()
795      self._fields[field].update(new_value)
796    elif field.message_type.full_name == _ListValueFullTypeName:
797      getter(self)
798      self._fields[field].Clear()
799      self._fields[field].extend(new_value)
800    else:
801      raise AttributeError(
802          'Assignment not allowed to composite field '
803          '"%s" in protocol message object.' % proto_field_name
804      )
805
806  # Add a property to encapsulate the getter.
807  doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
808  setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
809
810
811def _AddPropertiesForExtensions(descriptor, cls):
812  """Adds properties for all fields in this protocol message type."""
813  extensions = descriptor.extensions_by_name
814  for extension_name, extension_field in extensions.items():
815    constant_name = extension_name.upper() + '_FIELD_NUMBER'
816    setattr(cls, constant_name, extension_field.number)
817
818  # TODO: Migrate all users of these attributes to functions like
819  #   pool.FindExtensionByNumber(descriptor).
820  if descriptor.file is not None:
821    # TODO: Use cls.MESSAGE_FACTORY.pool when available.
822    pool = descriptor.file.pool
823
824def _AddStaticMethods(cls):
825  def FromString(s):
826    message = cls()
827    message.MergeFromString(s)
828    return message
829  cls.FromString = staticmethod(FromString)
830
831
832def _IsPresent(item):
833  """Given a (FieldDescriptor, value) tuple from _fields, return true if the
834  value should be included in the list returned by ListFields()."""
835
836  if item[0].label == _FieldDescriptor.LABEL_REPEATED:
837    return bool(item[1])
838  elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
839    return item[1]._is_present_in_parent
840  else:
841    return True
842
843
844def _AddListFieldsMethod(message_descriptor, cls):
845  """Helper for _AddMessageMethods()."""
846
847  def ListFields(self):
848    all_fields = [item for item in self._fields.items() if _IsPresent(item)]
849    all_fields.sort(key = lambda item: item[0].number)
850    return all_fields
851
852  cls.ListFields = ListFields
853
854
855def _AddHasFieldMethod(message_descriptor, cls):
856  """Helper for _AddMessageMethods()."""
857
858  hassable_fields = {}
859  for field in message_descriptor.fields:
860    if field.label == _FieldDescriptor.LABEL_REPEATED:
861      continue
862    # For proto3, only submessages and fields inside a oneof have presence.
863    if not field.has_presence:
864      continue
865    hassable_fields[field.name] = field
866
867  # Has methods are supported for oneof descriptors.
868  for oneof in message_descriptor.oneofs:
869    hassable_fields[oneof.name] = oneof
870
871  def HasField(self, field_name):
872    try:
873      field = hassable_fields[field_name]
874    except KeyError as exc:
875      raise ValueError('Protocol message %s has no non-repeated field "%s" '
876                       'nor has presence is not available for this field.' % (
877                           message_descriptor.full_name, field_name)) from exc
878
879    if isinstance(field, descriptor_mod.OneofDescriptor):
880      try:
881        return HasField(self, self._oneofs[field].name)
882      except KeyError:
883        return False
884    else:
885      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
886        value = self._fields.get(field)
887        return value is not None and value._is_present_in_parent
888      else:
889        return field in self._fields
890
891  cls.HasField = HasField
892
893
894def _AddClearFieldMethod(message_descriptor, cls):
895  """Helper for _AddMessageMethods()."""
896  def ClearField(self, field_name):
897    try:
898      field = message_descriptor.fields_by_name[field_name]
899    except KeyError:
900      try:
901        field = message_descriptor.oneofs_by_name[field_name]
902        if field in self._oneofs:
903          field = self._oneofs[field]
904        else:
905          return
906      except KeyError:
907        raise ValueError('Protocol message %s has no "%s" field.' %
908                         (message_descriptor.name, field_name))
909
910    if field in self._fields:
911      # To match the C++ implementation, we need to invalidate iterators
912      # for map fields when ClearField() happens.
913      if hasattr(self._fields[field], 'InvalidateIterators'):
914        self._fields[field].InvalidateIterators()
915
916      # Note:  If the field is a sub-message, its listener will still point
917      #   at us.  That's fine, because the worst than can happen is that it
918      #   will call _Modified() and invalidate our byte size.  Big deal.
919      del self._fields[field]
920
921      if self._oneofs.get(field.containing_oneof, None) is field:
922        del self._oneofs[field.containing_oneof]
923
924    # Always call _Modified() -- even if nothing was changed, this is
925    # a mutating method, and thus calling it should cause the field to become
926    # present in the parent message.
927    self._Modified()
928
929  cls.ClearField = ClearField
930
931
932def _AddClearExtensionMethod(cls):
933  """Helper for _AddMessageMethods()."""
934  def ClearExtension(self, field_descriptor):
935    extension_dict._VerifyExtensionHandle(self, field_descriptor)
936
937    # Similar to ClearField(), above.
938    if field_descriptor in self._fields:
939      del self._fields[field_descriptor]
940    self._Modified()
941  cls.ClearExtension = ClearExtension
942
943
944def _AddHasExtensionMethod(cls):
945  """Helper for _AddMessageMethods()."""
946  def HasExtension(self, field_descriptor):
947    extension_dict._VerifyExtensionHandle(self, field_descriptor)
948    if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED:
949      raise KeyError('"%s" is repeated.' % field_descriptor.full_name)
950
951    if field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
952      value = self._fields.get(field_descriptor)
953      return value is not None and value._is_present_in_parent
954    else:
955      return field_descriptor in self._fields
956  cls.HasExtension = HasExtension
957
958def _InternalUnpackAny(msg):
959  """Unpacks Any message and returns the unpacked message.
960
961  This internal method is different from public Any Unpack method which takes
962  the target message as argument. _InternalUnpackAny method does not have
963  target message type and need to find the message type in descriptor pool.
964
965  Args:
966    msg: An Any message to be unpacked.
967
968  Returns:
969    The unpacked message.
970  """
971  # TODO: Don't use the factory of generated messages.
972  # To make Any work with custom factories, use the message factory of the
973  # parent message.
974  # pylint: disable=g-import-not-at-top
975  from google.protobuf import symbol_database
976  factory = symbol_database.Default()
977
978  type_url = msg.type_url
979
980  if not type_url:
981    return None
982
983  # TODO: For now we just strip the hostname.  Better logic will be
984  # required.
985  type_name = type_url.split('/')[-1]
986  descriptor = factory.pool.FindMessageTypeByName(type_name)
987
988  if descriptor is None:
989    return None
990
991  # Unable to import message_factory at top because of circular import.
992  # pylint: disable=g-import-not-at-top
993  from google.protobuf import message_factory
994  message_class = message_factory.GetMessageClass(descriptor)
995  message = message_class()
996
997  message.ParseFromString(msg.value)
998  return message
999
1000
1001def _AddEqualsMethod(message_descriptor, cls):
1002  """Helper for _AddMessageMethods()."""
1003  def __eq__(self, other):
1004    if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance(
1005        other, list
1006    ):
1007      return self._internal_compare(other)
1008    if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance(
1009        other, dict
1010    ):
1011      return self._internal_compare(other)
1012
1013    if (not isinstance(other, message_mod.Message) or
1014        other.DESCRIPTOR != self.DESCRIPTOR):
1015      return NotImplemented
1016
1017    if self is other:
1018      return True
1019
1020    if self.DESCRIPTOR.full_name == _AnyFullTypeName:
1021      any_a = _InternalUnpackAny(self)
1022      any_b = _InternalUnpackAny(other)
1023      if any_a and any_b:
1024        return any_a == any_b
1025
1026    if not self.ListFields() == other.ListFields():
1027      return False
1028
1029    # TODO: Fix UnknownFieldSet to consider MessageSet extensions,
1030    # then use it for the comparison.
1031    unknown_fields = list(self._unknown_fields)
1032    unknown_fields.sort()
1033    other_unknown_fields = list(other._unknown_fields)
1034    other_unknown_fields.sort()
1035    return unknown_fields == other_unknown_fields
1036
1037  cls.__eq__ = __eq__
1038
1039
1040def _AddStrMethod(message_descriptor, cls):
1041  """Helper for _AddMessageMethods()."""
1042  def __str__(self):
1043    return text_format.MessageToString(self)
1044  cls.__str__ = __str__
1045
1046
1047def _AddReprMethod(message_descriptor, cls):
1048  """Helper for _AddMessageMethods()."""
1049  def __repr__(self):
1050    return text_format.MessageToString(self)
1051  cls.__repr__ = __repr__
1052
1053
1054def _AddUnicodeMethod(unused_message_descriptor, cls):
1055  """Helper for _AddMessageMethods()."""
1056
1057  def __unicode__(self):
1058    return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
1059  cls.__unicode__ = __unicode__
1060
1061
1062def _AddContainsMethod(message_descriptor, cls):
1063
1064  if message_descriptor.full_name == 'google.protobuf.Struct':
1065    def __contains__(self, key):
1066      return key in self.fields
1067  elif message_descriptor.full_name == 'google.protobuf.ListValue':
1068    def __contains__(self, value):
1069      return value in self.items()
1070  else:
1071    def __contains__(self, field):
1072      return self.HasField(field)
1073
1074  cls.__contains__ = __contains__
1075
1076
1077def _BytesForNonRepeatedElement(value, field_number, field_type):
1078  """Returns the number of bytes needed to serialize a non-repeated element.
1079  The returned byte count includes space for tag information and any
1080  other additional space associated with serializing value.
1081
1082  Args:
1083    value: Value we're serializing.
1084    field_number: Field number of this value.  (Since the field number
1085      is stored as part of a varint-encoded tag, this has an impact
1086      on the total bytes required to serialize the value).
1087    field_type: The type of the field.  One of the TYPE_* constants
1088      within FieldDescriptor.
1089  """
1090  try:
1091    fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1092    return fn(field_number, value)
1093  except KeyError:
1094    raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1095
1096
1097def _AddByteSizeMethod(message_descriptor, cls):
1098  """Helper for _AddMessageMethods()."""
1099
1100  def ByteSize(self):
1101    if not self._cached_byte_size_dirty:
1102      return self._cached_byte_size
1103
1104    size = 0
1105    descriptor = self.DESCRIPTOR
1106    if descriptor._is_map_entry:
1107      # Fields of map entry should always be serialized.
1108      key_field = descriptor.fields_by_name['key']
1109      _MaybeAddEncoder(cls, key_field)
1110      size = key_field._sizer(self.key)
1111      value_field = descriptor.fields_by_name['value']
1112      _MaybeAddEncoder(cls, value_field)
1113      size += value_field._sizer(self.value)
1114    else:
1115      for field_descriptor, field_value in self.ListFields():
1116        _MaybeAddEncoder(cls, field_descriptor)
1117        size += field_descriptor._sizer(field_value)
1118      for tag_bytes, value_bytes in self._unknown_fields:
1119        size += len(tag_bytes) + len(value_bytes)
1120
1121    self._cached_byte_size = size
1122    self._cached_byte_size_dirty = False
1123    self._listener_for_children.dirty = False
1124    return size
1125
1126  cls.ByteSize = ByteSize
1127
1128
1129def _AddSerializeToStringMethod(message_descriptor, cls):
1130  """Helper for _AddMessageMethods()."""
1131
1132  def SerializeToString(self, **kwargs):
1133    # Check if the message has all of its required fields set.
1134    if not self.IsInitialized():
1135      raise message_mod.EncodeError(
1136          'Message %s is missing required fields: %s' % (
1137          self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1138    return self.SerializePartialToString(**kwargs)
1139  cls.SerializeToString = SerializeToString
1140
1141
1142def _AddSerializePartialToStringMethod(message_descriptor, cls):
1143  """Helper for _AddMessageMethods()."""
1144
1145  def SerializePartialToString(self, **kwargs):
1146    out = BytesIO()
1147    self._InternalSerialize(out.write, **kwargs)
1148    return out.getvalue()
1149  cls.SerializePartialToString = SerializePartialToString
1150
1151  def InternalSerialize(self, write_bytes, deterministic=None):
1152    if deterministic is None:
1153      deterministic = (
1154          api_implementation.IsPythonDefaultSerializationDeterministic())
1155    else:
1156      deterministic = bool(deterministic)
1157
1158    descriptor = self.DESCRIPTOR
1159    if descriptor._is_map_entry:
1160      # Fields of map entry should always be serialized.
1161      key_field = descriptor.fields_by_name['key']
1162      _MaybeAddEncoder(cls, key_field)
1163      key_field._encoder(write_bytes, self.key, deterministic)
1164      value_field = descriptor.fields_by_name['value']
1165      _MaybeAddEncoder(cls, value_field)
1166      value_field._encoder(write_bytes, self.value, deterministic)
1167    else:
1168      for field_descriptor, field_value in self.ListFields():
1169        _MaybeAddEncoder(cls, field_descriptor)
1170        field_descriptor._encoder(write_bytes, field_value, deterministic)
1171      for tag_bytes, value_bytes in self._unknown_fields:
1172        write_bytes(tag_bytes)
1173        write_bytes(value_bytes)
1174  cls._InternalSerialize = InternalSerialize
1175
1176
1177def _AddMergeFromStringMethod(message_descriptor, cls):
1178  """Helper for _AddMessageMethods()."""
1179  def MergeFromString(self, serialized):
1180    serialized = memoryview(serialized)
1181    length = len(serialized)
1182    try:
1183      if self._InternalParse(serialized, 0, length) != length:
1184        # The only reason _InternalParse would return early is if it
1185        # encountered an end-group tag.
1186        raise message_mod.DecodeError('Unexpected end-group tag.')
1187    except (IndexError, TypeError):
1188      # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1189      raise message_mod.DecodeError('Truncated message.')
1190    except struct.error as e:
1191      raise message_mod.DecodeError(e)
1192    return length   # Return this for legacy reasons.
1193  cls.MergeFromString = MergeFromString
1194
1195  local_ReadTag = decoder.ReadTag
1196  local_SkipField = decoder.SkipField
1197  fields_by_tag = cls._fields_by_tag
1198  message_set_decoders_by_tag = cls._message_set_decoders_by_tag
1199
1200  def InternalParse(self, buffer, pos, end):
1201    """Create a message from serialized bytes.
1202
1203    Args:
1204      self: Message, instance of the proto message object.
1205      buffer: memoryview of the serialized data.
1206      pos: int, position to start in the serialized data.
1207      end: int, end position of the serialized data.
1208
1209    Returns:
1210      Message object.
1211    """
1212    # Guard against internal misuse, since this function is called internally
1213    # quite extensively, and its easy to accidentally pass bytes.
1214    assert isinstance(buffer, memoryview)
1215    self._Modified()
1216    field_dict = self._fields
1217    while pos != end:
1218      (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1219      field_decoder, field_des = message_set_decoders_by_tag.get(
1220          tag_bytes, (None, None)
1221      )
1222      if field_decoder:
1223        pos = field_decoder(buffer, new_pos, end, self, field_dict)
1224        continue
1225      field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
1226      if field_des is None:
1227        if not self._unknown_fields:   # pylint: disable=protected-access
1228          self._unknown_fields = []    # pylint: disable=protected-access
1229        # pylint: disable=protected-access
1230        (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
1231        field_number, wire_type = wire_format.UnpackTag(tag)
1232        if field_number == 0:
1233          raise message_mod.DecodeError('Field number 0 is illegal.')
1234        # TODO: remove old_pos.
1235        old_pos = new_pos
1236        (data, new_pos) = decoder._DecodeUnknownField(
1237            buffer, new_pos, wire_type)  # pylint: disable=protected-access
1238        if new_pos == -1:
1239          return pos
1240        # TODO: remove _unknown_fields.
1241        new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
1242        if new_pos == -1:
1243          return pos
1244        self._unknown_fields.append(
1245            (tag_bytes, buffer[old_pos:new_pos].tobytes()))
1246        pos = new_pos
1247      else:
1248        _MaybeAddDecoder(cls, field_des)
1249        field_decoder = field_des._decoders[is_packed]
1250        pos = field_decoder(buffer, new_pos, end, self, field_dict)
1251        if field_des.containing_oneof:
1252          self._UpdateOneofState(field_des)
1253    return pos
1254  cls._InternalParse = InternalParse
1255
1256
1257def _AddIsInitializedMethod(message_descriptor, cls):
1258  """Adds the IsInitialized and FindInitializationError methods to the
1259  protocol message class."""
1260
1261  required_fields = [field for field in message_descriptor.fields
1262                           if field.label == _FieldDescriptor.LABEL_REQUIRED]
1263
1264  def IsInitialized(self, errors=None):
1265    """Checks if all required fields of a message are set.
1266
1267    Args:
1268      errors:  A list which, if provided, will be populated with the field
1269               paths of all missing required fields.
1270
1271    Returns:
1272      True iff the specified message has all required fields set.
1273    """
1274
1275    # Performance is critical so we avoid HasField() and ListFields().
1276
1277    for field in required_fields:
1278      if (field not in self._fields or
1279          (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1280           not self._fields[field]._is_present_in_parent)):
1281        if errors is not None:
1282          errors.extend(self.FindInitializationErrors())
1283        return False
1284
1285    for field, value in list(self._fields.items()):  # dict can change size!
1286      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1287        if field.label == _FieldDescriptor.LABEL_REPEATED:
1288          if (field.message_type._is_map_entry):
1289            continue
1290          for element in value:
1291            if not element.IsInitialized():
1292              if errors is not None:
1293                errors.extend(self.FindInitializationErrors())
1294              return False
1295        elif value._is_present_in_parent and not value.IsInitialized():
1296          if errors is not None:
1297            errors.extend(self.FindInitializationErrors())
1298          return False
1299
1300    return True
1301
1302  cls.IsInitialized = IsInitialized
1303
1304  def FindInitializationErrors(self):
1305    """Finds required fields which are not initialized.
1306
1307    Returns:
1308      A list of strings.  Each string is a path to an uninitialized field from
1309      the top-level message, e.g. "foo.bar[5].baz".
1310    """
1311
1312    errors = []  # simplify things
1313
1314    for field in required_fields:
1315      if not self.HasField(field.name):
1316        errors.append(field.name)
1317
1318    for field, value in self.ListFields():
1319      if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1320        if field.is_extension:
1321          name = '(%s)' % field.full_name
1322        else:
1323          name = field.name
1324
1325        if _IsMapField(field):
1326          if _IsMessageMapField(field):
1327            for key in value:
1328              element = value[key]
1329              prefix = '%s[%s].' % (name, key)
1330              sub_errors = element.FindInitializationErrors()
1331              errors += [prefix + error for error in sub_errors]
1332          else:
1333            # ScalarMaps can't have any initialization errors.
1334            pass
1335        elif field.label == _FieldDescriptor.LABEL_REPEATED:
1336          for i in range(len(value)):
1337            element = value[i]
1338            prefix = '%s[%d].' % (name, i)
1339            sub_errors = element.FindInitializationErrors()
1340            errors += [prefix + error for error in sub_errors]
1341        else:
1342          prefix = name + '.'
1343          sub_errors = value.FindInitializationErrors()
1344          errors += [prefix + error for error in sub_errors]
1345
1346    return errors
1347
1348  cls.FindInitializationErrors = FindInitializationErrors
1349
1350
1351def _FullyQualifiedClassName(klass):
1352  module = klass.__module__
1353  name = getattr(klass, '__qualname__', klass.__name__)
1354  if module in (None, 'builtins', '__builtin__'):
1355    return name
1356  return module + '.' + name
1357
1358
1359def _AddMergeFromMethod(cls):
1360  LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1361  CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1362
1363  def MergeFrom(self, msg):
1364    if not isinstance(msg, cls):
1365      raise TypeError(
1366          'Parameter to MergeFrom() must be instance of same class: '
1367          'expected %s got %s.' % (_FullyQualifiedClassName(cls),
1368                                   _FullyQualifiedClassName(msg.__class__)))
1369
1370    assert msg is not self
1371    self._Modified()
1372
1373    fields = self._fields
1374
1375    for field, value in msg._fields.items():
1376      if field.label == LABEL_REPEATED:
1377        field_value = fields.get(field)
1378        if field_value is None:
1379          # Construct a new object to represent this field.
1380          field_value = field._default_constructor(self)
1381          fields[field] = field_value
1382        field_value.MergeFrom(value)
1383      elif field.cpp_type == CPPTYPE_MESSAGE:
1384        if value._is_present_in_parent:
1385          field_value = fields.get(field)
1386          if field_value is None:
1387            # Construct a new object to represent this field.
1388            field_value = field._default_constructor(self)
1389            fields[field] = field_value
1390          field_value.MergeFrom(value)
1391      else:
1392        self._fields[field] = value
1393        if field.containing_oneof:
1394          self._UpdateOneofState(field)
1395
1396    if msg._unknown_fields:
1397      if not self._unknown_fields:
1398        self._unknown_fields = []
1399      self._unknown_fields.extend(msg._unknown_fields)
1400
1401  cls.MergeFrom = MergeFrom
1402
1403
1404def _AddWhichOneofMethod(message_descriptor, cls):
1405  def WhichOneof(self, oneof_name):
1406    """Returns the name of the currently set field inside a oneof, or None."""
1407    try:
1408      field = message_descriptor.oneofs_by_name[oneof_name]
1409    except KeyError:
1410      raise ValueError(
1411          'Protocol message has no oneof "%s" field.' % oneof_name)
1412
1413    nested_field = self._oneofs.get(field, None)
1414    if nested_field is not None and self.HasField(nested_field.name):
1415      return nested_field.name
1416    else:
1417      return None
1418
1419  cls.WhichOneof = WhichOneof
1420
1421
1422def _Clear(self):
1423  # Clear fields.
1424  self._fields = {}
1425  self._unknown_fields = ()
1426
1427  self._oneofs = {}
1428  self._Modified()
1429
1430
1431def _UnknownFields(self):
1432  raise NotImplementedError('Please use the add-on feaure '
1433                            'unknown_fields.UnknownFieldSet(message) in '
1434                            'unknown_fields.py instead.')
1435
1436
1437def _DiscardUnknownFields(self):
1438  self._unknown_fields = []
1439  for field, value in self.ListFields():
1440    if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1441      if _IsMapField(field):
1442        if _IsMessageMapField(field):
1443          for key in value:
1444            value[key].DiscardUnknownFields()
1445      elif field.label == _FieldDescriptor.LABEL_REPEATED:
1446        for sub_message in value:
1447          sub_message.DiscardUnknownFields()
1448      else:
1449        value.DiscardUnknownFields()
1450
1451
1452def _SetListener(self, listener):
1453  if listener is None:
1454    self._listener = message_listener_mod.NullMessageListener()
1455  else:
1456    self._listener = listener
1457
1458
1459def _AddMessageMethods(message_descriptor, cls):
1460  """Adds implementations of all Message methods to cls."""
1461  _AddListFieldsMethod(message_descriptor, cls)
1462  _AddHasFieldMethod(message_descriptor, cls)
1463  _AddClearFieldMethod(message_descriptor, cls)
1464  if message_descriptor.is_extendable:
1465    _AddClearExtensionMethod(cls)
1466    _AddHasExtensionMethod(cls)
1467  _AddEqualsMethod(message_descriptor, cls)
1468  _AddStrMethod(message_descriptor, cls)
1469  _AddReprMethod(message_descriptor, cls)
1470  _AddUnicodeMethod(message_descriptor, cls)
1471  _AddContainsMethod(message_descriptor, cls)
1472  _AddByteSizeMethod(message_descriptor, cls)
1473  _AddSerializeToStringMethod(message_descriptor, cls)
1474  _AddSerializePartialToStringMethod(message_descriptor, cls)
1475  _AddMergeFromStringMethod(message_descriptor, cls)
1476  _AddIsInitializedMethod(message_descriptor, cls)
1477  _AddMergeFromMethod(cls)
1478  _AddWhichOneofMethod(message_descriptor, cls)
1479  # Adds methods which do not depend on cls.
1480  cls.Clear = _Clear
1481  cls.DiscardUnknownFields = _DiscardUnknownFields
1482  cls._SetListener = _SetListener
1483
1484
1485def _AddPrivateHelperMethods(message_descriptor, cls):
1486  """Adds implementation of private helper methods to cls."""
1487
1488  def Modified(self):
1489    """Sets the _cached_byte_size_dirty bit to true,
1490    and propagates this to our listener iff this was a state change.
1491    """
1492
1493    # Note:  Some callers check _cached_byte_size_dirty before calling
1494    #   _Modified() as an extra optimization.  So, if this method is ever
1495    #   changed such that it does stuff even when _cached_byte_size_dirty is
1496    #   already true, the callers need to be updated.
1497    if not self._cached_byte_size_dirty:
1498      self._cached_byte_size_dirty = True
1499      self._listener_for_children.dirty = True
1500      self._is_present_in_parent = True
1501      self._listener.Modified()
1502
1503  def _UpdateOneofState(self, field):
1504    """Sets field as the active field in its containing oneof.
1505
1506    Will also delete currently active field in the oneof, if it is different
1507    from the argument. Does not mark the message as modified.
1508    """
1509    other_field = self._oneofs.setdefault(field.containing_oneof, field)
1510    if other_field is not field:
1511      del self._fields[other_field]
1512      self._oneofs[field.containing_oneof] = field
1513
1514  cls._Modified = Modified
1515  cls.SetInParent = Modified
1516  cls._UpdateOneofState = _UpdateOneofState
1517
1518
1519class _Listener(object):
1520
1521  """MessageListener implementation that a parent message registers with its
1522  child message.
1523
1524  In order to support semantics like:
1525
1526    foo.bar.baz.moo = 23
1527    assert foo.HasField('bar')
1528
1529  ...child objects must have back references to their parents.
1530  This helper class is at the heart of this support.
1531  """
1532
1533  def __init__(self, parent_message):
1534    """Args:
1535      parent_message: The message whose _Modified() method we should call when
1536        we receive Modified() messages.
1537    """
1538    # This listener establishes a back reference from a child (contained) object
1539    # to its parent (containing) object.  We make this a weak reference to avoid
1540    # creating cyclic garbage when the client finishes with the 'parent' object
1541    # in the tree.
1542    if isinstance(parent_message, weakref.ProxyType):
1543      self._parent_message_weakref = parent_message
1544    else:
1545      self._parent_message_weakref = weakref.proxy(parent_message)
1546
1547    # As an optimization, we also indicate directly on the listener whether
1548    # or not the parent message is dirty.  This way we can avoid traversing
1549    # up the tree in the common case.
1550    self.dirty = False
1551
1552  def Modified(self):
1553    if self.dirty:
1554      return
1555    try:
1556      # Propagate the signal to our parents iff this is the first field set.
1557      self._parent_message_weakref._Modified()
1558    except ReferenceError:
1559      # We can get here if a client has kept a reference to a child object,
1560      # and is now setting a field on it, but the child's parent has been
1561      # garbage-collected.  This is not an error.
1562      pass
1563
1564
1565class _OneofListener(_Listener):
1566  """Special listener implementation for setting composite oneof fields."""
1567
1568  def __init__(self, parent_message, field):
1569    """Args:
1570      parent_message: The message whose _Modified() method we should call when
1571        we receive Modified() messages.
1572      field: The descriptor of the field being set in the parent message.
1573    """
1574    super(_OneofListener, self).__init__(parent_message)
1575    self._field = field
1576
1577  def Modified(self):
1578    """Also updates the state of the containing oneof in the parent message."""
1579    try:
1580      self._parent_message_weakref._UpdateOneofState(self._field)
1581      super(_OneofListener, self).Modified()
1582    except ReferenceError:
1583      pass
1584