• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Provides DescriptorPool to use as a container for proto2 descriptors.
32
33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
34a collection of protocol buffer descriptors for use when dynamically creating
35message types at runtime.
36
37For most applications protocol buffers should be used via modules generated by
38the protocol buffer compiler tool. This should only be used when the type of
39protocol buffers used in an application or library cannot be predetermined.
40
41Below is a straightforward example on how to use this class::
42
43  pool = DescriptorPool()
44  file_descriptor_protos = [ ... ]
45  for file_descriptor_proto in file_descriptor_protos:
46    pool.Add(file_descriptor_proto)
47  my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
48
49The message descriptor can be used in conjunction with the message_factory
50module in order to create a protocol buffer class that can be encoded and
51decoded.
52
53If you want to get a Python class for the specified proto, use the
54helper functions inside google.protobuf.message_factory
55directly instead of this class.
56"""
57
58__author__ = 'matthewtoia@google.com (Matt Toia)'
59
60import collections
61import warnings
62
63from google.protobuf import descriptor
64from google.protobuf import descriptor_database
65from google.protobuf import text_encoding
66
67
68_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS  # pylint: disable=protected-access
69
70
71def _Deprecated(func):
72  """Mark functions as deprecated."""
73
74  def NewFunc(*args, **kwargs):
75    warnings.warn(
76        'Call to deprecated function %s(). Note: Do add unlinked descriptors '
77        'to descriptor_pool is wrong. Use Add() or AddSerializedFile() '
78        'instead.' % func.__name__,
79        category=DeprecationWarning)
80    return func(*args, **kwargs)
81  NewFunc.__name__ = func.__name__
82  NewFunc.__doc__ = func.__doc__
83  NewFunc.__dict__.update(func.__dict__)
84  return NewFunc
85
86
87def _NormalizeFullyQualifiedName(name):
88  """Remove leading period from fully-qualified type name.
89
90  Due to b/13860351 in descriptor_database.py, types in the root namespace are
91  generated with a leading period. This function removes that prefix.
92
93  Args:
94    name (str): The fully-qualified symbol name.
95
96  Returns:
97    str: The normalized fully-qualified symbol name.
98  """
99  return name.lstrip('.')
100
101
102def _OptionsOrNone(descriptor_proto):
103  """Returns the value of the field `options`, or None if it is not set."""
104  if descriptor_proto.HasField('options'):
105    return descriptor_proto.options
106  else:
107    return None
108
109
110def _IsMessageSetExtension(field):
111  return (field.is_extension and
112          field.containing_type.has_options and
113          field.containing_type.GetOptions().message_set_wire_format and
114          field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
115          field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
116
117
118class DescriptorPool(object):
119  """A collection of protobufs dynamically constructed by descriptor protos."""
120
121  if _USE_C_DESCRIPTORS:
122
123    def __new__(cls, descriptor_db=None):
124      # pylint: disable=protected-access
125      return descriptor._message.DescriptorPool(descriptor_db)
126
127  def __init__(self, descriptor_db=None):
128    """Initializes a Pool of proto buffs.
129
130    The descriptor_db argument to the constructor is provided to allow
131    specialized file descriptor proto lookup code to be triggered on demand. An
132    example would be an implementation which will read and compile a file
133    specified in a call to FindFileByName() and not require the call to Add()
134    at all. Results from this database will be cached internally here as well.
135
136    Args:
137      descriptor_db: A secondary source of file descriptors.
138    """
139
140    self._internal_db = descriptor_database.DescriptorDatabase()
141    self._descriptor_db = descriptor_db
142    self._descriptors = {}
143    self._enum_descriptors = {}
144    self._service_descriptors = {}
145    self._file_descriptors = {}
146    self._toplevel_extensions = {}
147    # TODO(jieluo): Remove _file_desc_by_toplevel_extension after
148    # maybe year 2020 for compatibility issue (with 3.4.1 only).
149    self._file_desc_by_toplevel_extension = {}
150    self._top_enum_values = {}
151    # We store extensions in two two-level mappings: The first key is the
152    # descriptor of the message being extended, the second key is the extension
153    # full name or its tag number.
154    self._extensions_by_name = collections.defaultdict(dict)
155    self._extensions_by_number = collections.defaultdict(dict)
156
157  def _CheckConflictRegister(self, desc, desc_name, file_name):
158    """Check if the descriptor name conflicts with another of the same name.
159
160    Args:
161      desc: Descriptor of a message, enum, service, extension or enum value.
162      desc_name (str): the full name of desc.
163      file_name (str): The file name of descriptor.
164    """
165    for register, descriptor_type in [
166        (self._descriptors, descriptor.Descriptor),
167        (self._enum_descriptors, descriptor.EnumDescriptor),
168        (self._service_descriptors, descriptor.ServiceDescriptor),
169        (self._toplevel_extensions, descriptor.FieldDescriptor),
170        (self._top_enum_values, descriptor.EnumValueDescriptor)]:
171      if desc_name in register:
172        old_desc = register[desc_name]
173        if isinstance(old_desc, descriptor.EnumValueDescriptor):
174          old_file = old_desc.type.file.name
175        else:
176          old_file = old_desc.file.name
177
178        if not isinstance(desc, descriptor_type) or (
179            old_file != file_name):
180          error_msg = ('Conflict register for file "' + file_name +
181                       '": ' + desc_name +
182                       ' is already defined in file "' +
183                       old_file + '". Please fix the conflict by adding '
184                       'package name on the proto file, or use different '
185                       'name for the duplication.')
186          if isinstance(desc, descriptor.EnumValueDescriptor):
187            error_msg += ('\nNote: enum values appear as '
188                          'siblings of the enum type instead of '
189                          'children of it.')
190
191          raise TypeError(error_msg)
192
193        return
194
195  def Add(self, file_desc_proto):
196    """Adds the FileDescriptorProto and its types to this pool.
197
198    Args:
199      file_desc_proto (FileDescriptorProto): The file descriptor to add.
200    """
201
202    self._internal_db.Add(file_desc_proto)
203
204  def AddSerializedFile(self, serialized_file_desc_proto):
205    """Adds the FileDescriptorProto and its types to this pool.
206
207    Args:
208      serialized_file_desc_proto (bytes): A bytes string, serialization of the
209        :class:`FileDescriptorProto` to add.
210    """
211
212    # pylint: disable=g-import-not-at-top
213    from google.protobuf import descriptor_pb2
214    file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
215        serialized_file_desc_proto)
216    self.Add(file_desc_proto)
217
218  # Add Descriptor to descriptor pool is dreprecated. Please use Add()
219  # or AddSerializedFile() to add a FileDescriptorProto instead.
220  @_Deprecated
221  def AddDescriptor(self, desc):
222    self._AddDescriptor(desc)
223
224  # Never call this method. It is for internal usage only.
225  def _AddDescriptor(self, desc):
226    """Adds a Descriptor to the pool, non-recursively.
227
228    If the Descriptor contains nested messages or enums, the caller must
229    explicitly register them. This method also registers the FileDescriptor
230    associated with the message.
231
232    Args:
233      desc: A Descriptor.
234    """
235    if not isinstance(desc, descriptor.Descriptor):
236      raise TypeError('Expected instance of descriptor.Descriptor.')
237
238    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
239
240    self._descriptors[desc.full_name] = desc
241    self._AddFileDescriptor(desc.file)
242
243  # Add EnumDescriptor to descriptor pool is dreprecated. Please use Add()
244  # or AddSerializedFile() to add a FileDescriptorProto instead.
245  @_Deprecated
246  def AddEnumDescriptor(self, enum_desc):
247    self._AddEnumDescriptor(enum_desc)
248
249  # Never call this method. It is for internal usage only.
250  def _AddEnumDescriptor(self, enum_desc):
251    """Adds an EnumDescriptor to the pool.
252
253    This method also registers the FileDescriptor associated with the enum.
254
255    Args:
256      enum_desc: An EnumDescriptor.
257    """
258
259    if not isinstance(enum_desc, descriptor.EnumDescriptor):
260      raise TypeError('Expected instance of descriptor.EnumDescriptor.')
261
262    file_name = enum_desc.file.name
263    self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
264    self._enum_descriptors[enum_desc.full_name] = enum_desc
265
266    # Top enum values need to be indexed.
267    # Count the number of dots to see whether the enum is toplevel or nested
268    # in a message. We cannot use enum_desc.containing_type at this stage.
269    if enum_desc.file.package:
270      top_level = (enum_desc.full_name.count('.')
271                   - enum_desc.file.package.count('.') == 1)
272    else:
273      top_level = enum_desc.full_name.count('.') == 0
274    if top_level:
275      file_name = enum_desc.file.name
276      package = enum_desc.file.package
277      for enum_value in enum_desc.values:
278        full_name = _NormalizeFullyQualifiedName(
279            '.'.join((package, enum_value.name)))
280        self._CheckConflictRegister(enum_value, full_name, file_name)
281        self._top_enum_values[full_name] = enum_value
282    self._AddFileDescriptor(enum_desc.file)
283
284  # Add ServiceDescriptor to descriptor pool is dreprecated. Please use Add()
285  # or AddSerializedFile() to add a FileDescriptorProto instead.
286  @_Deprecated
287  def AddServiceDescriptor(self, service_desc):
288    self._AddServiceDescriptor(service_desc)
289
290  # Never call this method. It is for internal usage only.
291  def _AddServiceDescriptor(self, service_desc):
292    """Adds a ServiceDescriptor to the pool.
293
294    Args:
295      service_desc: A ServiceDescriptor.
296    """
297
298    if not isinstance(service_desc, descriptor.ServiceDescriptor):
299      raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
300
301    self._CheckConflictRegister(service_desc, service_desc.full_name,
302                                service_desc.file.name)
303    self._service_descriptors[service_desc.full_name] = service_desc
304
305  # Add ExtensionDescriptor to descriptor pool is dreprecated. Please use Add()
306  # or AddSerializedFile() to add a FileDescriptorProto instead.
307  @_Deprecated
308  def AddExtensionDescriptor(self, extension):
309    self._AddExtensionDescriptor(extension)
310
311  # Never call this method. It is for internal usage only.
312  def _AddExtensionDescriptor(self, extension):
313    """Adds a FieldDescriptor describing an extension to the pool.
314
315    Args:
316      extension: A FieldDescriptor.
317
318    Raises:
319      AssertionError: when another extension with the same number extends the
320        same message.
321      TypeError: when the specified extension is not a
322        descriptor.FieldDescriptor.
323    """
324    if not (isinstance(extension, descriptor.FieldDescriptor) and
325            extension.is_extension):
326      raise TypeError('Expected an extension descriptor.')
327
328    if extension.extension_scope is None:
329      self._toplevel_extensions[extension.full_name] = extension
330
331    try:
332      existing_desc = self._extensions_by_number[
333          extension.containing_type][extension.number]
334    except KeyError:
335      pass
336    else:
337      if extension is not existing_desc:
338        raise AssertionError(
339            'Extensions "%s" and "%s" both try to extend message type "%s" '
340            'with field number %d.' %
341            (extension.full_name, existing_desc.full_name,
342             extension.containing_type.full_name, extension.number))
343
344    self._extensions_by_number[extension.containing_type][
345        extension.number] = extension
346    self._extensions_by_name[extension.containing_type][
347        extension.full_name] = extension
348
349    # Also register MessageSet extensions with the type name.
350    if _IsMessageSetExtension(extension):
351      self._extensions_by_name[extension.containing_type][
352          extension.message_type.full_name] = extension
353
354  @_Deprecated
355  def AddFileDescriptor(self, file_desc):
356    self._InternalAddFileDescriptor(file_desc)
357
358  # Never call this method. It is for internal usage only.
359  def _InternalAddFileDescriptor(self, file_desc):
360    """Adds a FileDescriptor to the pool, non-recursively.
361
362    If the FileDescriptor contains messages or enums, the caller must explicitly
363    register them.
364
365    Args:
366      file_desc: A FileDescriptor.
367    """
368
369    self._AddFileDescriptor(file_desc)
370    # TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
371    # FieldDescriptor.file is added in code gen. Remove this solution after
372    # maybe 2020 for compatibility reason (with 3.4.1 only).
373    for extension in file_desc.extensions_by_name.values():
374      self._file_desc_by_toplevel_extension[
375          extension.full_name] = file_desc
376
377  def _AddFileDescriptor(self, file_desc):
378    """Adds a FileDescriptor to the pool, non-recursively.
379
380    If the FileDescriptor contains messages or enums, the caller must explicitly
381    register them.
382
383    Args:
384      file_desc: A FileDescriptor.
385    """
386
387    if not isinstance(file_desc, descriptor.FileDescriptor):
388      raise TypeError('Expected instance of descriptor.FileDescriptor.')
389    self._file_descriptors[file_desc.name] = file_desc
390
391  def FindFileByName(self, file_name):
392    """Gets a FileDescriptor by file name.
393
394    Args:
395      file_name (str): The path to the file to get a descriptor for.
396
397    Returns:
398      FileDescriptor: The descriptor for the named file.
399
400    Raises:
401      KeyError: if the file cannot be found in the pool.
402    """
403
404    try:
405      return self._file_descriptors[file_name]
406    except KeyError:
407      pass
408
409    try:
410      file_proto = self._internal_db.FindFileByName(file_name)
411    except KeyError as error:
412      if self._descriptor_db:
413        file_proto = self._descriptor_db.FindFileByName(file_name)
414      else:
415        raise error
416    if not file_proto:
417      raise KeyError('Cannot find a file named %s' % file_name)
418    return self._ConvertFileProtoToFileDescriptor(file_proto)
419
420  def FindFileContainingSymbol(self, symbol):
421    """Gets the FileDescriptor for the file containing the specified symbol.
422
423    Args:
424      symbol (str): The name of the symbol to search for.
425
426    Returns:
427      FileDescriptor: Descriptor for the file that contains the specified
428      symbol.
429
430    Raises:
431      KeyError: if the file cannot be found in the pool.
432    """
433
434    symbol = _NormalizeFullyQualifiedName(symbol)
435    try:
436      return self._InternalFindFileContainingSymbol(symbol)
437    except KeyError:
438      pass
439
440    try:
441      # Try fallback database. Build and find again if possible.
442      self._FindFileContainingSymbolInDb(symbol)
443      return self._InternalFindFileContainingSymbol(symbol)
444    except KeyError:
445      raise KeyError('Cannot find a file containing %s' % symbol)
446
447  def _InternalFindFileContainingSymbol(self, symbol):
448    """Gets the already built FileDescriptor containing the specified symbol.
449
450    Args:
451      symbol (str): The name of the symbol to search for.
452
453    Returns:
454      FileDescriptor: Descriptor for the file that contains the specified
455      symbol.
456
457    Raises:
458      KeyError: if the file cannot be found in the pool.
459    """
460    try:
461      return self._descriptors[symbol].file
462    except KeyError:
463      pass
464
465    try:
466      return self._enum_descriptors[symbol].file
467    except KeyError:
468      pass
469
470    try:
471      return self._service_descriptors[symbol].file
472    except KeyError:
473      pass
474
475    try:
476      return self._top_enum_values[symbol].type.file
477    except KeyError:
478      pass
479
480    try:
481      return self._file_desc_by_toplevel_extension[symbol]
482    except KeyError:
483      pass
484
485    # Try fields, enum values and nested extensions inside a message.
486    top_name, _, sub_name = symbol.rpartition('.')
487    try:
488      message = self.FindMessageTypeByName(top_name)
489      assert (sub_name in message.extensions_by_name or
490              sub_name in message.fields_by_name or
491              sub_name in message.enum_values_by_name)
492      return message.file
493    except (KeyError, AssertionError):
494      raise KeyError('Cannot find a file containing %s' % symbol)
495
496  def FindMessageTypeByName(self, full_name):
497    """Loads the named descriptor from the pool.
498
499    Args:
500      full_name (str): The full name of the descriptor to load.
501
502    Returns:
503      Descriptor: The descriptor for the named type.
504
505    Raises:
506      KeyError: if the message cannot be found in the pool.
507    """
508
509    full_name = _NormalizeFullyQualifiedName(full_name)
510    if full_name not in self._descriptors:
511      self._FindFileContainingSymbolInDb(full_name)
512    return self._descriptors[full_name]
513
514  def FindEnumTypeByName(self, full_name):
515    """Loads the named enum descriptor from the pool.
516
517    Args:
518      full_name (str): The full name of the enum descriptor to load.
519
520    Returns:
521      EnumDescriptor: The enum descriptor for the named type.
522
523    Raises:
524      KeyError: if the enum cannot be found in the pool.
525    """
526
527    full_name = _NormalizeFullyQualifiedName(full_name)
528    if full_name not in self._enum_descriptors:
529      self._FindFileContainingSymbolInDb(full_name)
530    return self._enum_descriptors[full_name]
531
532  def FindFieldByName(self, full_name):
533    """Loads the named field descriptor from the pool.
534
535    Args:
536      full_name (str): The full name of the field descriptor to load.
537
538    Returns:
539      FieldDescriptor: The field descriptor for the named field.
540
541    Raises:
542      KeyError: if the field cannot be found in the pool.
543    """
544    full_name = _NormalizeFullyQualifiedName(full_name)
545    message_name, _, field_name = full_name.rpartition('.')
546    message_descriptor = self.FindMessageTypeByName(message_name)
547    return message_descriptor.fields_by_name[field_name]
548
549  def FindOneofByName(self, full_name):
550    """Loads the named oneof descriptor from the pool.
551
552    Args:
553      full_name (str): The full name of the oneof descriptor to load.
554
555    Returns:
556      OneofDescriptor: The oneof descriptor for the named oneof.
557
558    Raises:
559      KeyError: if the oneof cannot be found in the pool.
560    """
561    full_name = _NormalizeFullyQualifiedName(full_name)
562    message_name, _, oneof_name = full_name.rpartition('.')
563    message_descriptor = self.FindMessageTypeByName(message_name)
564    return message_descriptor.oneofs_by_name[oneof_name]
565
566  def FindExtensionByName(self, full_name):
567    """Loads the named extension descriptor from the pool.
568
569    Args:
570      full_name (str): The full name of the extension descriptor to load.
571
572    Returns:
573      FieldDescriptor: The field descriptor for the named extension.
574
575    Raises:
576      KeyError: if the extension cannot be found in the pool.
577    """
578    full_name = _NormalizeFullyQualifiedName(full_name)
579    try:
580      # The proto compiler does not give any link between the FileDescriptor
581      # and top-level extensions unless the FileDescriptorProto is added to
582      # the DescriptorDatabase, but this can impact memory usage.
583      # So we registered these extensions by name explicitly.
584      return self._toplevel_extensions[full_name]
585    except KeyError:
586      pass
587    message_name, _, extension_name = full_name.rpartition('.')
588    try:
589      # Most extensions are nested inside a message.
590      scope = self.FindMessageTypeByName(message_name)
591    except KeyError:
592      # Some extensions are defined at file scope.
593      scope = self._FindFileContainingSymbolInDb(full_name)
594    return scope.extensions_by_name[extension_name]
595
596  def FindExtensionByNumber(self, message_descriptor, number):
597    """Gets the extension of the specified message with the specified number.
598
599    Extensions have to be registered to this pool by calling :func:`Add` or
600    :func:`AddExtensionDescriptor`.
601
602    Args:
603      message_descriptor (Descriptor): descriptor of the extended message.
604      number (int): Number of the extension field.
605
606    Returns:
607      FieldDescriptor: The descriptor for the extension.
608
609    Raises:
610      KeyError: when no extension with the given number is known for the
611        specified message.
612    """
613    try:
614      return self._extensions_by_number[message_descriptor][number]
615    except KeyError:
616      self._TryLoadExtensionFromDB(message_descriptor, number)
617      return self._extensions_by_number[message_descriptor][number]
618
619  def FindAllExtensions(self, message_descriptor):
620    """Gets all the known extensions of a given message.
621
622    Extensions have to be registered to this pool by build related
623    :func:`Add` or :func:`AddExtensionDescriptor`.
624
625    Args:
626      message_descriptor (Descriptor): Descriptor of the extended message.
627
628    Returns:
629      list[FieldDescriptor]: Field descriptors describing the extensions.
630    """
631    # Fallback to descriptor db if FindAllExtensionNumbers is provided.
632    if self._descriptor_db and hasattr(
633        self._descriptor_db, 'FindAllExtensionNumbers'):
634      full_name = message_descriptor.full_name
635      all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
636      for number in all_numbers:
637        if number in self._extensions_by_number[message_descriptor]:
638          continue
639        self._TryLoadExtensionFromDB(message_descriptor, number)
640
641    return list(self._extensions_by_number[message_descriptor].values())
642
643  def _TryLoadExtensionFromDB(self, message_descriptor, number):
644    """Try to Load extensions from descriptor db.
645
646    Args:
647      message_descriptor: descriptor of the extended message.
648      number: the extension number that needs to be loaded.
649    """
650    if not self._descriptor_db:
651      return
652    # Only supported when FindFileContainingExtension is provided.
653    if not hasattr(
654        self._descriptor_db, 'FindFileContainingExtension'):
655      return
656
657    full_name = message_descriptor.full_name
658    file_proto = self._descriptor_db.FindFileContainingExtension(
659        full_name, number)
660
661    if file_proto is None:
662      return
663
664    try:
665      self._ConvertFileProtoToFileDescriptor(file_proto)
666    except:
667      warn_msg = ('Unable to load proto file %s for extension number %d.' %
668                  (file_proto.name, number))
669      warnings.warn(warn_msg, RuntimeWarning)
670
671  def FindServiceByName(self, full_name):
672    """Loads the named service descriptor from the pool.
673
674    Args:
675      full_name (str): The full name of the service descriptor to load.
676
677    Returns:
678      ServiceDescriptor: The service descriptor for the named service.
679
680    Raises:
681      KeyError: if the service cannot be found in the pool.
682    """
683    full_name = _NormalizeFullyQualifiedName(full_name)
684    if full_name not in self._service_descriptors:
685      self._FindFileContainingSymbolInDb(full_name)
686    return self._service_descriptors[full_name]
687
688  def FindMethodByName(self, full_name):
689    """Loads the named service method descriptor from the pool.
690
691    Args:
692      full_name (str): The full name of the method descriptor to load.
693
694    Returns:
695      MethodDescriptor: The method descriptor for the service method.
696
697    Raises:
698      KeyError: if the method cannot be found in the pool.
699    """
700    full_name = _NormalizeFullyQualifiedName(full_name)
701    service_name, _, method_name = full_name.rpartition('.')
702    service_descriptor = self.FindServiceByName(service_name)
703    return service_descriptor.methods_by_name[method_name]
704
705  def _FindFileContainingSymbolInDb(self, symbol):
706    """Finds the file in descriptor DB containing the specified symbol.
707
708    Args:
709      symbol (str): The name of the symbol to search for.
710
711    Returns:
712      FileDescriptor: The file that contains the specified symbol.
713
714    Raises:
715      KeyError: if the file cannot be found in the descriptor database.
716    """
717    try:
718      file_proto = self._internal_db.FindFileContainingSymbol(symbol)
719    except KeyError as error:
720      if self._descriptor_db:
721        file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
722      else:
723        raise error
724    if not file_proto:
725      raise KeyError('Cannot find a file containing %s' % symbol)
726    return self._ConvertFileProtoToFileDescriptor(file_proto)
727
728  def _ConvertFileProtoToFileDescriptor(self, file_proto):
729    """Creates a FileDescriptor from a proto or returns a cached copy.
730
731    This method also has the side effect of loading all the symbols found in
732    the file into the appropriate dictionaries in the pool.
733
734    Args:
735      file_proto: The proto to convert.
736
737    Returns:
738      A FileDescriptor matching the passed in proto.
739    """
740    if file_proto.name not in self._file_descriptors:
741      built_deps = list(self._GetDeps(file_proto.dependency))
742      direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
743      public_deps = [direct_deps[i] for i in file_proto.public_dependency]
744
745      file_descriptor = descriptor.FileDescriptor(
746          pool=self,
747          name=file_proto.name,
748          package=file_proto.package,
749          syntax=file_proto.syntax,
750          options=_OptionsOrNone(file_proto),
751          serialized_pb=file_proto.SerializeToString(),
752          dependencies=direct_deps,
753          public_dependencies=public_deps,
754          # pylint: disable=protected-access
755          create_key=descriptor._internal_create_key)
756      scope = {}
757
758      # This loop extracts all the message and enum types from all the
759      # dependencies of the file_proto. This is necessary to create the
760      # scope of available message types when defining the passed in
761      # file proto.
762      for dependency in built_deps:
763        scope.update(self._ExtractSymbols(
764            dependency.message_types_by_name.values()))
765        scope.update((_PrefixWithDot(enum.full_name), enum)
766                     for enum in dependency.enum_types_by_name.values())
767
768      for message_type in file_proto.message_type:
769        message_desc = self._ConvertMessageDescriptor(
770            message_type, file_proto.package, file_descriptor, scope,
771            file_proto.syntax)
772        file_descriptor.message_types_by_name[message_desc.name] = (
773            message_desc)
774
775      for enum_type in file_proto.enum_type:
776        file_descriptor.enum_types_by_name[enum_type.name] = (
777            self._ConvertEnumDescriptor(enum_type, file_proto.package,
778                                        file_descriptor, None, scope, True))
779
780      for index, extension_proto in enumerate(file_proto.extension):
781        extension_desc = self._MakeFieldDescriptor(
782            extension_proto, file_proto.package, index, file_descriptor,
783            is_extension=True)
784        extension_desc.containing_type = self._GetTypeFromScope(
785            file_descriptor.package, extension_proto.extendee, scope)
786        self._SetFieldType(extension_proto, extension_desc,
787                           file_descriptor.package, scope)
788        file_descriptor.extensions_by_name[extension_desc.name] = (
789            extension_desc)
790        self._file_desc_by_toplevel_extension[extension_desc.full_name] = (
791            file_descriptor)
792
793      for desc_proto in file_proto.message_type:
794        self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
795
796      if file_proto.package:
797        desc_proto_prefix = _PrefixWithDot(file_proto.package)
798      else:
799        desc_proto_prefix = ''
800
801      for desc_proto in file_proto.message_type:
802        desc = self._GetTypeFromScope(
803            desc_proto_prefix, desc_proto.name, scope)
804        file_descriptor.message_types_by_name[desc_proto.name] = desc
805
806      for index, service_proto in enumerate(file_proto.service):
807        file_descriptor.services_by_name[service_proto.name] = (
808            self._MakeServiceDescriptor(service_proto, index, scope,
809                                        file_proto.package, file_descriptor))
810
811      self.Add(file_proto)
812      self._file_descriptors[file_proto.name] = file_descriptor
813
814    # Add extensions to the pool
815    file_desc = self._file_descriptors[file_proto.name]
816    for extension in file_desc.extensions_by_name.values():
817      self._AddExtensionDescriptor(extension)
818    for message_type in file_desc.message_types_by_name.values():
819      for extension in message_type.extensions:
820        self._AddExtensionDescriptor(extension)
821
822    return file_desc
823
824  def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
825                                scope=None, syntax=None):
826    """Adds the proto to the pool in the specified package.
827
828    Args:
829      desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
830      package: The package the proto should be located in.
831      file_desc: The file containing this message.
832      scope: Dict mapping short and full symbols to message and enum types.
833      syntax: string indicating syntax of the file ("proto2" or "proto3")
834
835    Returns:
836      The added descriptor.
837    """
838
839    if package:
840      desc_name = '.'.join((package, desc_proto.name))
841    else:
842      desc_name = desc_proto.name
843
844    if file_desc is None:
845      file_name = None
846    else:
847      file_name = file_desc.name
848
849    if scope is None:
850      scope = {}
851
852    nested = [
853        self._ConvertMessageDescriptor(
854            nested, desc_name, file_desc, scope, syntax)
855        for nested in desc_proto.nested_type]
856    enums = [
857        self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
858                                    scope, False)
859        for enum in desc_proto.enum_type]
860    fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
861              for index, field in enumerate(desc_proto.field)]
862    extensions = [
863        self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
864                                  is_extension=True)
865        for index, extension in enumerate(desc_proto.extension)]
866    oneofs = [
867        # pylint: disable=g-complex-comprehension
868        descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)),
869                                   index, None, [], desc.options,
870                                   # pylint: disable=protected-access
871                                   create_key=descriptor._internal_create_key)
872        for index, desc in enumerate(desc_proto.oneof_decl)]
873    extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
874    if extension_ranges:
875      is_extendable = True
876    else:
877      is_extendable = False
878    desc = descriptor.Descriptor(
879        name=desc_proto.name,
880        full_name=desc_name,
881        filename=file_name,
882        containing_type=None,
883        fields=fields,
884        oneofs=oneofs,
885        nested_types=nested,
886        enum_types=enums,
887        extensions=extensions,
888        options=_OptionsOrNone(desc_proto),
889        is_extendable=is_extendable,
890        extension_ranges=extension_ranges,
891        file=file_desc,
892        serialized_start=None,
893        serialized_end=None,
894        syntax=syntax,
895        # pylint: disable=protected-access
896        create_key=descriptor._internal_create_key)
897    for nested in desc.nested_types:
898      nested.containing_type = desc
899    for enum in desc.enum_types:
900      enum.containing_type = desc
901    for field_index, field_desc in enumerate(desc_proto.field):
902      if field_desc.HasField('oneof_index'):
903        oneof_index = field_desc.oneof_index
904        oneofs[oneof_index].fields.append(fields[field_index])
905        fields[field_index].containing_oneof = oneofs[oneof_index]
906
907    scope[_PrefixWithDot(desc_name)] = desc
908    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
909    self._descriptors[desc_name] = desc
910    return desc
911
912  def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
913                             containing_type=None, scope=None, top_level=False):
914    """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
915
916    Args:
917      enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
918      package: Optional package name for the new message EnumDescriptor.
919      file_desc: The file containing the enum descriptor.
920      containing_type: The type containing this enum.
921      scope: Scope containing available types.
922      top_level: If True, the enum is a top level symbol. If False, the enum
923          is defined inside a message.
924
925    Returns:
926      The added descriptor
927    """
928
929    if package:
930      enum_name = '.'.join((package, enum_proto.name))
931    else:
932      enum_name = enum_proto.name
933
934    if file_desc is None:
935      file_name = None
936    else:
937      file_name = file_desc.name
938
939    values = [self._MakeEnumValueDescriptor(value, index)
940              for index, value in enumerate(enum_proto.value)]
941    desc = descriptor.EnumDescriptor(name=enum_proto.name,
942                                     full_name=enum_name,
943                                     filename=file_name,
944                                     file=file_desc,
945                                     values=values,
946                                     containing_type=containing_type,
947                                     options=_OptionsOrNone(enum_proto),
948                                     # pylint: disable=protected-access
949                                     create_key=descriptor._internal_create_key)
950    scope['.%s' % enum_name] = desc
951    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
952    self._enum_descriptors[enum_name] = desc
953
954    # Add top level enum values.
955    if top_level:
956      for value in values:
957        full_name = _NormalizeFullyQualifiedName(
958            '.'.join((package, value.name)))
959        self._CheckConflictRegister(value, full_name, file_name)
960        self._top_enum_values[full_name] = value
961
962    return desc
963
964  def _MakeFieldDescriptor(self, field_proto, message_name, index,
965                           file_desc, is_extension=False):
966    """Creates a field descriptor from a FieldDescriptorProto.
967
968    For message and enum type fields, this method will do a look up
969    in the pool for the appropriate descriptor for that type. If it
970    is unavailable, it will fall back to the _source function to
971    create it. If this type is still unavailable, construction will
972    fail.
973
974    Args:
975      field_proto: The proto describing the field.
976      message_name: The name of the containing message.
977      index: Index of the field
978      file_desc: The file containing the field descriptor.
979      is_extension: Indication that this field is for an extension.
980
981    Returns:
982      An initialized FieldDescriptor object
983    """
984
985    if message_name:
986      full_name = '.'.join((message_name, field_proto.name))
987    else:
988      full_name = field_proto.name
989
990    return descriptor.FieldDescriptor(
991        name=field_proto.name,
992        full_name=full_name,
993        index=index,
994        number=field_proto.number,
995        type=field_proto.type,
996        cpp_type=None,
997        message_type=None,
998        enum_type=None,
999        containing_type=None,
1000        label=field_proto.label,
1001        has_default_value=False,
1002        default_value=None,
1003        is_extension=is_extension,
1004        extension_scope=None,
1005        options=_OptionsOrNone(field_proto),
1006        file=file_desc,
1007        # pylint: disable=protected-access
1008        create_key=descriptor._internal_create_key)
1009
1010  def _SetAllFieldTypes(self, package, desc_proto, scope):
1011    """Sets all the descriptor's fields's types.
1012
1013    This method also sets the containing types on any extensions.
1014
1015    Args:
1016      package: The current package of desc_proto.
1017      desc_proto: The message descriptor to update.
1018      scope: Enclosing scope of available types.
1019    """
1020
1021    package = _PrefixWithDot(package)
1022
1023    main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
1024
1025    if package == '.':
1026      nested_package = _PrefixWithDot(desc_proto.name)
1027    else:
1028      nested_package = '.'.join([package, desc_proto.name])
1029
1030    for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
1031      self._SetFieldType(field_proto, field_desc, nested_package, scope)
1032
1033    for extension_proto, extension_desc in (
1034        zip(desc_proto.extension, main_desc.extensions)):
1035      extension_desc.containing_type = self._GetTypeFromScope(
1036          nested_package, extension_proto.extendee, scope)
1037      self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
1038
1039    for nested_type in desc_proto.nested_type:
1040      self._SetAllFieldTypes(nested_package, nested_type, scope)
1041
1042  def _SetFieldType(self, field_proto, field_desc, package, scope):
1043    """Sets the field's type, cpp_type, message_type and enum_type.
1044
1045    Args:
1046      field_proto: Data about the field in proto format.
1047      field_desc: The descriptor to modify.
1048      package: The package the field's container is in.
1049      scope: Enclosing scope of available types.
1050    """
1051    if field_proto.type_name:
1052      desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
1053    else:
1054      desc = None
1055
1056    if not field_proto.HasField('type'):
1057      if isinstance(desc, descriptor.Descriptor):
1058        field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
1059      else:
1060        field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
1061
1062    field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
1063        field_proto.type)
1064
1065    if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
1066        or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
1067      field_desc.message_type = desc
1068
1069    if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1070      field_desc.enum_type = desc
1071
1072    if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
1073      field_desc.has_default_value = False
1074      field_desc.default_value = []
1075    elif field_proto.HasField('default_value'):
1076      field_desc.has_default_value = True
1077      if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1078          field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1079        field_desc.default_value = float(field_proto.default_value)
1080      elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1081        field_desc.default_value = field_proto.default_value
1082      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1083        field_desc.default_value = field_proto.default_value.lower() == 'true'
1084      elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1085        field_desc.default_value = field_desc.enum_type.values_by_name[
1086            field_proto.default_value].number
1087      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1088        field_desc.default_value = text_encoding.CUnescape(
1089            field_proto.default_value)
1090      elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1091        field_desc.default_value = None
1092      else:
1093        # All other types are of the "int" type.
1094        field_desc.default_value = int(field_proto.default_value)
1095    else:
1096      field_desc.has_default_value = False
1097      if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
1098          field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
1099        field_desc.default_value = 0.0
1100      elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
1101        field_desc.default_value = u''
1102      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
1103        field_desc.default_value = False
1104      elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
1105        field_desc.default_value = field_desc.enum_type.values[0].number
1106      elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
1107        field_desc.default_value = b''
1108      elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
1109        field_desc.default_value = None
1110      else:
1111        # All other types are of the "int" type.
1112        field_desc.default_value = 0
1113
1114    field_desc.type = field_proto.type
1115
1116  def _MakeEnumValueDescriptor(self, value_proto, index):
1117    """Creates a enum value descriptor object from a enum value proto.
1118
1119    Args:
1120      value_proto: The proto describing the enum value.
1121      index: The index of the enum value.
1122
1123    Returns:
1124      An initialized EnumValueDescriptor object.
1125    """
1126
1127    return descriptor.EnumValueDescriptor(
1128        name=value_proto.name,
1129        index=index,
1130        number=value_proto.number,
1131        options=_OptionsOrNone(value_proto),
1132        type=None,
1133        # pylint: disable=protected-access
1134        create_key=descriptor._internal_create_key)
1135
1136  def _MakeServiceDescriptor(self, service_proto, service_index, scope,
1137                             package, file_desc):
1138    """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
1139
1140    Args:
1141      service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
1142      service_index: The index of the service in the File.
1143      scope: Dict mapping short and full symbols to message and enum types.
1144      package: Optional package name for the new message EnumDescriptor.
1145      file_desc: The file containing the service descriptor.
1146
1147    Returns:
1148      The added descriptor.
1149    """
1150
1151    if package:
1152      service_name = '.'.join((package, service_proto.name))
1153    else:
1154      service_name = service_proto.name
1155
1156    methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
1157                                          scope, index)
1158               for index, method_proto in enumerate(service_proto.method)]
1159    desc = descriptor.ServiceDescriptor(
1160        name=service_proto.name,
1161        full_name=service_name,
1162        index=service_index,
1163        methods=methods,
1164        options=_OptionsOrNone(service_proto),
1165        file=file_desc,
1166        # pylint: disable=protected-access
1167        create_key=descriptor._internal_create_key)
1168    self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
1169    self._service_descriptors[service_name] = desc
1170    return desc
1171
1172  def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
1173                            index):
1174    """Creates a method descriptor from a MethodDescriptorProto.
1175
1176    Args:
1177      method_proto: The proto describing the method.
1178      service_name: The name of the containing service.
1179      package: Optional package name to look up for types.
1180      scope: Scope containing available types.
1181      index: Index of the method in the service.
1182
1183    Returns:
1184      An initialized MethodDescriptor object.
1185    """
1186    full_name = '.'.join((service_name, method_proto.name))
1187    input_type = self._GetTypeFromScope(
1188        package, method_proto.input_type, scope)
1189    output_type = self._GetTypeFromScope(
1190        package, method_proto.output_type, scope)
1191    return descriptor.MethodDescriptor(
1192        name=method_proto.name,
1193        full_name=full_name,
1194        index=index,
1195        containing_service=None,
1196        input_type=input_type,
1197        output_type=output_type,
1198        options=_OptionsOrNone(method_proto),
1199        # pylint: disable=protected-access
1200        create_key=descriptor._internal_create_key)
1201
1202  def _ExtractSymbols(self, descriptors):
1203    """Pulls out all the symbols from descriptor protos.
1204
1205    Args:
1206      descriptors: The messages to extract descriptors from.
1207    Yields:
1208      A two element tuple of the type name and descriptor object.
1209    """
1210
1211    for desc in descriptors:
1212      yield (_PrefixWithDot(desc.full_name), desc)
1213      for symbol in self._ExtractSymbols(desc.nested_types):
1214        yield symbol
1215      for enum in desc.enum_types:
1216        yield (_PrefixWithDot(enum.full_name), enum)
1217
1218  def _GetDeps(self, dependencies):
1219    """Recursively finds dependencies for file protos.
1220
1221    Args:
1222      dependencies: The names of the files being depended on.
1223
1224    Yields:
1225      Each direct and indirect dependency.
1226    """
1227
1228    for dependency in dependencies:
1229      dep_desc = self.FindFileByName(dependency)
1230      yield dep_desc
1231      for parent_dep in dep_desc.dependencies:
1232        yield parent_dep
1233
1234  def _GetTypeFromScope(self, package, type_name, scope):
1235    """Finds a given type name in the current scope.
1236
1237    Args:
1238      package: The package the proto should be located in.
1239      type_name: The name of the type to be found in the scope.
1240      scope: Dict mapping short and full symbols to message and enum types.
1241
1242    Returns:
1243      The descriptor for the requested type.
1244    """
1245    if type_name not in scope:
1246      components = _PrefixWithDot(package).split('.')
1247      while components:
1248        possible_match = '.'.join(components + [type_name])
1249        if possible_match in scope:
1250          type_name = possible_match
1251          break
1252        else:
1253          components.pop(-1)
1254    return scope[type_name]
1255
1256
1257def _PrefixWithDot(name):
1258  return name if name.startswith('.') else '.%s' % name
1259
1260
1261if _USE_C_DESCRIPTORS:
1262  # TODO(amauryfa): This pool could be constructed from Python code, when we
1263  # support a flag like 'use_cpp_generated_pool=True'.
1264  # pylint: disable=protected-access
1265  _DEFAULT = descriptor._message.default_pool
1266else:
1267  _DEFAULT = DescriptorPool()
1268
1269
1270def Default():
1271  return _DEFAULT
1272