1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Provides DescriptorPool to use as a container for proto2 descriptors. 32 33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain 34a collection of protocol buffer descriptors for use when dynamically creating 35message types at runtime. 36 37For most applications protocol buffers should be used via modules generated by 38the protocol buffer compiler tool. This should only be used when the type of 39protocol buffers used in an application or library cannot be predetermined. 40 41Below is a straightforward example on how to use this class:: 42 43 pool = DescriptorPool() 44 file_descriptor_protos = [ ... ] 45 for file_descriptor_proto in file_descriptor_protos: 46 pool.Add(file_descriptor_proto) 47 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType') 48 49The message descriptor can be used in conjunction with the message_factory 50module in order to create a protocol buffer class that can be encoded and 51decoded. 52 53If you want to get a Python class for the specified proto, use the 54helper functions inside google.protobuf.message_factory 55directly instead of this class. 56""" 57 58__author__ = 'matthewtoia@google.com (Matt Toia)' 59 60import collections 61import warnings 62 63from google.protobuf import descriptor 64from google.protobuf import descriptor_database 65from google.protobuf import text_encoding 66 67 68_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access 69 70 71def _Deprecated(func): 72 """Mark functions as deprecated.""" 73 74 def NewFunc(*args, **kwargs): 75 warnings.warn( 76 'Call to deprecated function %s(). Note: Do add unlinked descriptors ' 77 'to descriptor_pool is wrong. Use Add() or AddSerializedFile() ' 78 'instead.' % func.__name__, 79 category=DeprecationWarning) 80 return func(*args, **kwargs) 81 NewFunc.__name__ = func.__name__ 82 NewFunc.__doc__ = func.__doc__ 83 NewFunc.__dict__.update(func.__dict__) 84 return NewFunc 85 86 87def _NormalizeFullyQualifiedName(name): 88 """Remove leading period from fully-qualified type name. 89 90 Due to b/13860351 in descriptor_database.py, types in the root namespace are 91 generated with a leading period. This function removes that prefix. 92 93 Args: 94 name (str): The fully-qualified symbol name. 95 96 Returns: 97 str: The normalized fully-qualified symbol name. 98 """ 99 return name.lstrip('.') 100 101 102def _OptionsOrNone(descriptor_proto): 103 """Returns the value of the field `options`, or None if it is not set.""" 104 if descriptor_proto.HasField('options'): 105 return descriptor_proto.options 106 else: 107 return None 108 109 110def _IsMessageSetExtension(field): 111 return (field.is_extension and 112 field.containing_type.has_options and 113 field.containing_type.GetOptions().message_set_wire_format and 114 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and 115 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) 116 117 118class DescriptorPool(object): 119 """A collection of protobufs dynamically constructed by descriptor protos.""" 120 121 if _USE_C_DESCRIPTORS: 122 123 def __new__(cls, descriptor_db=None): 124 # pylint: disable=protected-access 125 return descriptor._message.DescriptorPool(descriptor_db) 126 127 def __init__(self, descriptor_db=None): 128 """Initializes a Pool of proto buffs. 129 130 The descriptor_db argument to the constructor is provided to allow 131 specialized file descriptor proto lookup code to be triggered on demand. An 132 example would be an implementation which will read and compile a file 133 specified in a call to FindFileByName() and not require the call to Add() 134 at all. Results from this database will be cached internally here as well. 135 136 Args: 137 descriptor_db: A secondary source of file descriptors. 138 """ 139 140 self._internal_db = descriptor_database.DescriptorDatabase() 141 self._descriptor_db = descriptor_db 142 self._descriptors = {} 143 self._enum_descriptors = {} 144 self._service_descriptors = {} 145 self._file_descriptors = {} 146 self._toplevel_extensions = {} 147 # TODO(jieluo): Remove _file_desc_by_toplevel_extension after 148 # maybe year 2020 for compatibility issue (with 3.4.1 only). 149 self._file_desc_by_toplevel_extension = {} 150 self._top_enum_values = {} 151 # We store extensions in two two-level mappings: The first key is the 152 # descriptor of the message being extended, the second key is the extension 153 # full name or its tag number. 154 self._extensions_by_name = collections.defaultdict(dict) 155 self._extensions_by_number = collections.defaultdict(dict) 156 157 def _CheckConflictRegister(self, desc, desc_name, file_name): 158 """Check if the descriptor name conflicts with another of the same name. 159 160 Args: 161 desc: Descriptor of a message, enum, service, extension or enum value. 162 desc_name (str): the full name of desc. 163 file_name (str): The file name of descriptor. 164 """ 165 for register, descriptor_type in [ 166 (self._descriptors, descriptor.Descriptor), 167 (self._enum_descriptors, descriptor.EnumDescriptor), 168 (self._service_descriptors, descriptor.ServiceDescriptor), 169 (self._toplevel_extensions, descriptor.FieldDescriptor), 170 (self._top_enum_values, descriptor.EnumValueDescriptor)]: 171 if desc_name in register: 172 old_desc = register[desc_name] 173 if isinstance(old_desc, descriptor.EnumValueDescriptor): 174 old_file = old_desc.type.file.name 175 else: 176 old_file = old_desc.file.name 177 178 if not isinstance(desc, descriptor_type) or ( 179 old_file != file_name): 180 error_msg = ('Conflict register for file "' + file_name + 181 '": ' + desc_name + 182 ' is already defined in file "' + 183 old_file + '". Please fix the conflict by adding ' 184 'package name on the proto file, or use different ' 185 'name for the duplication.') 186 if isinstance(desc, descriptor.EnumValueDescriptor): 187 error_msg += ('\nNote: enum values appear as ' 188 'siblings of the enum type instead of ' 189 'children of it.') 190 191 raise TypeError(error_msg) 192 193 return 194 195 def Add(self, file_desc_proto): 196 """Adds the FileDescriptorProto and its types to this pool. 197 198 Args: 199 file_desc_proto (FileDescriptorProto): The file descriptor to add. 200 """ 201 202 self._internal_db.Add(file_desc_proto) 203 204 def AddSerializedFile(self, serialized_file_desc_proto): 205 """Adds the FileDescriptorProto and its types to this pool. 206 207 Args: 208 serialized_file_desc_proto (bytes): A bytes string, serialization of the 209 :class:`FileDescriptorProto` to add. 210 """ 211 212 # pylint: disable=g-import-not-at-top 213 from google.protobuf import descriptor_pb2 214 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( 215 serialized_file_desc_proto) 216 self.Add(file_desc_proto) 217 218 # Add Descriptor to descriptor pool is dreprecated. Please use Add() 219 # or AddSerializedFile() to add a FileDescriptorProto instead. 220 @_Deprecated 221 def AddDescriptor(self, desc): 222 self._AddDescriptor(desc) 223 224 # Never call this method. It is for internal usage only. 225 def _AddDescriptor(self, desc): 226 """Adds a Descriptor to the pool, non-recursively. 227 228 If the Descriptor contains nested messages or enums, the caller must 229 explicitly register them. This method also registers the FileDescriptor 230 associated with the message. 231 232 Args: 233 desc: A Descriptor. 234 """ 235 if not isinstance(desc, descriptor.Descriptor): 236 raise TypeError('Expected instance of descriptor.Descriptor.') 237 238 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 239 240 self._descriptors[desc.full_name] = desc 241 self._AddFileDescriptor(desc.file) 242 243 # Add EnumDescriptor to descriptor pool is dreprecated. Please use Add() 244 # or AddSerializedFile() to add a FileDescriptorProto instead. 245 @_Deprecated 246 def AddEnumDescriptor(self, enum_desc): 247 self._AddEnumDescriptor(enum_desc) 248 249 # Never call this method. It is for internal usage only. 250 def _AddEnumDescriptor(self, enum_desc): 251 """Adds an EnumDescriptor to the pool. 252 253 This method also registers the FileDescriptor associated with the enum. 254 255 Args: 256 enum_desc: An EnumDescriptor. 257 """ 258 259 if not isinstance(enum_desc, descriptor.EnumDescriptor): 260 raise TypeError('Expected instance of descriptor.EnumDescriptor.') 261 262 file_name = enum_desc.file.name 263 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name) 264 self._enum_descriptors[enum_desc.full_name] = enum_desc 265 266 # Top enum values need to be indexed. 267 # Count the number of dots to see whether the enum is toplevel or nested 268 # in a message. We cannot use enum_desc.containing_type at this stage. 269 if enum_desc.file.package: 270 top_level = (enum_desc.full_name.count('.') 271 - enum_desc.file.package.count('.') == 1) 272 else: 273 top_level = enum_desc.full_name.count('.') == 0 274 if top_level: 275 file_name = enum_desc.file.name 276 package = enum_desc.file.package 277 for enum_value in enum_desc.values: 278 full_name = _NormalizeFullyQualifiedName( 279 '.'.join((package, enum_value.name))) 280 self._CheckConflictRegister(enum_value, full_name, file_name) 281 self._top_enum_values[full_name] = enum_value 282 self._AddFileDescriptor(enum_desc.file) 283 284 # Add ServiceDescriptor to descriptor pool is dreprecated. Please use Add() 285 # or AddSerializedFile() to add a FileDescriptorProto instead. 286 @_Deprecated 287 def AddServiceDescriptor(self, service_desc): 288 self._AddServiceDescriptor(service_desc) 289 290 # Never call this method. It is for internal usage only. 291 def _AddServiceDescriptor(self, service_desc): 292 """Adds a ServiceDescriptor to the pool. 293 294 Args: 295 service_desc: A ServiceDescriptor. 296 """ 297 298 if not isinstance(service_desc, descriptor.ServiceDescriptor): 299 raise TypeError('Expected instance of descriptor.ServiceDescriptor.') 300 301 self._CheckConflictRegister(service_desc, service_desc.full_name, 302 service_desc.file.name) 303 self._service_descriptors[service_desc.full_name] = service_desc 304 305 # Add ExtensionDescriptor to descriptor pool is dreprecated. Please use Add() 306 # or AddSerializedFile() to add a FileDescriptorProto instead. 307 @_Deprecated 308 def AddExtensionDescriptor(self, extension): 309 self._AddExtensionDescriptor(extension) 310 311 # Never call this method. It is for internal usage only. 312 def _AddExtensionDescriptor(self, extension): 313 """Adds a FieldDescriptor describing an extension to the pool. 314 315 Args: 316 extension: A FieldDescriptor. 317 318 Raises: 319 AssertionError: when another extension with the same number extends the 320 same message. 321 TypeError: when the specified extension is not a 322 descriptor.FieldDescriptor. 323 """ 324 if not (isinstance(extension, descriptor.FieldDescriptor) and 325 extension.is_extension): 326 raise TypeError('Expected an extension descriptor.') 327 328 if extension.extension_scope is None: 329 self._toplevel_extensions[extension.full_name] = extension 330 331 try: 332 existing_desc = self._extensions_by_number[ 333 extension.containing_type][extension.number] 334 except KeyError: 335 pass 336 else: 337 if extension is not existing_desc: 338 raise AssertionError( 339 'Extensions "%s" and "%s" both try to extend message type "%s" ' 340 'with field number %d.' % 341 (extension.full_name, existing_desc.full_name, 342 extension.containing_type.full_name, extension.number)) 343 344 self._extensions_by_number[extension.containing_type][ 345 extension.number] = extension 346 self._extensions_by_name[extension.containing_type][ 347 extension.full_name] = extension 348 349 # Also register MessageSet extensions with the type name. 350 if _IsMessageSetExtension(extension): 351 self._extensions_by_name[extension.containing_type][ 352 extension.message_type.full_name] = extension 353 354 @_Deprecated 355 def AddFileDescriptor(self, file_desc): 356 self._InternalAddFileDescriptor(file_desc) 357 358 # Never call this method. It is for internal usage only. 359 def _InternalAddFileDescriptor(self, file_desc): 360 """Adds a FileDescriptor to the pool, non-recursively. 361 362 If the FileDescriptor contains messages or enums, the caller must explicitly 363 register them. 364 365 Args: 366 file_desc: A FileDescriptor. 367 """ 368 369 self._AddFileDescriptor(file_desc) 370 # TODO(jieluo): This is a temporary solution for FieldDescriptor.file. 371 # FieldDescriptor.file is added in code gen. Remove this solution after 372 # maybe 2020 for compatibility reason (with 3.4.1 only). 373 for extension in file_desc.extensions_by_name.values(): 374 self._file_desc_by_toplevel_extension[ 375 extension.full_name] = file_desc 376 377 def _AddFileDescriptor(self, file_desc): 378 """Adds a FileDescriptor to the pool, non-recursively. 379 380 If the FileDescriptor contains messages or enums, the caller must explicitly 381 register them. 382 383 Args: 384 file_desc: A FileDescriptor. 385 """ 386 387 if not isinstance(file_desc, descriptor.FileDescriptor): 388 raise TypeError('Expected instance of descriptor.FileDescriptor.') 389 self._file_descriptors[file_desc.name] = file_desc 390 391 def FindFileByName(self, file_name): 392 """Gets a FileDescriptor by file name. 393 394 Args: 395 file_name (str): The path to the file to get a descriptor for. 396 397 Returns: 398 FileDescriptor: The descriptor for the named file. 399 400 Raises: 401 KeyError: if the file cannot be found in the pool. 402 """ 403 404 try: 405 return self._file_descriptors[file_name] 406 except KeyError: 407 pass 408 409 try: 410 file_proto = self._internal_db.FindFileByName(file_name) 411 except KeyError as error: 412 if self._descriptor_db: 413 file_proto = self._descriptor_db.FindFileByName(file_name) 414 else: 415 raise error 416 if not file_proto: 417 raise KeyError('Cannot find a file named %s' % file_name) 418 return self._ConvertFileProtoToFileDescriptor(file_proto) 419 420 def FindFileContainingSymbol(self, symbol): 421 """Gets the FileDescriptor for the file containing the specified symbol. 422 423 Args: 424 symbol (str): The name of the symbol to search for. 425 426 Returns: 427 FileDescriptor: Descriptor for the file that contains the specified 428 symbol. 429 430 Raises: 431 KeyError: if the file cannot be found in the pool. 432 """ 433 434 symbol = _NormalizeFullyQualifiedName(symbol) 435 try: 436 return self._InternalFindFileContainingSymbol(symbol) 437 except KeyError: 438 pass 439 440 try: 441 # Try fallback database. Build and find again if possible. 442 self._FindFileContainingSymbolInDb(symbol) 443 return self._InternalFindFileContainingSymbol(symbol) 444 except KeyError: 445 raise KeyError('Cannot find a file containing %s' % symbol) 446 447 def _InternalFindFileContainingSymbol(self, symbol): 448 """Gets the already built FileDescriptor containing the specified symbol. 449 450 Args: 451 symbol (str): The name of the symbol to search for. 452 453 Returns: 454 FileDescriptor: Descriptor for the file that contains the specified 455 symbol. 456 457 Raises: 458 KeyError: if the file cannot be found in the pool. 459 """ 460 try: 461 return self._descriptors[symbol].file 462 except KeyError: 463 pass 464 465 try: 466 return self._enum_descriptors[symbol].file 467 except KeyError: 468 pass 469 470 try: 471 return self._service_descriptors[symbol].file 472 except KeyError: 473 pass 474 475 try: 476 return self._top_enum_values[symbol].type.file 477 except KeyError: 478 pass 479 480 try: 481 return self._file_desc_by_toplevel_extension[symbol] 482 except KeyError: 483 pass 484 485 # Try fields, enum values and nested extensions inside a message. 486 top_name, _, sub_name = symbol.rpartition('.') 487 try: 488 message = self.FindMessageTypeByName(top_name) 489 assert (sub_name in message.extensions_by_name or 490 sub_name in message.fields_by_name or 491 sub_name in message.enum_values_by_name) 492 return message.file 493 except (KeyError, AssertionError): 494 raise KeyError('Cannot find a file containing %s' % symbol) 495 496 def FindMessageTypeByName(self, full_name): 497 """Loads the named descriptor from the pool. 498 499 Args: 500 full_name (str): The full name of the descriptor to load. 501 502 Returns: 503 Descriptor: The descriptor for the named type. 504 505 Raises: 506 KeyError: if the message cannot be found in the pool. 507 """ 508 509 full_name = _NormalizeFullyQualifiedName(full_name) 510 if full_name not in self._descriptors: 511 self._FindFileContainingSymbolInDb(full_name) 512 return self._descriptors[full_name] 513 514 def FindEnumTypeByName(self, full_name): 515 """Loads the named enum descriptor from the pool. 516 517 Args: 518 full_name (str): The full name of the enum descriptor to load. 519 520 Returns: 521 EnumDescriptor: The enum descriptor for the named type. 522 523 Raises: 524 KeyError: if the enum cannot be found in the pool. 525 """ 526 527 full_name = _NormalizeFullyQualifiedName(full_name) 528 if full_name not in self._enum_descriptors: 529 self._FindFileContainingSymbolInDb(full_name) 530 return self._enum_descriptors[full_name] 531 532 def FindFieldByName(self, full_name): 533 """Loads the named field descriptor from the pool. 534 535 Args: 536 full_name (str): The full name of the field descriptor to load. 537 538 Returns: 539 FieldDescriptor: The field descriptor for the named field. 540 541 Raises: 542 KeyError: if the field cannot be found in the pool. 543 """ 544 full_name = _NormalizeFullyQualifiedName(full_name) 545 message_name, _, field_name = full_name.rpartition('.') 546 message_descriptor = self.FindMessageTypeByName(message_name) 547 return message_descriptor.fields_by_name[field_name] 548 549 def FindOneofByName(self, full_name): 550 """Loads the named oneof descriptor from the pool. 551 552 Args: 553 full_name (str): The full name of the oneof descriptor to load. 554 555 Returns: 556 OneofDescriptor: The oneof descriptor for the named oneof. 557 558 Raises: 559 KeyError: if the oneof cannot be found in the pool. 560 """ 561 full_name = _NormalizeFullyQualifiedName(full_name) 562 message_name, _, oneof_name = full_name.rpartition('.') 563 message_descriptor = self.FindMessageTypeByName(message_name) 564 return message_descriptor.oneofs_by_name[oneof_name] 565 566 def FindExtensionByName(self, full_name): 567 """Loads the named extension descriptor from the pool. 568 569 Args: 570 full_name (str): The full name of the extension descriptor to load. 571 572 Returns: 573 FieldDescriptor: The field descriptor for the named extension. 574 575 Raises: 576 KeyError: if the extension cannot be found in the pool. 577 """ 578 full_name = _NormalizeFullyQualifiedName(full_name) 579 try: 580 # The proto compiler does not give any link between the FileDescriptor 581 # and top-level extensions unless the FileDescriptorProto is added to 582 # the DescriptorDatabase, but this can impact memory usage. 583 # So we registered these extensions by name explicitly. 584 return self._toplevel_extensions[full_name] 585 except KeyError: 586 pass 587 message_name, _, extension_name = full_name.rpartition('.') 588 try: 589 # Most extensions are nested inside a message. 590 scope = self.FindMessageTypeByName(message_name) 591 except KeyError: 592 # Some extensions are defined at file scope. 593 scope = self._FindFileContainingSymbolInDb(full_name) 594 return scope.extensions_by_name[extension_name] 595 596 def FindExtensionByNumber(self, message_descriptor, number): 597 """Gets the extension of the specified message with the specified number. 598 599 Extensions have to be registered to this pool by calling :func:`Add` or 600 :func:`AddExtensionDescriptor`. 601 602 Args: 603 message_descriptor (Descriptor): descriptor of the extended message. 604 number (int): Number of the extension field. 605 606 Returns: 607 FieldDescriptor: The descriptor for the extension. 608 609 Raises: 610 KeyError: when no extension with the given number is known for the 611 specified message. 612 """ 613 try: 614 return self._extensions_by_number[message_descriptor][number] 615 except KeyError: 616 self._TryLoadExtensionFromDB(message_descriptor, number) 617 return self._extensions_by_number[message_descriptor][number] 618 619 def FindAllExtensions(self, message_descriptor): 620 """Gets all the known extensions of a given message. 621 622 Extensions have to be registered to this pool by build related 623 :func:`Add` or :func:`AddExtensionDescriptor`. 624 625 Args: 626 message_descriptor (Descriptor): Descriptor of the extended message. 627 628 Returns: 629 list[FieldDescriptor]: Field descriptors describing the extensions. 630 """ 631 # Fallback to descriptor db if FindAllExtensionNumbers is provided. 632 if self._descriptor_db and hasattr( 633 self._descriptor_db, 'FindAllExtensionNumbers'): 634 full_name = message_descriptor.full_name 635 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name) 636 for number in all_numbers: 637 if number in self._extensions_by_number[message_descriptor]: 638 continue 639 self._TryLoadExtensionFromDB(message_descriptor, number) 640 641 return list(self._extensions_by_number[message_descriptor].values()) 642 643 def _TryLoadExtensionFromDB(self, message_descriptor, number): 644 """Try to Load extensions from descriptor db. 645 646 Args: 647 message_descriptor: descriptor of the extended message. 648 number: the extension number that needs to be loaded. 649 """ 650 if not self._descriptor_db: 651 return 652 # Only supported when FindFileContainingExtension is provided. 653 if not hasattr( 654 self._descriptor_db, 'FindFileContainingExtension'): 655 return 656 657 full_name = message_descriptor.full_name 658 file_proto = self._descriptor_db.FindFileContainingExtension( 659 full_name, number) 660 661 if file_proto is None: 662 return 663 664 try: 665 self._ConvertFileProtoToFileDescriptor(file_proto) 666 except: 667 warn_msg = ('Unable to load proto file %s for extension number %d.' % 668 (file_proto.name, number)) 669 warnings.warn(warn_msg, RuntimeWarning) 670 671 def FindServiceByName(self, full_name): 672 """Loads the named service descriptor from the pool. 673 674 Args: 675 full_name (str): The full name of the service descriptor to load. 676 677 Returns: 678 ServiceDescriptor: The service descriptor for the named service. 679 680 Raises: 681 KeyError: if the service cannot be found in the pool. 682 """ 683 full_name = _NormalizeFullyQualifiedName(full_name) 684 if full_name not in self._service_descriptors: 685 self._FindFileContainingSymbolInDb(full_name) 686 return self._service_descriptors[full_name] 687 688 def FindMethodByName(self, full_name): 689 """Loads the named service method descriptor from the pool. 690 691 Args: 692 full_name (str): The full name of the method descriptor to load. 693 694 Returns: 695 MethodDescriptor: The method descriptor for the service method. 696 697 Raises: 698 KeyError: if the method cannot be found in the pool. 699 """ 700 full_name = _NormalizeFullyQualifiedName(full_name) 701 service_name, _, method_name = full_name.rpartition('.') 702 service_descriptor = self.FindServiceByName(service_name) 703 return service_descriptor.methods_by_name[method_name] 704 705 def _FindFileContainingSymbolInDb(self, symbol): 706 """Finds the file in descriptor DB containing the specified symbol. 707 708 Args: 709 symbol (str): The name of the symbol to search for. 710 711 Returns: 712 FileDescriptor: The file that contains the specified symbol. 713 714 Raises: 715 KeyError: if the file cannot be found in the descriptor database. 716 """ 717 try: 718 file_proto = self._internal_db.FindFileContainingSymbol(symbol) 719 except KeyError as error: 720 if self._descriptor_db: 721 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) 722 else: 723 raise error 724 if not file_proto: 725 raise KeyError('Cannot find a file containing %s' % symbol) 726 return self._ConvertFileProtoToFileDescriptor(file_proto) 727 728 def _ConvertFileProtoToFileDescriptor(self, file_proto): 729 """Creates a FileDescriptor from a proto or returns a cached copy. 730 731 This method also has the side effect of loading all the symbols found in 732 the file into the appropriate dictionaries in the pool. 733 734 Args: 735 file_proto: The proto to convert. 736 737 Returns: 738 A FileDescriptor matching the passed in proto. 739 """ 740 if file_proto.name not in self._file_descriptors: 741 built_deps = list(self._GetDeps(file_proto.dependency)) 742 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] 743 public_deps = [direct_deps[i] for i in file_proto.public_dependency] 744 745 file_descriptor = descriptor.FileDescriptor( 746 pool=self, 747 name=file_proto.name, 748 package=file_proto.package, 749 syntax=file_proto.syntax, 750 options=_OptionsOrNone(file_proto), 751 serialized_pb=file_proto.SerializeToString(), 752 dependencies=direct_deps, 753 public_dependencies=public_deps, 754 # pylint: disable=protected-access 755 create_key=descriptor._internal_create_key) 756 scope = {} 757 758 # This loop extracts all the message and enum types from all the 759 # dependencies of the file_proto. This is necessary to create the 760 # scope of available message types when defining the passed in 761 # file proto. 762 for dependency in built_deps: 763 scope.update(self._ExtractSymbols( 764 dependency.message_types_by_name.values())) 765 scope.update((_PrefixWithDot(enum.full_name), enum) 766 for enum in dependency.enum_types_by_name.values()) 767 768 for message_type in file_proto.message_type: 769 message_desc = self._ConvertMessageDescriptor( 770 message_type, file_proto.package, file_descriptor, scope, 771 file_proto.syntax) 772 file_descriptor.message_types_by_name[message_desc.name] = ( 773 message_desc) 774 775 for enum_type in file_proto.enum_type: 776 file_descriptor.enum_types_by_name[enum_type.name] = ( 777 self._ConvertEnumDescriptor(enum_type, file_proto.package, 778 file_descriptor, None, scope, True)) 779 780 for index, extension_proto in enumerate(file_proto.extension): 781 extension_desc = self._MakeFieldDescriptor( 782 extension_proto, file_proto.package, index, file_descriptor, 783 is_extension=True) 784 extension_desc.containing_type = self._GetTypeFromScope( 785 file_descriptor.package, extension_proto.extendee, scope) 786 self._SetFieldType(extension_proto, extension_desc, 787 file_descriptor.package, scope) 788 file_descriptor.extensions_by_name[extension_desc.name] = ( 789 extension_desc) 790 self._file_desc_by_toplevel_extension[extension_desc.full_name] = ( 791 file_descriptor) 792 793 for desc_proto in file_proto.message_type: 794 self._SetAllFieldTypes(file_proto.package, desc_proto, scope) 795 796 if file_proto.package: 797 desc_proto_prefix = _PrefixWithDot(file_proto.package) 798 else: 799 desc_proto_prefix = '' 800 801 for desc_proto in file_proto.message_type: 802 desc = self._GetTypeFromScope( 803 desc_proto_prefix, desc_proto.name, scope) 804 file_descriptor.message_types_by_name[desc_proto.name] = desc 805 806 for index, service_proto in enumerate(file_proto.service): 807 file_descriptor.services_by_name[service_proto.name] = ( 808 self._MakeServiceDescriptor(service_proto, index, scope, 809 file_proto.package, file_descriptor)) 810 811 self.Add(file_proto) 812 self._file_descriptors[file_proto.name] = file_descriptor 813 814 # Add extensions to the pool 815 file_desc = self._file_descriptors[file_proto.name] 816 for extension in file_desc.extensions_by_name.values(): 817 self._AddExtensionDescriptor(extension) 818 for message_type in file_desc.message_types_by_name.values(): 819 for extension in message_type.extensions: 820 self._AddExtensionDescriptor(extension) 821 822 return file_desc 823 824 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, 825 scope=None, syntax=None): 826 """Adds the proto to the pool in the specified package. 827 828 Args: 829 desc_proto: The descriptor_pb2.DescriptorProto protobuf message. 830 package: The package the proto should be located in. 831 file_desc: The file containing this message. 832 scope: Dict mapping short and full symbols to message and enum types. 833 syntax: string indicating syntax of the file ("proto2" or "proto3") 834 835 Returns: 836 The added descriptor. 837 """ 838 839 if package: 840 desc_name = '.'.join((package, desc_proto.name)) 841 else: 842 desc_name = desc_proto.name 843 844 if file_desc is None: 845 file_name = None 846 else: 847 file_name = file_desc.name 848 849 if scope is None: 850 scope = {} 851 852 nested = [ 853 self._ConvertMessageDescriptor( 854 nested, desc_name, file_desc, scope, syntax) 855 for nested in desc_proto.nested_type] 856 enums = [ 857 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, 858 scope, False) 859 for enum in desc_proto.enum_type] 860 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) 861 for index, field in enumerate(desc_proto.field)] 862 extensions = [ 863 self._MakeFieldDescriptor(extension, desc_name, index, file_desc, 864 is_extension=True) 865 for index, extension in enumerate(desc_proto.extension)] 866 oneofs = [ 867 # pylint: disable=g-complex-comprehension 868 descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), 869 index, None, [], desc.options, 870 # pylint: disable=protected-access 871 create_key=descriptor._internal_create_key) 872 for index, desc in enumerate(desc_proto.oneof_decl)] 873 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] 874 if extension_ranges: 875 is_extendable = True 876 else: 877 is_extendable = False 878 desc = descriptor.Descriptor( 879 name=desc_proto.name, 880 full_name=desc_name, 881 filename=file_name, 882 containing_type=None, 883 fields=fields, 884 oneofs=oneofs, 885 nested_types=nested, 886 enum_types=enums, 887 extensions=extensions, 888 options=_OptionsOrNone(desc_proto), 889 is_extendable=is_extendable, 890 extension_ranges=extension_ranges, 891 file=file_desc, 892 serialized_start=None, 893 serialized_end=None, 894 syntax=syntax, 895 # pylint: disable=protected-access 896 create_key=descriptor._internal_create_key) 897 for nested in desc.nested_types: 898 nested.containing_type = desc 899 for enum in desc.enum_types: 900 enum.containing_type = desc 901 for field_index, field_desc in enumerate(desc_proto.field): 902 if field_desc.HasField('oneof_index'): 903 oneof_index = field_desc.oneof_index 904 oneofs[oneof_index].fields.append(fields[field_index]) 905 fields[field_index].containing_oneof = oneofs[oneof_index] 906 907 scope[_PrefixWithDot(desc_name)] = desc 908 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 909 self._descriptors[desc_name] = desc 910 return desc 911 912 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, 913 containing_type=None, scope=None, top_level=False): 914 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. 915 916 Args: 917 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message. 918 package: Optional package name for the new message EnumDescriptor. 919 file_desc: The file containing the enum descriptor. 920 containing_type: The type containing this enum. 921 scope: Scope containing available types. 922 top_level: If True, the enum is a top level symbol. If False, the enum 923 is defined inside a message. 924 925 Returns: 926 The added descriptor 927 """ 928 929 if package: 930 enum_name = '.'.join((package, enum_proto.name)) 931 else: 932 enum_name = enum_proto.name 933 934 if file_desc is None: 935 file_name = None 936 else: 937 file_name = file_desc.name 938 939 values = [self._MakeEnumValueDescriptor(value, index) 940 for index, value in enumerate(enum_proto.value)] 941 desc = descriptor.EnumDescriptor(name=enum_proto.name, 942 full_name=enum_name, 943 filename=file_name, 944 file=file_desc, 945 values=values, 946 containing_type=containing_type, 947 options=_OptionsOrNone(enum_proto), 948 # pylint: disable=protected-access 949 create_key=descriptor._internal_create_key) 950 scope['.%s' % enum_name] = desc 951 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 952 self._enum_descriptors[enum_name] = desc 953 954 # Add top level enum values. 955 if top_level: 956 for value in values: 957 full_name = _NormalizeFullyQualifiedName( 958 '.'.join((package, value.name))) 959 self._CheckConflictRegister(value, full_name, file_name) 960 self._top_enum_values[full_name] = value 961 962 return desc 963 964 def _MakeFieldDescriptor(self, field_proto, message_name, index, 965 file_desc, is_extension=False): 966 """Creates a field descriptor from a FieldDescriptorProto. 967 968 For message and enum type fields, this method will do a look up 969 in the pool for the appropriate descriptor for that type. If it 970 is unavailable, it will fall back to the _source function to 971 create it. If this type is still unavailable, construction will 972 fail. 973 974 Args: 975 field_proto: The proto describing the field. 976 message_name: The name of the containing message. 977 index: Index of the field 978 file_desc: The file containing the field descriptor. 979 is_extension: Indication that this field is for an extension. 980 981 Returns: 982 An initialized FieldDescriptor object 983 """ 984 985 if message_name: 986 full_name = '.'.join((message_name, field_proto.name)) 987 else: 988 full_name = field_proto.name 989 990 return descriptor.FieldDescriptor( 991 name=field_proto.name, 992 full_name=full_name, 993 index=index, 994 number=field_proto.number, 995 type=field_proto.type, 996 cpp_type=None, 997 message_type=None, 998 enum_type=None, 999 containing_type=None, 1000 label=field_proto.label, 1001 has_default_value=False, 1002 default_value=None, 1003 is_extension=is_extension, 1004 extension_scope=None, 1005 options=_OptionsOrNone(field_proto), 1006 file=file_desc, 1007 # pylint: disable=protected-access 1008 create_key=descriptor._internal_create_key) 1009 1010 def _SetAllFieldTypes(self, package, desc_proto, scope): 1011 """Sets all the descriptor's fields's types. 1012 1013 This method also sets the containing types on any extensions. 1014 1015 Args: 1016 package: The current package of desc_proto. 1017 desc_proto: The message descriptor to update. 1018 scope: Enclosing scope of available types. 1019 """ 1020 1021 package = _PrefixWithDot(package) 1022 1023 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope) 1024 1025 if package == '.': 1026 nested_package = _PrefixWithDot(desc_proto.name) 1027 else: 1028 nested_package = '.'.join([package, desc_proto.name]) 1029 1030 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): 1031 self._SetFieldType(field_proto, field_desc, nested_package, scope) 1032 1033 for extension_proto, extension_desc in ( 1034 zip(desc_proto.extension, main_desc.extensions)): 1035 extension_desc.containing_type = self._GetTypeFromScope( 1036 nested_package, extension_proto.extendee, scope) 1037 self._SetFieldType(extension_proto, extension_desc, nested_package, scope) 1038 1039 for nested_type in desc_proto.nested_type: 1040 self._SetAllFieldTypes(nested_package, nested_type, scope) 1041 1042 def _SetFieldType(self, field_proto, field_desc, package, scope): 1043 """Sets the field's type, cpp_type, message_type and enum_type. 1044 1045 Args: 1046 field_proto: Data about the field in proto format. 1047 field_desc: The descriptor to modify. 1048 package: The package the field's container is in. 1049 scope: Enclosing scope of available types. 1050 """ 1051 if field_proto.type_name: 1052 desc = self._GetTypeFromScope(package, field_proto.type_name, scope) 1053 else: 1054 desc = None 1055 1056 if not field_proto.HasField('type'): 1057 if isinstance(desc, descriptor.Descriptor): 1058 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE 1059 else: 1060 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM 1061 1062 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( 1063 field_proto.type) 1064 1065 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE 1066 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): 1067 field_desc.message_type = desc 1068 1069 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1070 field_desc.enum_type = desc 1071 1072 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: 1073 field_desc.has_default_value = False 1074 field_desc.default_value = [] 1075 elif field_proto.HasField('default_value'): 1076 field_desc.has_default_value = True 1077 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 1078 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 1079 field_desc.default_value = float(field_proto.default_value) 1080 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 1081 field_desc.default_value = field_proto.default_value 1082 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 1083 field_desc.default_value = field_proto.default_value.lower() == 'true' 1084 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1085 field_desc.default_value = field_desc.enum_type.values_by_name[ 1086 field_proto.default_value].number 1087 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 1088 field_desc.default_value = text_encoding.CUnescape( 1089 field_proto.default_value) 1090 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 1091 field_desc.default_value = None 1092 else: 1093 # All other types are of the "int" type. 1094 field_desc.default_value = int(field_proto.default_value) 1095 else: 1096 field_desc.has_default_value = False 1097 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 1098 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 1099 field_desc.default_value = 0.0 1100 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 1101 field_desc.default_value = u'' 1102 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 1103 field_desc.default_value = False 1104 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1105 field_desc.default_value = field_desc.enum_type.values[0].number 1106 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 1107 field_desc.default_value = b'' 1108 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 1109 field_desc.default_value = None 1110 else: 1111 # All other types are of the "int" type. 1112 field_desc.default_value = 0 1113 1114 field_desc.type = field_proto.type 1115 1116 def _MakeEnumValueDescriptor(self, value_proto, index): 1117 """Creates a enum value descriptor object from a enum value proto. 1118 1119 Args: 1120 value_proto: The proto describing the enum value. 1121 index: The index of the enum value. 1122 1123 Returns: 1124 An initialized EnumValueDescriptor object. 1125 """ 1126 1127 return descriptor.EnumValueDescriptor( 1128 name=value_proto.name, 1129 index=index, 1130 number=value_proto.number, 1131 options=_OptionsOrNone(value_proto), 1132 type=None, 1133 # pylint: disable=protected-access 1134 create_key=descriptor._internal_create_key) 1135 1136 def _MakeServiceDescriptor(self, service_proto, service_index, scope, 1137 package, file_desc): 1138 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto. 1139 1140 Args: 1141 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message. 1142 service_index: The index of the service in the File. 1143 scope: Dict mapping short and full symbols to message and enum types. 1144 package: Optional package name for the new message EnumDescriptor. 1145 file_desc: The file containing the service descriptor. 1146 1147 Returns: 1148 The added descriptor. 1149 """ 1150 1151 if package: 1152 service_name = '.'.join((package, service_proto.name)) 1153 else: 1154 service_name = service_proto.name 1155 1156 methods = [self._MakeMethodDescriptor(method_proto, service_name, package, 1157 scope, index) 1158 for index, method_proto in enumerate(service_proto.method)] 1159 desc = descriptor.ServiceDescriptor( 1160 name=service_proto.name, 1161 full_name=service_name, 1162 index=service_index, 1163 methods=methods, 1164 options=_OptionsOrNone(service_proto), 1165 file=file_desc, 1166 # pylint: disable=protected-access 1167 create_key=descriptor._internal_create_key) 1168 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 1169 self._service_descriptors[service_name] = desc 1170 return desc 1171 1172 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, 1173 index): 1174 """Creates a method descriptor from a MethodDescriptorProto. 1175 1176 Args: 1177 method_proto: The proto describing the method. 1178 service_name: The name of the containing service. 1179 package: Optional package name to look up for types. 1180 scope: Scope containing available types. 1181 index: Index of the method in the service. 1182 1183 Returns: 1184 An initialized MethodDescriptor object. 1185 """ 1186 full_name = '.'.join((service_name, method_proto.name)) 1187 input_type = self._GetTypeFromScope( 1188 package, method_proto.input_type, scope) 1189 output_type = self._GetTypeFromScope( 1190 package, method_proto.output_type, scope) 1191 return descriptor.MethodDescriptor( 1192 name=method_proto.name, 1193 full_name=full_name, 1194 index=index, 1195 containing_service=None, 1196 input_type=input_type, 1197 output_type=output_type, 1198 options=_OptionsOrNone(method_proto), 1199 # pylint: disable=protected-access 1200 create_key=descriptor._internal_create_key) 1201 1202 def _ExtractSymbols(self, descriptors): 1203 """Pulls out all the symbols from descriptor protos. 1204 1205 Args: 1206 descriptors: The messages to extract descriptors from. 1207 Yields: 1208 A two element tuple of the type name and descriptor object. 1209 """ 1210 1211 for desc in descriptors: 1212 yield (_PrefixWithDot(desc.full_name), desc) 1213 for symbol in self._ExtractSymbols(desc.nested_types): 1214 yield symbol 1215 for enum in desc.enum_types: 1216 yield (_PrefixWithDot(enum.full_name), enum) 1217 1218 def _GetDeps(self, dependencies): 1219 """Recursively finds dependencies for file protos. 1220 1221 Args: 1222 dependencies: The names of the files being depended on. 1223 1224 Yields: 1225 Each direct and indirect dependency. 1226 """ 1227 1228 for dependency in dependencies: 1229 dep_desc = self.FindFileByName(dependency) 1230 yield dep_desc 1231 for parent_dep in dep_desc.dependencies: 1232 yield parent_dep 1233 1234 def _GetTypeFromScope(self, package, type_name, scope): 1235 """Finds a given type name in the current scope. 1236 1237 Args: 1238 package: The package the proto should be located in. 1239 type_name: The name of the type to be found in the scope. 1240 scope: Dict mapping short and full symbols to message and enum types. 1241 1242 Returns: 1243 The descriptor for the requested type. 1244 """ 1245 if type_name not in scope: 1246 components = _PrefixWithDot(package).split('.') 1247 while components: 1248 possible_match = '.'.join(components + [type_name]) 1249 if possible_match in scope: 1250 type_name = possible_match 1251 break 1252 else: 1253 components.pop(-1) 1254 return scope[type_name] 1255 1256 1257def _PrefixWithDot(name): 1258 return name if name.startswith('.') else '.%s' % name 1259 1260 1261if _USE_C_DESCRIPTORS: 1262 # TODO(amauryfa): This pool could be constructed from Python code, when we 1263 # support a flag like 'use_cpp_generated_pool=True'. 1264 # pylint: disable=protected-access 1265 _DEFAULT = descriptor._message.default_pool 1266else: 1267 _DEFAULT = DescriptorPool() 1268 1269 1270def Default(): 1271 return _DEFAULT 1272