• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Provides DescriptorPool to use as a container for proto2 descriptors.
9
10The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
11a collection of protocol buffer descriptors for use when dynamically creating
12message types at runtime.
13
14For most applications protocol buffers should be used via modules generated by
15the protocol buffer compiler tool. This should only be used when the type of
16protocol buffers used in an application or library cannot be predetermined.
17
18Below is a straightforward example on how to use this class::
19
20  pool = DescriptorPool()
21  file_descriptor_protos = [ ... ]
22  for file_descriptor_proto in file_descriptor_protos:
23    pool.Add(file_descriptor_proto)
24  my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
25
26The message descriptor can be used in conjunction with the message_factory
27module in order to create a protocol buffer class that can be encoded and
28decoded.
29
30If you want to get a Python class for the specified proto, use the
31helper functions inside google.protobuf.message_factory
32directly instead of this class.
33"""
34
35__author__ = 'matthewtoia@google.com (Matt Toia)'
36
37import collections
38import threading
39import warnings
40
41from google.protobuf import descriptor
42from google.protobuf import descriptor_database
43from google.protobuf import text_encoding
44from google.protobuf.internal import python_edition_defaults
45from google.protobuf.internal import python_message
46
47_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS  # pylint: disable=protected-access
48
49
50def _NormalizeFullyQualifiedName(name):
51  """Remove leading period from fully-qualified type name.
52
53  Due to b/13860351 in descriptor_database.py, types in the root namespace are
54  generated with a leading period. This function removes that prefix.
55
56  Args:
57    name (str): The fully-qualified symbol name.
58
59  Returns:
60    str: The normalized fully-qualified symbol name.
61  """
62  return name.lstrip('.')
63
64
65def _OptionsOrNone(descriptor_proto):
66  """Returns the value of the field `options`, or None if it is not set."""
67  if descriptor_proto.HasField('options'):
68    return descriptor_proto.options
69  else:
70    return None
71
72
73def _IsMessageSetExtension(field):
74  return (field.is_extension and
75          field.containing_type.has_options and
76          field.containing_type.GetOptions().message_set_wire_format and
77          field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
78          field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
79
80_edition_defaults_lock = threading.Lock()
81
82
83class DescriptorPool(object):
84  """A collection of protobufs dynamically constructed by descriptor protos."""
85
86  if _USE_C_DESCRIPTORS:
87
88   def __new__(cls, descriptor_db=None):
89     # pylint: disable=protected-access
90     return descriptor._message.DescriptorPool(descriptor_db)
91
92  def __init__(
93      self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False
94  ):
95    """Initializes a Pool of proto buffs.
96
97    The descriptor_db argument to the constructor is provided to allow
98    specialized file descriptor proto lookup code to be triggered on demand. An
99    example would be an implementation which will read and compile a file
100    specified in a call to FindFileByName() and not require the call to Add()
101    at all. Results from this database will be cached internally here as well.
102
103    Args:
104      descriptor_db: A secondary source of file descriptors.
105      use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with
106        C++.
107    """
108
109    self._internal_db = descriptor_database.DescriptorDatabase()
110    self._descriptor_db = descriptor_db
111    self._descriptors = {}
112    self._enum_descriptors = {}
113    self._service_descriptors = {}
114    self._file_descriptors = {}
115    self._toplevel_extensions = {}
116    self._top_enum_values = {}
117    # We store extensions in two two-level mappings: The first key is the
118    # descriptor of the message being extended, the second key is the extension
119    # full name or its tag number.
120    self._extensions_by_name = collections.defaultdict(dict)
121    self._extensions_by_number = collections.defaultdict(dict)
122    self._serialized_edition_defaults = (
123        python_edition_defaults._PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS
124    )
125    self._edition_defaults = None
126    self._feature_cache = dict()
127
128  def _CheckConflictRegister(self, desc, desc_name, file_name):
129    """Check if the descriptor name conflicts with another of the same name.
130
131    Args:
132      desc: Descriptor of a message, enum, service, extension or enum value.
133      desc_name (str): the full name of desc.
134      file_name (str): The file name of descriptor.
135    """
136    for register, descriptor_type in [
137        (self._descriptors, descriptor.Descriptor),
138        (self._enum_descriptors, descriptor.EnumDescriptor),
139        (self._service_descriptors, descriptor.ServiceDescriptor),
140        (self._toplevel_extensions, descriptor.FieldDescriptor),
141        (self._top_enum_values, descriptor.EnumValueDescriptor)]:
142      if desc_name in register:
143        old_desc = register[desc_name]
144        if isinstance(old_desc, descriptor.EnumValueDescriptor):
145          old_file = old_desc.type.file.name
146        else:
147          old_file = old_desc.file.name
148
149        if not isinstance(desc, descriptor_type) or (
150            old_file != file_name):
151          error_msg = ('Conflict register for file "' + file_name +
152                       '": ' + desc_name +
153                       ' is already defined in file "' +
154                       old_file + '". Please fix the conflict by adding '
155                       'package name on the proto file, or use different '
156                       'name for the duplication.')
157          if isinstance(desc, descriptor.EnumValueDescriptor):
158            error_msg += ('\nNote: enum values appear as '
159                          'siblings of the enum type instead of '
160                          'children of it.')
161
162          raise TypeError(error_msg)
163
164        return
165
166  def Add(self, file_desc_proto):
167    """Adds the FileDescriptorProto and its types to this pool.
168
169    Args:
170      file_desc_proto (FileDescriptorProto): The file descriptor to add.
171    """
172
173    self._internal_db.Add(file_desc_proto)
174
175  def AddSerializedFile(self, serialized_file_desc_proto):
176    """Adds the FileDescriptorProto and its types to this pool.
177
178    Args:
179      serialized_file_desc_proto (bytes): A bytes string, serialization of the
180        :class:`FileDescriptorProto` to add.
181
182    Returns:
183      FileDescriptor: Descriptor for the added file.
184    """
185
186    # pylint: disable=g-import-not-at-top
187    from google.protobuf import descriptor_pb2
188    file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
189        serialized_file_desc_proto)
190    file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto)
191    file_desc.serialized_pb = serialized_file_desc_proto
192    return file_desc
193
194  # Never call this method. It is for internal usage only.
195  def _AddDescriptor(self, desc):
196    """Adds a Descriptor to the pool, non-recursively.
197
198    If the Descriptor contains nested messages or enums, the caller must
199    explicitly register them. This method also registers the FileDescriptor
200    associated with the message.
201
202    Args:
203      desc: A Descriptor.
204    """
205    if not isinstance(desc, descriptor.Descriptor):
206      raise TypeError('Expected instance of descriptor.Descriptor.')
207
208    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
209
210    self._descriptors[desc.full_name] = desc
211    self._AddFileDescriptor(desc.file)
212
213  # Never call this method. It is for internal usage only.
214  def _AddEnumDescriptor(self, enum_desc):
215    """Adds an EnumDescriptor to the pool.
216
217    This method also registers the FileDescriptor associated with the enum.
218
219    Args:
220      enum_desc: An EnumDescriptor.
221    """
222
223    if not isinstance(enum_desc, descriptor.EnumDescriptor):
224      raise TypeError('Expected instance of descriptor.EnumDescriptor.')
225
226    file_name = enum_desc.file.name
227    self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
228    self._enum_descriptors[enum_desc.full_name] = enum_desc
229
230    # Top enum values need to be indexed.
231    # Count the number of dots to see whether the enum is toplevel or nested
232    # in a message. We cannot use enum_desc.containing_type at this stage.
233    if enum_desc.file.package:
234      top_level = (enum_desc.full_name.count('.')
235                   - enum_desc.file.package.count('.') == 1)
236    else:
237      top_level = enum_desc.full_name.count('.') == 0
238    if top_level:
239      file_name = enum_desc.file.name
240      package = enum_desc.file.package
241      for enum_value in enum_desc.values:
242        full_name = _NormalizeFullyQualifiedName(
243            '.'.join((package, enum_value.name)))
244        self._CheckConflictRegister(enum_value, full_name, file_name)
245        self._top_enum_values[full_name] = enum_value
246    self._AddFileDescriptor(enum_desc.file)
247
248  # Never call this method. It is for internal usage only.
249  def _AddServiceDescriptor(self, service_desc):
250    """Adds a ServiceDescriptor to the pool.
251
252    Args:
253      service_desc: A ServiceDescriptor.
254    """
255
256    if not isinstance(service_desc, descriptor.ServiceDescriptor):
257      raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
258
259    self._CheckConflictRegister(service_desc, service_desc.full_name,
260                                service_desc.file.name)
261    self._service_descriptors[service_desc.full_name] = service_desc
262
263  # Never call this method. It is for internal usage only.
264  def _AddExtensionDescriptor(self, extension):
265    """Adds a FieldDescriptor describing an extension to the pool.
266
267    Args:
268      extension: A FieldDescriptor.
269
270    Raises:
271      AssertionError: when another extension with the same number extends the
272        same message.
273      TypeError: when the specified extension is not a
274        descriptor.FieldDescriptor.
275    """
276    if not (isinstance(extension, descriptor.FieldDescriptor) and
277            extension.is_extension):
278      raise TypeError('Expected an extension descriptor.')
279
280    if extension.extension_scope is None:
281      self._CheckConflictRegister(
282          extension, extension.full_name, extension.file.name)
283      self._toplevel_extensions[extension.full_name] = extension
284
285    try:
286      existing_desc = self._extensions_by_number[
287          extension.containing_type][extension.number]
288    except KeyError:
289      pass
290    else:
291      if extension is not existing_desc:
292        raise AssertionError(
293            'Extensions "%s" and "%s" both try to extend message type "%s" '
294            'with field number %d.' %
295            (extension.full_name, existing_desc.full_name,
296             extension.containing_type.full_name, extension.number))
297
298    self._extensions_by_number[extension.containing_type][
299        extension.number] = extension
300    self._extensions_by_name[extension.containing_type][
301        extension.full_name] = extension
302
303    # Also register MessageSet extensions with the type name.
304    if _IsMessageSetExtension(extension):
305      self._extensions_by_name[extension.containing_type][
306          extension.message_type.full_name] = extension
307
308    if hasattr(extension.containing_type, '_concrete_class'):
309      python_message._AttachFieldHelpers(
310          extension.containing_type._concrete_class, extension)
311
312  # Never call this method. It is for internal usage only.
313  def _InternalAddFileDescriptor(self, file_desc):
314    """Adds a FileDescriptor to the pool, non-recursively.
315
316    If the FileDescriptor contains messages or enums, the caller must explicitly
317    register them.
318
319    Args:
320      file_desc: A FileDescriptor.
321    """
322
323    self._AddFileDescriptor(file_desc)
324
325  def _AddFileDescriptor(self, file_desc):
326    """Adds a FileDescriptor to the pool, non-recursively.
327
328    If the FileDescriptor contains messages or enums, the caller must explicitly
329    register them.
330
331    Args:
332      file_desc: A FileDescriptor.
333    """
334
335    if not isinstance(file_desc, descriptor.FileDescriptor):
336      raise TypeError('Expected instance of descriptor.FileDescriptor.')
337    self._file_descriptors[file_desc.name] = file_desc
338
339  def FindFileByName(self, file_name):
340    """Gets a FileDescriptor by file name.
341
342    Args:
343      file_name (str): The path to the file to get a descriptor for.
344
345    Returns:
346      FileDescriptor: The descriptor for the named file.
347
348    Raises:
349      KeyError: if the file cannot be found in the pool.
350    """
351
352    try:
353      return self._file_descriptors[file_name]
354    except KeyError:
355      pass
356
357    try:
358      file_proto = self._internal_db.FindFileByName(file_name)
359    except KeyError as error:
360      if self._descriptor_db:
361        file_proto = self._descriptor_db.FindFileByName(file_name)
362      else:
363        raise error
364    if not file_proto:
365      raise KeyError('Cannot find a file named %s' % file_name)
366    return self._ConvertFileProtoToFileDescriptor(file_proto)
367
368  def FindFileContainingSymbol(self, symbol):
369    """Gets the FileDescriptor for the file containing the specified symbol.
370
371    Args:
372      symbol (str): The name of the symbol to search for.
373
374    Returns:
375      FileDescriptor: Descriptor for the file that contains the specified
376      symbol.
377
378    Raises:
379      KeyError: if the file cannot be found in the pool.
380    """
381
382    symbol = _NormalizeFullyQualifiedName(symbol)
383    try:
384      return self._InternalFindFileContainingSymbol(symbol)
385    except KeyError:
386      pass
387
388    try:
389      # Try fallback database. Build and find again if possible.
390      self._FindFileContainingSymbolInDb(symbol)
391      return self._InternalFindFileContainingSymbol(symbol)
392    except KeyError:
393      raise KeyError('Cannot find a file containing %s' % symbol)
394
395  def _InternalFindFileContainingSymbol(self, symbol):
396    """Gets the already built FileDescriptor containing the specified symbol.
397
398    Args:
399      symbol (str): The name of the symbol to search for.
400
401    Returns:
402      FileDescriptor: Descriptor for the file that contains the specified
403      symbol.
404
405    Raises:
406      KeyError: if the file cannot be found in the pool.
407    """
408    try:
409      return self._descriptors[symbol].file
410    except KeyError:
411      pass
412
413    try:
414      return self._enum_descriptors[symbol].file
415    except KeyError:
416      pass
417
418    try:
419      return self._service_descriptors[symbol].file
420    except KeyError:
421      pass
422
423    try:
424      return self._top_enum_values[symbol].type.file
425    except KeyError:
426      pass
427
428    try:
429      return self._toplevel_extensions[symbol].file
430    except KeyError:
431      pass
432
433    # Try fields, enum values and nested extensions inside a message.
434    top_name, _, sub_name = symbol.rpartition('.')
435    try:
436      message = self.FindMessageTypeByName(top_name)
437      assert (sub_name in message.extensions_by_name or
438              sub_name in message.fields_by_name or
439              sub_name in message.enum_values_by_name)
440      return message.file
441    except (KeyError, AssertionError):
442      raise KeyError('Cannot find a file containing %s' % symbol)
443
444  def FindMessageTypeByName(self, full_name):
445    """Loads the named descriptor from the pool.
446
447    Args:
448      full_name (str): The full name of the descriptor to load.
449
450    Returns:
451      Descriptor: The descriptor for the named type.
452
453    Raises:
454      KeyError: if the message cannot be found in the pool.
455    """
456
457    full_name = _NormalizeFullyQualifiedName(full_name)
458    if full_name not in self._descriptors:
459      self._FindFileContainingSymbolInDb(full_name)
460    return self._descriptors[full_name]
461
462  def FindEnumTypeByName(self, full_name):
463    """Loads the named enum descriptor from the pool.
464
465    Args:
466      full_name (str): The full name of the enum descriptor to load.
467
468    Returns:
469      EnumDescriptor: The enum descriptor for the named type.
470
471    Raises:
472      KeyError: if the enum cannot be found in the pool.
473    """
474
475    full_name = _NormalizeFullyQualifiedName(full_name)
476    if full_name not in self._enum_descriptors:
477      self._FindFileContainingSymbolInDb(full_name)
478    return self._enum_descriptors[full_name]
479
480  def FindFieldByName(self, full_name):
481    """Loads the named field descriptor from the pool.
482
483    Args:
484      full_name (str): The full name of the field descriptor to load.
485
486    Returns:
487      FieldDescriptor: The field descriptor for the named field.
488
489    Raises:
490      KeyError: if the field cannot be found in the pool.
491    """
492    full_name = _NormalizeFullyQualifiedName(full_name)
493    message_name, _, field_name = full_name.rpartition('.')
494    message_descriptor = self.FindMessageTypeByName(message_name)
495    return message_descriptor.fields_by_name[field_name]
496
497  def FindOneofByName(self, full_name):
498    """Loads the named oneof descriptor from the pool.
499
500    Args:
501      full_name (str): The full name of the oneof descriptor to load.
502
503    Returns:
504      OneofDescriptor: The oneof descriptor for the named oneof.
505
506    Raises:
507      KeyError: if the oneof cannot be found in the pool.
508    """
509    full_name = _NormalizeFullyQualifiedName(full_name)
510    message_name, _, oneof_name = full_name.rpartition('.')
511    message_descriptor = self.FindMessageTypeByName(message_name)
512    return message_descriptor.oneofs_by_name[oneof_name]
513
514  def FindExtensionByName(self, full_name):
515    """Loads the named extension descriptor from the pool.
516
517    Args:
518      full_name (str): The full name of the extension descriptor to load.
519
520    Returns:
521      FieldDescriptor: The field descriptor for the named extension.
522
523    Raises:
524      KeyError: if the extension cannot be found in the pool.
525    """
526    full_name = _NormalizeFullyQualifiedName(full_name)
527    try:
528      # The proto compiler does not give any link between the FileDescriptor
529      # and top-level extensions unless the FileDescriptorProto is added to
530      # the DescriptorDatabase, but this can impact memory usage.
531      # So we registered these extensions by name explicitly.
532      return self._toplevel_extensions[full_name]
533    except KeyError:
534      pass
535    message_name, _, extension_name = full_name.rpartition('.')
536    try:
537      # Most extensions are nested inside a message.
538      scope = self.FindMessageTypeByName(message_name)
539    except KeyError:
540      # Some extensions are defined at file scope.
541      scope = self._FindFileContainingSymbolInDb(full_name)
542    return scope.extensions_by_name[extension_name]
543
544  def FindExtensionByNumber(self, message_descriptor, number):
545    """Gets the extension of the specified message with the specified number.
546
547    Extensions have to be registered to this pool by calling :func:`Add` or
548    :func:`AddExtensionDescriptor`.
549
550    Args:
551      message_descriptor (Descriptor): descriptor of the extended message.
552      number (int): Number of the extension field.
553
554    Returns:
555      FieldDescriptor: The descriptor for the extension.
556
557    Raises:
558      KeyError: when no extension with the given number is known for the
559        specified message.
560    """
561    try:
562      return self._extensions_by_number[message_descriptor][number]
563    except KeyError:
564      self._TryLoadExtensionFromDB(message_descriptor, number)
565      return self._extensions_by_number[message_descriptor][number]
566
567  def FindAllExtensions(self, message_descriptor):
568    """Gets all the known extensions of a given message.
569
570    Extensions have to be registered to this pool by build related
571    :func:`Add` or :func:`AddExtensionDescriptor`.
572
573    Args:
574      message_descriptor (Descriptor): Descriptor of the extended message.
575
576    Returns:
577      list[FieldDescriptor]: Field descriptors describing the extensions.
578    """
579    # Fallback to descriptor db if FindAllExtensionNumbers is provided.
580    if self._descriptor_db and hasattr(
581        self._descriptor_db, 'FindAllExtensionNumbers'):
582      full_name = message_descriptor.full_name
583      all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
584      for number in all_numbers:
585        if number in self._extensions_by_number[message_descriptor]:
586          continue
587        self._TryLoadExtensionFromDB(message_descriptor, number)
588
589    return list(self._extensions_by_number[message_descriptor].values())
590
591  def _TryLoadExtensionFromDB(self, message_descriptor, number):
592    """Try to Load extensions from descriptor db.
593
594    Args:
595      message_descriptor: descriptor of the extended message.
596      number: the extension number that needs to be loaded.
597    """
598    if not self._descriptor_db:
599      return
600    # Only supported when FindFileContainingExtension is provided.
601    if not hasattr(
602        self._descriptor_db, 'FindFileContainingExtension'):
603      return
604
605    full_name = message_descriptor.full_name
606    file_proto = self._descriptor_db.FindFileContainingExtension(
607        full_name, number)
608
609    if file_proto is None:
610      return
611
612    try:
613      self._ConvertFileProtoToFileDescriptor(file_proto)
614    except:
615      warn_msg = ('Unable to load proto file %s for extension number %d.' %
616                  (file_proto.name, number))
617      warnings.warn(warn_msg, RuntimeWarning)
618
619  def FindServiceByName(self, full_name):
620    """Loads the named service descriptor from the pool.
621
622    Args:
623      full_name (str): The full name of the service descriptor to load.
624
625    Returns:
626      ServiceDescriptor: The service descriptor for the named service.
627
628    Raises:
629      KeyError: if the service cannot be found in the pool.
630    """
631    full_name = _NormalizeFullyQualifiedName(full_name)
632    if full_name not in self._service_descriptors:
633      self._FindFileContainingSymbolInDb(full_name)
634    return self._service_descriptors[full_name]
635
636  def FindMethodByName(self, full_name):
637    """Loads the named service method descriptor from the pool.
638
639    Args:
640      full_name (str): The full name of the method descriptor to load.
641
642    Returns:
643      MethodDescriptor: The method descriptor for the service method.
644
645    Raises:
646      KeyError: if the method cannot be found in the pool.
647    """
648    full_name = _NormalizeFullyQualifiedName(full_name)
649    service_name, _, method_name = full_name.rpartition('.')
650    service_descriptor = self.FindServiceByName(service_name)
651    return service_descriptor.methods_by_name[method_name]
652
653  def SetFeatureSetDefaults(self, defaults):
654    """Sets the default feature mappings used during the build.
655
656    Args:
657      defaults: a FeatureSetDefaults message containing the new mappings.
658    """
659    if self._edition_defaults is not None:
660      raise ValueError(
661          "Feature set defaults can't be changed once the pool has started"
662          ' building!'
663      )
664
665    # pylint: disable=g-import-not-at-top
666    from google.protobuf import descriptor_pb2
667
668    if not isinstance(defaults, descriptor_pb2.FeatureSetDefaults):
669      raise TypeError('SetFeatureSetDefaults called with invalid type')
670
671
672    if defaults.minimum_edition > defaults.maximum_edition:
673      raise ValueError(
674          'Invalid edition range %s to %s'
675          % (
676              descriptor_pb2.Edition.Name(defaults.minimum_edition),
677              descriptor_pb2.Edition.Name(defaults.maximum_edition),
678          )
679      )
680
681    prev_edition = descriptor_pb2.Edition.EDITION_UNKNOWN
682    for d in defaults.defaults:
683      if d.edition == descriptor_pb2.Edition.EDITION_UNKNOWN:
684        raise ValueError('Invalid edition EDITION_UNKNOWN specified')
685      if prev_edition >= d.edition:
686        raise ValueError(
687            'Feature set defaults are not strictly increasing.  %s is greater'
688            ' than or equal to %s'
689            % (
690                descriptor_pb2.Edition.Name(prev_edition),
691                descriptor_pb2.Edition.Name(d.edition),
692            )
693        )
694      prev_edition = d.edition
695    self._edition_defaults = defaults
696
697  def _CreateDefaultFeatures(self, edition):
698    """Creates a FeatureSet message with defaults for a specific edition.
699
700    Args:
701      edition: the edition to generate defaults for.
702
703    Returns:
704      A FeatureSet message with defaults for a specific edition.
705    """
706    # pylint: disable=g-import-not-at-top
707    from google.protobuf import descriptor_pb2
708
709    with _edition_defaults_lock:
710      if not self._edition_defaults:
711        self._edition_defaults = descriptor_pb2.FeatureSetDefaults()
712        self._edition_defaults.ParseFromString(
713            self._serialized_edition_defaults
714        )
715
716    if edition < self._edition_defaults.minimum_edition:
717      raise TypeError(
718          'Edition %s is earlier than the minimum supported edition %s!'
719          % (
720              descriptor_pb2.Edition.Name(edition),
721              descriptor_pb2.Edition.Name(
722                  self._edition_defaults.minimum_edition
723              ),
724          )
725      )
726    if edition > self._edition_defaults.maximum_edition:
727      raise TypeError(
728          'Edition %s is later than the maximum supported edition %s!'
729          % (
730              descriptor_pb2.Edition.Name(edition),
731              descriptor_pb2.Edition.Name(
732                  self._edition_defaults.maximum_edition
733              ),
734          )
735      )
736    found = None
737    for d in self._edition_defaults.defaults:
738      if d.edition > edition:
739        break
740      found = d
741    if found is None:
742      raise TypeError(
743          'No valid default found for edition %s!'
744          % descriptor_pb2.Edition.Name(edition)
745      )
746
747    defaults = descriptor_pb2.FeatureSet()
748    defaults.CopyFrom(found.fixed_features)
749    defaults.MergeFrom(found.overridable_features)
750    return defaults
751
752  def _InternFeatures(self, features):
753    serialized = features.SerializeToString()
754    with _edition_defaults_lock:
755      cached = self._feature_cache.get(serialized)
756      if cached is None:
757        self._feature_cache[serialized] = features
758        cached = features
759    return cached
760
761  def _FindFileContainingSymbolInDb(self, symbol):
762    """Finds the file in descriptor DB containing the specified symbol.
763
764    Args:
765      symbol (str): The name of the symbol to search for.
766
767    Returns:
768      FileDescriptor: The file that contains the specified symbol.
769
770    Raises:
771      KeyError: if the file cannot be found in the descriptor database.
772    """
773    try:
774      file_proto = self._internal_db.FindFileContainingSymbol(symbol)
775    except KeyError as error:
776      if self._descriptor_db:
777        file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
778      else:
779        raise error
780    if not file_proto:
781      raise KeyError('Cannot find a file containing %s' % symbol)
782    return self._ConvertFileProtoToFileDescriptor(file_proto)
783
784  def _ConvertFileProtoToFileDescriptor(self, file_proto):
785    """Creates a FileDescriptor from a proto or returns a cached copy.
786
787    This method also has the side effect of loading all the symbols found in
788    the file into the appropriate dictionaries in the pool.
789
790    Args:
791      file_proto: The proto to convert.
792
793    Returns:
794      A FileDescriptor matching the passed in proto.
795    """
796    if file_proto.name not in self._file_descriptors:
797      built_deps = list(self._GetDeps(file_proto.dependency))
798      direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
799      public_deps = [direct_deps[i] for i in file_proto.public_dependency]
800
801      # pylint: disable=g-import-not-at-top
802      from google.protobuf import descriptor_pb2
803
804      file_descriptor = descriptor.FileDescriptor(
805          pool=self,
806          name=file_proto.name,
807          package=file_proto.package,
808          syntax=file_proto.syntax,
809          edition=descriptor_pb2.Edition.Name(file_proto.edition),
810          options=_OptionsOrNone(file_proto),
811          serialized_pb=file_proto.SerializeToString(),
812          dependencies=direct_deps,
813          public_dependencies=public_deps,
814          # pylint: disable=protected-access
815          create_key=descriptor._internal_create_key,
816      )
817      scope = {}
818
819      # This loop extracts all the message and enum types from all the
820      # dependencies of the file_proto. This is necessary to create the
821      # scope of available message types when defining the passed in
822      # file proto.
823      for dependency in built_deps:
824        scope.update(self._ExtractSymbols(
825            dependency.message_types_by_name.values()))
826        scope.update((_PrefixWithDot(enum.full_name), enum)
827                     for enum in dependency.enum_types_by_name.values())
828
829      for message_type in file_proto.message_type:
830        message_desc = self._ConvertMessageDescriptor(
831            message_type, file_proto.package, file_descriptor, scope,
832            file_proto.syntax)
833        file_descriptor.message_types_by_name[message_desc.name] = (
834            message_desc)
835
836      for enum_type in file_proto.enum_type:
837        file_descriptor.enum_types_by_name[enum_type.name] = (
838            self._ConvertEnumDescriptor(enum_type, file_proto.package,
839                                        file_descriptor, None, scope, True))
840
841      for index, extension_proto in enumerate(file_proto.extension):
842        extension_desc = self._MakeFieldDescriptor(
843            extension_proto, file_proto.package, index, file_descriptor,
844            is_extension=True)
845        extension_desc.containing_type = self._GetTypeFromScope(
846            file_descriptor.package, extension_proto.extendee, scope)
847        self._SetFieldType(extension_proto, extension_desc,
848                           file_descriptor.package, scope)
849        file_descriptor.extensions_by_name[extension_desc.name] = (
850            extension_desc)
851
852      for desc_proto in file_proto.message_type:
853        self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
854
855      if file_proto.package:
856        desc_proto_prefix = _PrefixWithDot(file_proto.package)
857      else:
858        desc_proto_prefix = ''
859
860      for desc_proto in file_proto.message_type:
861        desc = self._GetTypeFromScope(
862            desc_proto_prefix, desc_proto.name, scope)
863        file_descriptor.message_types_by_name[desc_proto.name] = desc
864
865      for index, service_proto in enumerate(file_proto.service):
866        file_descriptor.services_by_name[service_proto.name] = (
867            self._MakeServiceDescriptor(service_proto, index, scope,
868                                        file_proto.package, file_descriptor))
869
870      self._file_descriptors[file_proto.name] = file_descriptor
871
872    # Add extensions to the pool
873    def AddExtensionForNested(message_type):
874      for nested in message_type.nested_types:
875        AddExtensionForNested(nested)
876      for extension in message_type.extensions:
877        self._AddExtensionDescriptor(extension)
878
879    file_desc = self._file_descriptors[file_proto.name]
880    for extension in file_desc.extensions_by_name.values():
881      self._AddExtensionDescriptor(extension)
882    for message_type in file_desc.message_types_by_name.values():
883      AddExtensionForNested(message_type)
884
885    return file_desc
886
887  def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
888                                scope=None, syntax=None):
889    """Adds the proto to the pool in the specified package.
890
891    Args:
892      desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
893      package: The package the proto should be located in.
894      file_desc: The file containing this message.
895      scope: Dict mapping short and full symbols to message and enum types.
896      syntax: string indicating syntax of the file ("proto2" or "proto3")
897
898    Returns:
899      The added descriptor.
900    """
901
902    if package:
903      desc_name = '.'.join((package, desc_proto.name))
904    else:
905      desc_name = desc_proto.name
906
907    if file_desc is None:
908      file_name = None
909    else:
910      file_name = file_desc.name
911
912    if scope is None:
913      scope = {}
914
915    nested = [
916        self._ConvertMessageDescriptor(
917            nested, desc_name, file_desc, scope, syntax)
918        for nested in desc_proto.nested_type]
919    enums = [
920        self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
921                                    scope, False)
922        for enum in desc_proto.enum_type]
923    fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
924              for index, field in enumerate(desc_proto.field)]
925    extensions = [
926        self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
927                                  is_extension=True)
928        for index, extension in enumerate(desc_proto.extension)]
929    oneofs = [
930        # pylint: disable=g-complex-comprehension
931        descriptor.OneofDescriptor(
932            desc.name,
933            '.'.join((desc_name, desc.name)),
934            index,
935            None,
936            [],
937            _OptionsOrNone(desc),
938            # pylint: disable=protected-access
939            create_key=descriptor._internal_create_key)
940        for index, desc in enumerate(desc_proto.oneof_decl)
941    ]
942    extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
943    if extension_ranges:
944      is_extendable = True
945    else:
946      is_extendable = False
947    desc = descriptor.Descriptor(
948        name=desc_proto.name,
949        full_name=desc_name,
950        filename=file_name,
951        containing_type=None,
952        fields=fields,
953        oneofs=oneofs,
954        nested_types=nested,
955        enum_types=enums,
956        extensions=extensions,
957        options=_OptionsOrNone(desc_proto),
958        is_extendable=is_extendable,
959        extension_ranges=extension_ranges,
960        file=file_desc,
961        serialized_start=None,
962        serialized_end=None,
963        is_map_entry=desc_proto.options.map_entry,
964        # pylint: disable=protected-access
965        create_key=descriptor._internal_create_key,
966    )
967    for nested in desc.nested_types:
968      nested.containing_type = desc
969    for enum in desc.enum_types:
970      enum.containing_type = desc
971    for field_index, field_desc in enumerate(desc_proto.field):
972      if field_desc.HasField('oneof_index'):
973        oneof_index = field_desc.oneof_index
974        oneofs[oneof_index].fields.append(fields[field_index])
975        fields[field_index].containing_oneof = oneofs[oneof_index]
976
977    scope[_PrefixWithDot(desc_name)] = desc
978    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
979    self._descriptors[desc_name] = desc
980    return desc
981
982  def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
983                             containing_type=None, scope=None, top_level=False):
984    """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
985
986    Args:
987      enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
988      package: Optional package name for the new message EnumDescriptor.
989      file_desc: The file containing the enum descriptor.
990      containing_type: The type containing this enum.
991      scope: Scope containing available types.
992      top_level: If True, the enum is a top level symbol. If False, the enum
993          is defined inside a message.
994
995    Returns:
996      The added descriptor
997    """
998
999    if package:
1000      enum_name = '.'.join((package, enum_proto.name))
1001    else:
1002      enum_name = enum_proto.name
1003
1004    if file_desc is None:
1005      file_name = None
1006    else:
1007      file_name = file_desc.name
1008
1009    values = [self._MakeEnumValueDescriptor(value, index)
1010              for index, value in enumerate(enum_proto.value)]
1011    desc = descriptor.EnumDescriptor(name=enum_proto.name,
1012                                     full_name=enum_name,
1013                                     filename=file_name,
1014                                     file=file_desc,
1015                                     values=values,
1016                                     containing_type=containing_type,
1017                                     options=_OptionsOrNone(enum_proto),
1018                                     # pylint: disable=protected-access
1019                                     create_key=descriptor._internal_create_key)
1020    scope['.%s' % enum_name] = desc
1021    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1022    self._enum_descriptors[enum_name] = desc
1023
1024    # Add top level enum values.
1025    if top_level:
1026      for value in values:
1027        full_name = _NormalizeFullyQualifiedName(
1028            '.'.join((package, value.name)))
1029        self._CheckConflictRegister(value, full_name, file_name)
1030        self._top_enum_values[full_name] = value
1031
1032    return desc
1033
1034  def _MakeFieldDescriptor(self, field_proto, message_name, index,
1035                           file_desc, is_extension=False):
1036    """Creates a field descriptor from a FieldDescriptorProto.
1037
1038    For message and enum type fields, this method will do a look up
1039    in the pool for the appropriate descriptor for that type. If it
1040    is unavailable, it will fall back to the _source function to
1041    create it. If this type is still unavailable, construction will
1042    fail.
1043
1044    Args:
1045      field_proto: The proto describing the field.
1046      message_name: The name of the containing message.
1047      index: Index of the field
1048      file_desc: The file containing the field descriptor.
1049      is_extension: Indication that this field is for an extension.
1050
1051    Returns:
1052      An initialized FieldDescriptor object
1053    """
1054
1055    if message_name:
1056      full_name = '.'.join((message_name, field_proto.name))
1057    else:
1058      full_name = field_proto.name
1059
1060    if field_proto.json_name:
1061      json_name = field_proto.json_name
1062    else:
1063      json_name = None
1064
1065    return descriptor.FieldDescriptor(
1066        name=field_proto.name,
1067        full_name=full_name,
1068        index=index,
1069        number=field_proto.number,
1070        type=field_proto.type,
1071        cpp_type=None,
1072        message_type=None,
1073        enum_type=None,
1074        containing_type=None,
1075        label=field_proto.label,
1076        has_default_value=False,
1077        default_value=None,
1078        is_extension=is_extension,
1079        extension_scope=None,
1080        options=_OptionsOrNone(field_proto),
1081        json_name=json_name,
1082        file=file_desc,
1083        # pylint: disable=protected-access
1084        create_key=descriptor._internal_create_key)
1085
1086  def _SetAllFieldTypes(self, package, desc_proto, scope):
1087    """Sets all the descriptor's fields's types.
1088
1089    This method also sets the containing types on any extensions.
1090
1091    Args:
1092      package: The current package of desc_proto.
1093      desc_proto: The message descriptor to update.
1094      scope: Enclosing scope of available types.
1095    """
1096
1097    package = _PrefixWithDot(package)
1098
1099    main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1100
1101    if package == '.':
1102      nested_package = _PrefixWithDot(desc_proto.name)
1103    else:
1104      nested_package = '.'.join([package, desc_proto.name])
1105
1106    for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1107      self._SetFieldType(field_proto, field_desc, nested_package, scope)
1108
1109    for extension_proto, extension_desc in (
1110        zip(desc_proto.extension, main_desc.extensions)):
1111      extension_desc.containing_type = self._GetTypeFromScope(
1112          nested_package, extension_proto.extendee, scope)
1113      self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1114
1115    for nested_type in desc_proto.nested_type:
1116      self._SetAllFieldTypes(nested_package, nested_type, scope)
1117
1118  def _SetFieldType(self, field_proto, field_desc, package, scope):
1119    """Sets the field's type, cpp_type, message_type and enum_type.
1120
1121    Args:
1122      field_proto: Data about the field in proto format.
1123      field_desc: The descriptor to modify.
1124      package: The package the field's container is in.
1125      scope: Enclosing scope of available types.
1126    """
1127    if field_proto.type_name:
1128      desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1129    else:
1130      desc = None
1131
1132    if not field_proto.HasField('type'):
1133      if isinstance(desc, descriptor.Descriptor):
1134        field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1135      else:
1136        field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1137
1138    field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1139        field_proto.type)
1140
1141    if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1142        or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1143      field_desc.message_type = desc
1144
1145    if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1146      field_desc.enum_type = desc
1147
1148    if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1149      field_desc.has_default_value = False
1150      field_desc.default_value = []
1151    elif field_proto.HasField('default_value'):
1152      field_desc.has_default_value = True
1153      if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1154          field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1155        field_desc.default_value = float(field_proto.default_value)
1156      elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1157        field_desc.default_value = field_proto.default_value
1158      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1159        field_desc.default_value = field_proto.default_value.lower() == 'true'
1160      elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1161        field_desc.default_value = field_desc.enum_type.values_by_name[
1162            field_proto.default_value].number
1163      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1164        field_desc.default_value = text_encoding.CUnescape(
1165            field_proto.default_value)
1166      elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1167        field_desc.default_value = None
1168      else:
1169        # All other types are of the "int" type.
1170        field_desc.default_value = int(field_proto.default_value)
1171    else:
1172      field_desc.has_default_value = False
1173      if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1174          field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1175        field_desc.default_value = 0.0
1176      elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1177        field_desc.default_value = u''
1178      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1179        field_desc.default_value = False
1180      elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1181        field_desc.default_value = field_desc.enum_type.values[0].number
1182      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1183        field_desc.default_value = b''
1184      elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1185        field_desc.default_value = None
1186      elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP:
1187        field_desc.default_value = None
1188      else:
1189        # All other types are of the "int" type.
1190        field_desc.default_value = 0
1191
1192    field_desc.type = field_proto.type
1193
1194  def _MakeEnumValueDescriptor(self, value_proto, index):
1195    """Creates a enum value descriptor object from a enum value proto.
1196
1197    Args:
1198      value_proto: The proto describing the enum value.
1199      index: The index of the enum value.
1200
1201    Returns:
1202      An initialized EnumValueDescriptor object.
1203    """
1204
1205    return descriptor.EnumValueDescriptor(
1206        name=value_proto.name,
1207        index=index,
1208        number=value_proto.number,
1209        options=_OptionsOrNone(value_proto),
1210        type=None,
1211        # pylint: disable=protected-access
1212        create_key=descriptor._internal_create_key)
1213
1214  def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1215                             package, file_desc):
1216    """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1217
1218    Args:
1219      service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1220      service_index: The index of the service in the File.
1221      scope: Dict mapping short and full symbols to message and enum types.
1222      package: Optional package name for the new message EnumDescriptor.
1223      file_desc: The file containing the service descriptor.
1224
1225    Returns:
1226      The added descriptor.
1227    """
1228
1229    if package:
1230      service_name = '.'.join((package, service_proto.name))
1231    else:
1232      service_name = service_proto.name
1233
1234    methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1235                                          scope, index)
1236               for index, method_proto in enumerate(service_proto.method)]
1237    desc = descriptor.ServiceDescriptor(
1238        name=service_proto.name,
1239        full_name=service_name,
1240        index=service_index,
1241        methods=methods,
1242        options=_OptionsOrNone(service_proto),
1243        file=file_desc,
1244        # pylint: disable=protected-access
1245        create_key=descriptor._internal_create_key)
1246    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1247    self._service_descriptors[service_name] = desc
1248    return desc
1249
1250  def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1251                            index):
1252    """Creates a method descriptor from a MethodDescriptorProto.
1253
1254    Args:
1255      method_proto: The proto describing the method.
1256      service_name: The name of the containing service.
1257      package: Optional package name to look up for types.
1258      scope: Scope containing available types.
1259      index: Index of the method in the service.
1260
1261    Returns:
1262      An initialized MethodDescriptor object.
1263    """
1264    full_name = '.'.join((service_name, method_proto.name))
1265    input_type = self._GetTypeFromScope(
1266        package, method_proto.input_type, scope)
1267    output_type = self._GetTypeFromScope(
1268        package, method_proto.output_type, scope)
1269    return descriptor.MethodDescriptor(
1270        name=method_proto.name,
1271        full_name=full_name,
1272        index=index,
1273        containing_service=None,
1274        input_type=input_type,
1275        output_type=output_type,
1276        client_streaming=method_proto.client_streaming,
1277        server_streaming=method_proto.server_streaming,
1278        options=_OptionsOrNone(method_proto),
1279        # pylint: disable=protected-access
1280        create_key=descriptor._internal_create_key)
1281
1282  def _ExtractSymbols(self, descriptors):
1283    """Pulls out all the symbols from descriptor protos.
1284
1285    Args:
1286      descriptors: The messages to extract descriptors from.
1287    Yields:
1288      A two element tuple of the type name and descriptor object.
1289    """
1290
1291    for desc in descriptors:
1292      yield (_PrefixWithDot(desc.full_name), desc)
1293      for symbol in self._ExtractSymbols(desc.nested_types):
1294        yield symbol
1295      for enum in desc.enum_types:
1296        yield (_PrefixWithDot(enum.full_name), enum)
1297
1298  def _GetDeps(self, dependencies, visited=None):
1299    """Recursively finds dependencies for file protos.
1300
1301    Args:
1302      dependencies: The names of the files being depended on.
1303      visited: The names of files already found.
1304
1305    Yields:
1306      Each direct and indirect dependency.
1307    """
1308
1309    visited = visited or set()
1310    for dependency in dependencies:
1311      if dependency not in visited:
1312        visited.add(dependency)
1313        dep_desc = self.FindFileByName(dependency)
1314        yield dep_desc
1315        public_files = [d.name for d in dep_desc.public_dependencies]
1316        yield from self._GetDeps(public_files, visited)
1317
1318  def _GetTypeFromScope(self, package, type_name, scope):
1319    """Finds a given type name in the current scope.
1320
1321    Args:
1322      package: The package the proto should be located in.
1323      type_name: The name of the type to be found in the scope.
1324      scope: Dict mapping short and full symbols to message and enum types.
1325
1326    Returns:
1327      The descriptor for the requested type.
1328    """
1329    if type_name not in scope:
1330      components = _PrefixWithDot(package).split('.')
1331      while components:
1332        possible_match = '.'.join(components + [type_name])
1333        if possible_match in scope:
1334          type_name = possible_match
1335          break
1336        else:
1337          components.pop(-1)
1338    return scope[type_name]
1339
1340
1341def _PrefixWithDot(name):
1342  return name if name.startswith('.') else '.%s' % name
1343
1344
1345if _USE_C_DESCRIPTORS:
1346  # TODO: This pool could be constructed from Python code, when we
1347  # support a flag like 'use_cpp_generated_pool=True'.
1348  # pylint: disable=protected-access
1349  _DEFAULT = descriptor._message.default_pool
1350else:
1351  _DEFAULT = DescriptorPool()
1352
1353
1354def Default():
1355  return _DEFAULT
1356