1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# 4# Use of this source code is governed by a BSD-style 5# license that can be found in the LICENSE file or at 6# https://developers.google.com/open-source/licenses/bsd 7 8"""Provides DescriptorPool to use as a container for proto2 descriptors. 9 10The DescriptorPool is used in conjection with a DescriptorDatabase to maintain 11a collection of protocol buffer descriptors for use when dynamically creating 12message types at runtime. 13 14For most applications protocol buffers should be used via modules generated by 15the protocol buffer compiler tool. This should only be used when the type of 16protocol buffers used in an application or library cannot be predetermined. 17 18Below is a straightforward example on how to use this class:: 19 20 pool = DescriptorPool() 21 file_descriptor_protos = [ ... ] 22 for file_descriptor_proto in file_descriptor_protos: 23 pool.Add(file_descriptor_proto) 24 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType') 25 26The message descriptor can be used in conjunction with the message_factory 27module in order to create a protocol buffer class that can be encoded and 28decoded. 29 30If you want to get a Python class for the specified proto, use the 31helper functions inside google.protobuf.message_factory 32directly instead of this class. 33""" 34 35__author__ = 'matthewtoia@google.com (Matt Toia)' 36 37import collections 38import threading 39import warnings 40 41from google.protobuf import descriptor 42from google.protobuf import descriptor_database 43from google.protobuf import text_encoding 44from google.protobuf.internal import python_edition_defaults 45from google.protobuf.internal import python_message 46 47_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access 48 49 50def _NormalizeFullyQualifiedName(name): 51 """Remove leading period from fully-qualified type name. 52 53 Due to b/13860351 in descriptor_database.py, types in the root namespace are 54 generated with a leading period. This function removes that prefix. 55 56 Args: 57 name (str): The fully-qualified symbol name. 58 59 Returns: 60 str: The normalized fully-qualified symbol name. 61 """ 62 return name.lstrip('.') 63 64 65def _OptionsOrNone(descriptor_proto): 66 """Returns the value of the field `options`, or None if it is not set.""" 67 if descriptor_proto.HasField('options'): 68 return descriptor_proto.options 69 else: 70 return None 71 72 73def _IsMessageSetExtension(field): 74 return (field.is_extension and 75 field.containing_type.has_options and 76 field.containing_type.GetOptions().message_set_wire_format and 77 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and 78 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) 79 80_edition_defaults_lock = threading.Lock() 81 82 83class DescriptorPool(object): 84 """A collection of protobufs dynamically constructed by descriptor protos.""" 85 86 if _USE_C_DESCRIPTORS: 87 88 def __new__(cls, descriptor_db=None): 89 # pylint: disable=protected-access 90 return descriptor._message.DescriptorPool(descriptor_db) 91 92 def __init__( 93 self, descriptor_db=None, use_deprecated_legacy_json_field_conflicts=False 94 ): 95 """Initializes a Pool of proto buffs. 96 97 The descriptor_db argument to the constructor is provided to allow 98 specialized file descriptor proto lookup code to be triggered on demand. An 99 example would be an implementation which will read and compile a file 100 specified in a call to FindFileByName() and not require the call to Add() 101 at all. Results from this database will be cached internally here as well. 102 103 Args: 104 descriptor_db: A secondary source of file descriptors. 105 use_deprecated_legacy_json_field_conflicts: Unused, for compatibility with 106 C++. 107 """ 108 109 self._internal_db = descriptor_database.DescriptorDatabase() 110 self._descriptor_db = descriptor_db 111 self._descriptors = {} 112 self._enum_descriptors = {} 113 self._service_descriptors = {} 114 self._file_descriptors = {} 115 self._toplevel_extensions = {} 116 self._top_enum_values = {} 117 # We store extensions in two two-level mappings: The first key is the 118 # descriptor of the message being extended, the second key is the extension 119 # full name or its tag number. 120 self._extensions_by_name = collections.defaultdict(dict) 121 self._extensions_by_number = collections.defaultdict(dict) 122 self._serialized_edition_defaults = ( 123 python_edition_defaults._PROTOBUF_INTERNAL_PYTHON_EDITION_DEFAULTS 124 ) 125 self._edition_defaults = None 126 self._feature_cache = dict() 127 128 def _CheckConflictRegister(self, desc, desc_name, file_name): 129 """Check if the descriptor name conflicts with another of the same name. 130 131 Args: 132 desc: Descriptor of a message, enum, service, extension or enum value. 133 desc_name (str): the full name of desc. 134 file_name (str): The file name of descriptor. 135 """ 136 for register, descriptor_type in [ 137 (self._descriptors, descriptor.Descriptor), 138 (self._enum_descriptors, descriptor.EnumDescriptor), 139 (self._service_descriptors, descriptor.ServiceDescriptor), 140 (self._toplevel_extensions, descriptor.FieldDescriptor), 141 (self._top_enum_values, descriptor.EnumValueDescriptor)]: 142 if desc_name in register: 143 old_desc = register[desc_name] 144 if isinstance(old_desc, descriptor.EnumValueDescriptor): 145 old_file = old_desc.type.file.name 146 else: 147 old_file = old_desc.file.name 148 149 if not isinstance(desc, descriptor_type) or ( 150 old_file != file_name): 151 error_msg = ('Conflict register for file "' + file_name + 152 '": ' + desc_name + 153 ' is already defined in file "' + 154 old_file + '". Please fix the conflict by adding ' 155 'package name on the proto file, or use different ' 156 'name for the duplication.') 157 if isinstance(desc, descriptor.EnumValueDescriptor): 158 error_msg += ('\nNote: enum values appear as ' 159 'siblings of the enum type instead of ' 160 'children of it.') 161 162 raise TypeError(error_msg) 163 164 return 165 166 def Add(self, file_desc_proto): 167 """Adds the FileDescriptorProto and its types to this pool. 168 169 Args: 170 file_desc_proto (FileDescriptorProto): The file descriptor to add. 171 """ 172 173 self._internal_db.Add(file_desc_proto) 174 175 def AddSerializedFile(self, serialized_file_desc_proto): 176 """Adds the FileDescriptorProto and its types to this pool. 177 178 Args: 179 serialized_file_desc_proto (bytes): A bytes string, serialization of the 180 :class:`FileDescriptorProto` to add. 181 182 Returns: 183 FileDescriptor: Descriptor for the added file. 184 """ 185 186 # pylint: disable=g-import-not-at-top 187 from google.protobuf import descriptor_pb2 188 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( 189 serialized_file_desc_proto) 190 file_desc = self._ConvertFileProtoToFileDescriptor(file_desc_proto) 191 file_desc.serialized_pb = serialized_file_desc_proto 192 return file_desc 193 194 # Never call this method. It is for internal usage only. 195 def _AddDescriptor(self, desc): 196 """Adds a Descriptor to the pool, non-recursively. 197 198 If the Descriptor contains nested messages or enums, the caller must 199 explicitly register them. This method also registers the FileDescriptor 200 associated with the message. 201 202 Args: 203 desc: A Descriptor. 204 """ 205 if not isinstance(desc, descriptor.Descriptor): 206 raise TypeError('Expected instance of descriptor.Descriptor.') 207 208 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 209 210 self._descriptors[desc.full_name] = desc 211 self._AddFileDescriptor(desc.file) 212 213 # Never call this method. It is for internal usage only. 214 def _AddEnumDescriptor(self, enum_desc): 215 """Adds an EnumDescriptor to the pool. 216 217 This method also registers the FileDescriptor associated with the enum. 218 219 Args: 220 enum_desc: An EnumDescriptor. 221 """ 222 223 if not isinstance(enum_desc, descriptor.EnumDescriptor): 224 raise TypeError('Expected instance of descriptor.EnumDescriptor.') 225 226 file_name = enum_desc.file.name 227 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name) 228 self._enum_descriptors[enum_desc.full_name] = enum_desc 229 230 # Top enum values need to be indexed. 231 # Count the number of dots to see whether the enum is toplevel or nested 232 # in a message. We cannot use enum_desc.containing_type at this stage. 233 if enum_desc.file.package: 234 top_level = (enum_desc.full_name.count('.') 235 - enum_desc.file.package.count('.') == 1) 236 else: 237 top_level = enum_desc.full_name.count('.') == 0 238 if top_level: 239 file_name = enum_desc.file.name 240 package = enum_desc.file.package 241 for enum_value in enum_desc.values: 242 full_name = _NormalizeFullyQualifiedName( 243 '.'.join((package, enum_value.name))) 244 self._CheckConflictRegister(enum_value, full_name, file_name) 245 self._top_enum_values[full_name] = enum_value 246 self._AddFileDescriptor(enum_desc.file) 247 248 # Never call this method. It is for internal usage only. 249 def _AddServiceDescriptor(self, service_desc): 250 """Adds a ServiceDescriptor to the pool. 251 252 Args: 253 service_desc: A ServiceDescriptor. 254 """ 255 256 if not isinstance(service_desc, descriptor.ServiceDescriptor): 257 raise TypeError('Expected instance of descriptor.ServiceDescriptor.') 258 259 self._CheckConflictRegister(service_desc, service_desc.full_name, 260 service_desc.file.name) 261 self._service_descriptors[service_desc.full_name] = service_desc 262 263 # Never call this method. It is for internal usage only. 264 def _AddExtensionDescriptor(self, extension): 265 """Adds a FieldDescriptor describing an extension to the pool. 266 267 Args: 268 extension: A FieldDescriptor. 269 270 Raises: 271 AssertionError: when another extension with the same number extends the 272 same message. 273 TypeError: when the specified extension is not a 274 descriptor.FieldDescriptor. 275 """ 276 if not (isinstance(extension, descriptor.FieldDescriptor) and 277 extension.is_extension): 278 raise TypeError('Expected an extension descriptor.') 279 280 if extension.extension_scope is None: 281 self._CheckConflictRegister( 282 extension, extension.full_name, extension.file.name) 283 self._toplevel_extensions[extension.full_name] = extension 284 285 try: 286 existing_desc = self._extensions_by_number[ 287 extension.containing_type][extension.number] 288 except KeyError: 289 pass 290 else: 291 if extension is not existing_desc: 292 raise AssertionError( 293 'Extensions "%s" and "%s" both try to extend message type "%s" ' 294 'with field number %d.' % 295 (extension.full_name, existing_desc.full_name, 296 extension.containing_type.full_name, extension.number)) 297 298 self._extensions_by_number[extension.containing_type][ 299 extension.number] = extension 300 self._extensions_by_name[extension.containing_type][ 301 extension.full_name] = extension 302 303 # Also register MessageSet extensions with the type name. 304 if _IsMessageSetExtension(extension): 305 self._extensions_by_name[extension.containing_type][ 306 extension.message_type.full_name] = extension 307 308 if hasattr(extension.containing_type, '_concrete_class'): 309 python_message._AttachFieldHelpers( 310 extension.containing_type._concrete_class, extension) 311 312 # Never call this method. It is for internal usage only. 313 def _InternalAddFileDescriptor(self, file_desc): 314 """Adds a FileDescriptor to the pool, non-recursively. 315 316 If the FileDescriptor contains messages or enums, the caller must explicitly 317 register them. 318 319 Args: 320 file_desc: A FileDescriptor. 321 """ 322 323 self._AddFileDescriptor(file_desc) 324 325 def _AddFileDescriptor(self, file_desc): 326 """Adds a FileDescriptor to the pool, non-recursively. 327 328 If the FileDescriptor contains messages or enums, the caller must explicitly 329 register them. 330 331 Args: 332 file_desc: A FileDescriptor. 333 """ 334 335 if not isinstance(file_desc, descriptor.FileDescriptor): 336 raise TypeError('Expected instance of descriptor.FileDescriptor.') 337 self._file_descriptors[file_desc.name] = file_desc 338 339 def FindFileByName(self, file_name): 340 """Gets a FileDescriptor by file name. 341 342 Args: 343 file_name (str): The path to the file to get a descriptor for. 344 345 Returns: 346 FileDescriptor: The descriptor for the named file. 347 348 Raises: 349 KeyError: if the file cannot be found in the pool. 350 """ 351 352 try: 353 return self._file_descriptors[file_name] 354 except KeyError: 355 pass 356 357 try: 358 file_proto = self._internal_db.FindFileByName(file_name) 359 except KeyError as error: 360 if self._descriptor_db: 361 file_proto = self._descriptor_db.FindFileByName(file_name) 362 else: 363 raise error 364 if not file_proto: 365 raise KeyError('Cannot find a file named %s' % file_name) 366 return self._ConvertFileProtoToFileDescriptor(file_proto) 367 368 def FindFileContainingSymbol(self, symbol): 369 """Gets the FileDescriptor for the file containing the specified symbol. 370 371 Args: 372 symbol (str): The name of the symbol to search for. 373 374 Returns: 375 FileDescriptor: Descriptor for the file that contains the specified 376 symbol. 377 378 Raises: 379 KeyError: if the file cannot be found in the pool. 380 """ 381 382 symbol = _NormalizeFullyQualifiedName(symbol) 383 try: 384 return self._InternalFindFileContainingSymbol(symbol) 385 except KeyError: 386 pass 387 388 try: 389 # Try fallback database. Build and find again if possible. 390 self._FindFileContainingSymbolInDb(symbol) 391 return self._InternalFindFileContainingSymbol(symbol) 392 except KeyError: 393 raise KeyError('Cannot find a file containing %s' % symbol) 394 395 def _InternalFindFileContainingSymbol(self, symbol): 396 """Gets the already built FileDescriptor containing the specified symbol. 397 398 Args: 399 symbol (str): The name of the symbol to search for. 400 401 Returns: 402 FileDescriptor: Descriptor for the file that contains the specified 403 symbol. 404 405 Raises: 406 KeyError: if the file cannot be found in the pool. 407 """ 408 try: 409 return self._descriptors[symbol].file 410 except KeyError: 411 pass 412 413 try: 414 return self._enum_descriptors[symbol].file 415 except KeyError: 416 pass 417 418 try: 419 return self._service_descriptors[symbol].file 420 except KeyError: 421 pass 422 423 try: 424 return self._top_enum_values[symbol].type.file 425 except KeyError: 426 pass 427 428 try: 429 return self._toplevel_extensions[symbol].file 430 except KeyError: 431 pass 432 433 # Try fields, enum values and nested extensions inside a message. 434 top_name, _, sub_name = symbol.rpartition('.') 435 try: 436 message = self.FindMessageTypeByName(top_name) 437 assert (sub_name in message.extensions_by_name or 438 sub_name in message.fields_by_name or 439 sub_name in message.enum_values_by_name) 440 return message.file 441 except (KeyError, AssertionError): 442 raise KeyError('Cannot find a file containing %s' % symbol) 443 444 def FindMessageTypeByName(self, full_name): 445 """Loads the named descriptor from the pool. 446 447 Args: 448 full_name (str): The full name of the descriptor to load. 449 450 Returns: 451 Descriptor: The descriptor for the named type. 452 453 Raises: 454 KeyError: if the message cannot be found in the pool. 455 """ 456 457 full_name = _NormalizeFullyQualifiedName(full_name) 458 if full_name not in self._descriptors: 459 self._FindFileContainingSymbolInDb(full_name) 460 return self._descriptors[full_name] 461 462 def FindEnumTypeByName(self, full_name): 463 """Loads the named enum descriptor from the pool. 464 465 Args: 466 full_name (str): The full name of the enum descriptor to load. 467 468 Returns: 469 EnumDescriptor: The enum descriptor for the named type. 470 471 Raises: 472 KeyError: if the enum cannot be found in the pool. 473 """ 474 475 full_name = _NormalizeFullyQualifiedName(full_name) 476 if full_name not in self._enum_descriptors: 477 self._FindFileContainingSymbolInDb(full_name) 478 return self._enum_descriptors[full_name] 479 480 def FindFieldByName(self, full_name): 481 """Loads the named field descriptor from the pool. 482 483 Args: 484 full_name (str): The full name of the field descriptor to load. 485 486 Returns: 487 FieldDescriptor: The field descriptor for the named field. 488 489 Raises: 490 KeyError: if the field cannot be found in the pool. 491 """ 492 full_name = _NormalizeFullyQualifiedName(full_name) 493 message_name, _, field_name = full_name.rpartition('.') 494 message_descriptor = self.FindMessageTypeByName(message_name) 495 return message_descriptor.fields_by_name[field_name] 496 497 def FindOneofByName(self, full_name): 498 """Loads the named oneof descriptor from the pool. 499 500 Args: 501 full_name (str): The full name of the oneof descriptor to load. 502 503 Returns: 504 OneofDescriptor: The oneof descriptor for the named oneof. 505 506 Raises: 507 KeyError: if the oneof cannot be found in the pool. 508 """ 509 full_name = _NormalizeFullyQualifiedName(full_name) 510 message_name, _, oneof_name = full_name.rpartition('.') 511 message_descriptor = self.FindMessageTypeByName(message_name) 512 return message_descriptor.oneofs_by_name[oneof_name] 513 514 def FindExtensionByName(self, full_name): 515 """Loads the named extension descriptor from the pool. 516 517 Args: 518 full_name (str): The full name of the extension descriptor to load. 519 520 Returns: 521 FieldDescriptor: The field descriptor for the named extension. 522 523 Raises: 524 KeyError: if the extension cannot be found in the pool. 525 """ 526 full_name = _NormalizeFullyQualifiedName(full_name) 527 try: 528 # The proto compiler does not give any link between the FileDescriptor 529 # and top-level extensions unless the FileDescriptorProto is added to 530 # the DescriptorDatabase, but this can impact memory usage. 531 # So we registered these extensions by name explicitly. 532 return self._toplevel_extensions[full_name] 533 except KeyError: 534 pass 535 message_name, _, extension_name = full_name.rpartition('.') 536 try: 537 # Most extensions are nested inside a message. 538 scope = self.FindMessageTypeByName(message_name) 539 except KeyError: 540 # Some extensions are defined at file scope. 541 scope = self._FindFileContainingSymbolInDb(full_name) 542 return scope.extensions_by_name[extension_name] 543 544 def FindExtensionByNumber(self, message_descriptor, number): 545 """Gets the extension of the specified message with the specified number. 546 547 Extensions have to be registered to this pool by calling :func:`Add` or 548 :func:`AddExtensionDescriptor`. 549 550 Args: 551 message_descriptor (Descriptor): descriptor of the extended message. 552 number (int): Number of the extension field. 553 554 Returns: 555 FieldDescriptor: The descriptor for the extension. 556 557 Raises: 558 KeyError: when no extension with the given number is known for the 559 specified message. 560 """ 561 try: 562 return self._extensions_by_number[message_descriptor][number] 563 except KeyError: 564 self._TryLoadExtensionFromDB(message_descriptor, number) 565 return self._extensions_by_number[message_descriptor][number] 566 567 def FindAllExtensions(self, message_descriptor): 568 """Gets all the known extensions of a given message. 569 570 Extensions have to be registered to this pool by build related 571 :func:`Add` or :func:`AddExtensionDescriptor`. 572 573 Args: 574 message_descriptor (Descriptor): Descriptor of the extended message. 575 576 Returns: 577 list[FieldDescriptor]: Field descriptors describing the extensions. 578 """ 579 # Fallback to descriptor db if FindAllExtensionNumbers is provided. 580 if self._descriptor_db and hasattr( 581 self._descriptor_db, 'FindAllExtensionNumbers'): 582 full_name = message_descriptor.full_name 583 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name) 584 for number in all_numbers: 585 if number in self._extensions_by_number[message_descriptor]: 586 continue 587 self._TryLoadExtensionFromDB(message_descriptor, number) 588 589 return list(self._extensions_by_number[message_descriptor].values()) 590 591 def _TryLoadExtensionFromDB(self, message_descriptor, number): 592 """Try to Load extensions from descriptor db. 593 594 Args: 595 message_descriptor: descriptor of the extended message. 596 number: the extension number that needs to be loaded. 597 """ 598 if not self._descriptor_db: 599 return 600 # Only supported when FindFileContainingExtension is provided. 601 if not hasattr( 602 self._descriptor_db, 'FindFileContainingExtension'): 603 return 604 605 full_name = message_descriptor.full_name 606 file_proto = self._descriptor_db.FindFileContainingExtension( 607 full_name, number) 608 609 if file_proto is None: 610 return 611 612 try: 613 self._ConvertFileProtoToFileDescriptor(file_proto) 614 except: 615 warn_msg = ('Unable to load proto file %s for extension number %d.' % 616 (file_proto.name, number)) 617 warnings.warn(warn_msg, RuntimeWarning) 618 619 def FindServiceByName(self, full_name): 620 """Loads the named service descriptor from the pool. 621 622 Args: 623 full_name (str): The full name of the service descriptor to load. 624 625 Returns: 626 ServiceDescriptor: The service descriptor for the named service. 627 628 Raises: 629 KeyError: if the service cannot be found in the pool. 630 """ 631 full_name = _NormalizeFullyQualifiedName(full_name) 632 if full_name not in self._service_descriptors: 633 self._FindFileContainingSymbolInDb(full_name) 634 return self._service_descriptors[full_name] 635 636 def FindMethodByName(self, full_name): 637 """Loads the named service method descriptor from the pool. 638 639 Args: 640 full_name (str): The full name of the method descriptor to load. 641 642 Returns: 643 MethodDescriptor: The method descriptor for the service method. 644 645 Raises: 646 KeyError: if the method cannot be found in the pool. 647 """ 648 full_name = _NormalizeFullyQualifiedName(full_name) 649 service_name, _, method_name = full_name.rpartition('.') 650 service_descriptor = self.FindServiceByName(service_name) 651 return service_descriptor.methods_by_name[method_name] 652 653 def SetFeatureSetDefaults(self, defaults): 654 """Sets the default feature mappings used during the build. 655 656 Args: 657 defaults: a FeatureSetDefaults message containing the new mappings. 658 """ 659 if self._edition_defaults is not None: 660 raise ValueError( 661 "Feature set defaults can't be changed once the pool has started" 662 ' building!' 663 ) 664 665 # pylint: disable=g-import-not-at-top 666 from google.protobuf import descriptor_pb2 667 668 if not isinstance(defaults, descriptor_pb2.FeatureSetDefaults): 669 raise TypeError('SetFeatureSetDefaults called with invalid type') 670 671 672 if defaults.minimum_edition > defaults.maximum_edition: 673 raise ValueError( 674 'Invalid edition range %s to %s' 675 % ( 676 descriptor_pb2.Edition.Name(defaults.minimum_edition), 677 descriptor_pb2.Edition.Name(defaults.maximum_edition), 678 ) 679 ) 680 681 prev_edition = descriptor_pb2.Edition.EDITION_UNKNOWN 682 for d in defaults.defaults: 683 if d.edition == descriptor_pb2.Edition.EDITION_UNKNOWN: 684 raise ValueError('Invalid edition EDITION_UNKNOWN specified') 685 if prev_edition >= d.edition: 686 raise ValueError( 687 'Feature set defaults are not strictly increasing. %s is greater' 688 ' than or equal to %s' 689 % ( 690 descriptor_pb2.Edition.Name(prev_edition), 691 descriptor_pb2.Edition.Name(d.edition), 692 ) 693 ) 694 prev_edition = d.edition 695 self._edition_defaults = defaults 696 697 def _CreateDefaultFeatures(self, edition): 698 """Creates a FeatureSet message with defaults for a specific edition. 699 700 Args: 701 edition: the edition to generate defaults for. 702 703 Returns: 704 A FeatureSet message with defaults for a specific edition. 705 """ 706 # pylint: disable=g-import-not-at-top 707 from google.protobuf import descriptor_pb2 708 709 with _edition_defaults_lock: 710 if not self._edition_defaults: 711 self._edition_defaults = descriptor_pb2.FeatureSetDefaults() 712 self._edition_defaults.ParseFromString( 713 self._serialized_edition_defaults 714 ) 715 716 if edition < self._edition_defaults.minimum_edition: 717 raise TypeError( 718 'Edition %s is earlier than the minimum supported edition %s!' 719 % ( 720 descriptor_pb2.Edition.Name(edition), 721 descriptor_pb2.Edition.Name( 722 self._edition_defaults.minimum_edition 723 ), 724 ) 725 ) 726 if edition > self._edition_defaults.maximum_edition: 727 raise TypeError( 728 'Edition %s is later than the maximum supported edition %s!' 729 % ( 730 descriptor_pb2.Edition.Name(edition), 731 descriptor_pb2.Edition.Name( 732 self._edition_defaults.maximum_edition 733 ), 734 ) 735 ) 736 found = None 737 for d in self._edition_defaults.defaults: 738 if d.edition > edition: 739 break 740 found = d 741 if found is None: 742 raise TypeError( 743 'No valid default found for edition %s!' 744 % descriptor_pb2.Edition.Name(edition) 745 ) 746 747 defaults = descriptor_pb2.FeatureSet() 748 defaults.CopyFrom(found.fixed_features) 749 defaults.MergeFrom(found.overridable_features) 750 return defaults 751 752 def _InternFeatures(self, features): 753 serialized = features.SerializeToString() 754 with _edition_defaults_lock: 755 cached = self._feature_cache.get(serialized) 756 if cached is None: 757 self._feature_cache[serialized] = features 758 cached = features 759 return cached 760 761 def _FindFileContainingSymbolInDb(self, symbol): 762 """Finds the file in descriptor DB containing the specified symbol. 763 764 Args: 765 symbol (str): The name of the symbol to search for. 766 767 Returns: 768 FileDescriptor: The file that contains the specified symbol. 769 770 Raises: 771 KeyError: if the file cannot be found in the descriptor database. 772 """ 773 try: 774 file_proto = self._internal_db.FindFileContainingSymbol(symbol) 775 except KeyError as error: 776 if self._descriptor_db: 777 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) 778 else: 779 raise error 780 if not file_proto: 781 raise KeyError('Cannot find a file containing %s' % symbol) 782 return self._ConvertFileProtoToFileDescriptor(file_proto) 783 784 def _ConvertFileProtoToFileDescriptor(self, file_proto): 785 """Creates a FileDescriptor from a proto or returns a cached copy. 786 787 This method also has the side effect of loading all the symbols found in 788 the file into the appropriate dictionaries in the pool. 789 790 Args: 791 file_proto: The proto to convert. 792 793 Returns: 794 A FileDescriptor matching the passed in proto. 795 """ 796 if file_proto.name not in self._file_descriptors: 797 built_deps = list(self._GetDeps(file_proto.dependency)) 798 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] 799 public_deps = [direct_deps[i] for i in file_proto.public_dependency] 800 801 # pylint: disable=g-import-not-at-top 802 from google.protobuf import descriptor_pb2 803 804 file_descriptor = descriptor.FileDescriptor( 805 pool=self, 806 name=file_proto.name, 807 package=file_proto.package, 808 syntax=file_proto.syntax, 809 edition=descriptor_pb2.Edition.Name(file_proto.edition), 810 options=_OptionsOrNone(file_proto), 811 serialized_pb=file_proto.SerializeToString(), 812 dependencies=direct_deps, 813 public_dependencies=public_deps, 814 # pylint: disable=protected-access 815 create_key=descriptor._internal_create_key, 816 ) 817 scope = {} 818 819 # This loop extracts all the message and enum types from all the 820 # dependencies of the file_proto. This is necessary to create the 821 # scope of available message types when defining the passed in 822 # file proto. 823 for dependency in built_deps: 824 scope.update(self._ExtractSymbols( 825 dependency.message_types_by_name.values())) 826 scope.update((_PrefixWithDot(enum.full_name), enum) 827 for enum in dependency.enum_types_by_name.values()) 828 829 for message_type in file_proto.message_type: 830 message_desc = self._ConvertMessageDescriptor( 831 message_type, file_proto.package, file_descriptor, scope, 832 file_proto.syntax) 833 file_descriptor.message_types_by_name[message_desc.name] = ( 834 message_desc) 835 836 for enum_type in file_proto.enum_type: 837 file_descriptor.enum_types_by_name[enum_type.name] = ( 838 self._ConvertEnumDescriptor(enum_type, file_proto.package, 839 file_descriptor, None, scope, True)) 840 841 for index, extension_proto in enumerate(file_proto.extension): 842 extension_desc = self._MakeFieldDescriptor( 843 extension_proto, file_proto.package, index, file_descriptor, 844 is_extension=True) 845 extension_desc.containing_type = self._GetTypeFromScope( 846 file_descriptor.package, extension_proto.extendee, scope) 847 self._SetFieldType(extension_proto, extension_desc, 848 file_descriptor.package, scope) 849 file_descriptor.extensions_by_name[extension_desc.name] = ( 850 extension_desc) 851 852 for desc_proto in file_proto.message_type: 853 self._SetAllFieldTypes(file_proto.package, desc_proto, scope) 854 855 if file_proto.package: 856 desc_proto_prefix = _PrefixWithDot(file_proto.package) 857 else: 858 desc_proto_prefix = '' 859 860 for desc_proto in file_proto.message_type: 861 desc = self._GetTypeFromScope( 862 desc_proto_prefix, desc_proto.name, scope) 863 file_descriptor.message_types_by_name[desc_proto.name] = desc 864 865 for index, service_proto in enumerate(file_proto.service): 866 file_descriptor.services_by_name[service_proto.name] = ( 867 self._MakeServiceDescriptor(service_proto, index, scope, 868 file_proto.package, file_descriptor)) 869 870 self._file_descriptors[file_proto.name] = file_descriptor 871 872 # Add extensions to the pool 873 def AddExtensionForNested(message_type): 874 for nested in message_type.nested_types: 875 AddExtensionForNested(nested) 876 for extension in message_type.extensions: 877 self._AddExtensionDescriptor(extension) 878 879 file_desc = self._file_descriptors[file_proto.name] 880 for extension in file_desc.extensions_by_name.values(): 881 self._AddExtensionDescriptor(extension) 882 for message_type in file_desc.message_types_by_name.values(): 883 AddExtensionForNested(message_type) 884 885 return file_desc 886 887 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, 888 scope=None, syntax=None): 889 """Adds the proto to the pool in the specified package. 890 891 Args: 892 desc_proto: The descriptor_pb2.DescriptorProto protobuf message. 893 package: The package the proto should be located in. 894 file_desc: The file containing this message. 895 scope: Dict mapping short and full symbols to message and enum types. 896 syntax: string indicating syntax of the file ("proto2" or "proto3") 897 898 Returns: 899 The added descriptor. 900 """ 901 902 if package: 903 desc_name = '.'.join((package, desc_proto.name)) 904 else: 905 desc_name = desc_proto.name 906 907 if file_desc is None: 908 file_name = None 909 else: 910 file_name = file_desc.name 911 912 if scope is None: 913 scope = {} 914 915 nested = [ 916 self._ConvertMessageDescriptor( 917 nested, desc_name, file_desc, scope, syntax) 918 for nested in desc_proto.nested_type] 919 enums = [ 920 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, 921 scope, False) 922 for enum in desc_proto.enum_type] 923 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) 924 for index, field in enumerate(desc_proto.field)] 925 extensions = [ 926 self._MakeFieldDescriptor(extension, desc_name, index, file_desc, 927 is_extension=True) 928 for index, extension in enumerate(desc_proto.extension)] 929 oneofs = [ 930 # pylint: disable=g-complex-comprehension 931 descriptor.OneofDescriptor( 932 desc.name, 933 '.'.join((desc_name, desc.name)), 934 index, 935 None, 936 [], 937 _OptionsOrNone(desc), 938 # pylint: disable=protected-access 939 create_key=descriptor._internal_create_key) 940 for index, desc in enumerate(desc_proto.oneof_decl) 941 ] 942 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] 943 if extension_ranges: 944 is_extendable = True 945 else: 946 is_extendable = False 947 desc = descriptor.Descriptor( 948 name=desc_proto.name, 949 full_name=desc_name, 950 filename=file_name, 951 containing_type=None, 952 fields=fields, 953 oneofs=oneofs, 954 nested_types=nested, 955 enum_types=enums, 956 extensions=extensions, 957 options=_OptionsOrNone(desc_proto), 958 is_extendable=is_extendable, 959 extension_ranges=extension_ranges, 960 file=file_desc, 961 serialized_start=None, 962 serialized_end=None, 963 is_map_entry=desc_proto.options.map_entry, 964 # pylint: disable=protected-access 965 create_key=descriptor._internal_create_key, 966 ) 967 for nested in desc.nested_types: 968 nested.containing_type = desc 969 for enum in desc.enum_types: 970 enum.containing_type = desc 971 for field_index, field_desc in enumerate(desc_proto.field): 972 if field_desc.HasField('oneof_index'): 973 oneof_index = field_desc.oneof_index 974 oneofs[oneof_index].fields.append(fields[field_index]) 975 fields[field_index].containing_oneof = oneofs[oneof_index] 976 977 scope[_PrefixWithDot(desc_name)] = desc 978 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 979 self._descriptors[desc_name] = desc 980 return desc 981 982 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, 983 containing_type=None, scope=None, top_level=False): 984 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. 985 986 Args: 987 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message. 988 package: Optional package name for the new message EnumDescriptor. 989 file_desc: The file containing the enum descriptor. 990 containing_type: The type containing this enum. 991 scope: Scope containing available types. 992 top_level: If True, the enum is a top level symbol. If False, the enum 993 is defined inside a message. 994 995 Returns: 996 The added descriptor 997 """ 998 999 if package: 1000 enum_name = '.'.join((package, enum_proto.name)) 1001 else: 1002 enum_name = enum_proto.name 1003 1004 if file_desc is None: 1005 file_name = None 1006 else: 1007 file_name = file_desc.name 1008 1009 values = [self._MakeEnumValueDescriptor(value, index) 1010 for index, value in enumerate(enum_proto.value)] 1011 desc = descriptor.EnumDescriptor(name=enum_proto.name, 1012 full_name=enum_name, 1013 filename=file_name, 1014 file=file_desc, 1015 values=values, 1016 containing_type=containing_type, 1017 options=_OptionsOrNone(enum_proto), 1018 # pylint: disable=protected-access 1019 create_key=descriptor._internal_create_key) 1020 scope['.%s' % enum_name] = desc 1021 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 1022 self._enum_descriptors[enum_name] = desc 1023 1024 # Add top level enum values. 1025 if top_level: 1026 for value in values: 1027 full_name = _NormalizeFullyQualifiedName( 1028 '.'.join((package, value.name))) 1029 self._CheckConflictRegister(value, full_name, file_name) 1030 self._top_enum_values[full_name] = value 1031 1032 return desc 1033 1034 def _MakeFieldDescriptor(self, field_proto, message_name, index, 1035 file_desc, is_extension=False): 1036 """Creates a field descriptor from a FieldDescriptorProto. 1037 1038 For message and enum type fields, this method will do a look up 1039 in the pool for the appropriate descriptor for that type. If it 1040 is unavailable, it will fall back to the _source function to 1041 create it. If this type is still unavailable, construction will 1042 fail. 1043 1044 Args: 1045 field_proto: The proto describing the field. 1046 message_name: The name of the containing message. 1047 index: Index of the field 1048 file_desc: The file containing the field descriptor. 1049 is_extension: Indication that this field is for an extension. 1050 1051 Returns: 1052 An initialized FieldDescriptor object 1053 """ 1054 1055 if message_name: 1056 full_name = '.'.join((message_name, field_proto.name)) 1057 else: 1058 full_name = field_proto.name 1059 1060 if field_proto.json_name: 1061 json_name = field_proto.json_name 1062 else: 1063 json_name = None 1064 1065 return descriptor.FieldDescriptor( 1066 name=field_proto.name, 1067 full_name=full_name, 1068 index=index, 1069 number=field_proto.number, 1070 type=field_proto.type, 1071 cpp_type=None, 1072 message_type=None, 1073 enum_type=None, 1074 containing_type=None, 1075 label=field_proto.label, 1076 has_default_value=False, 1077 default_value=None, 1078 is_extension=is_extension, 1079 extension_scope=None, 1080 options=_OptionsOrNone(field_proto), 1081 json_name=json_name, 1082 file=file_desc, 1083 # pylint: disable=protected-access 1084 create_key=descriptor._internal_create_key) 1085 1086 def _SetAllFieldTypes(self, package, desc_proto, scope): 1087 """Sets all the descriptor's fields's types. 1088 1089 This method also sets the containing types on any extensions. 1090 1091 Args: 1092 package: The current package of desc_proto. 1093 desc_proto: The message descriptor to update. 1094 scope: Enclosing scope of available types. 1095 """ 1096 1097 package = _PrefixWithDot(package) 1098 1099 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope) 1100 1101 if package == '.': 1102 nested_package = _PrefixWithDot(desc_proto.name) 1103 else: 1104 nested_package = '.'.join([package, desc_proto.name]) 1105 1106 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): 1107 self._SetFieldType(field_proto, field_desc, nested_package, scope) 1108 1109 for extension_proto, extension_desc in ( 1110 zip(desc_proto.extension, main_desc.extensions)): 1111 extension_desc.containing_type = self._GetTypeFromScope( 1112 nested_package, extension_proto.extendee, scope) 1113 self._SetFieldType(extension_proto, extension_desc, nested_package, scope) 1114 1115 for nested_type in desc_proto.nested_type: 1116 self._SetAllFieldTypes(nested_package, nested_type, scope) 1117 1118 def _SetFieldType(self, field_proto, field_desc, package, scope): 1119 """Sets the field's type, cpp_type, message_type and enum_type. 1120 1121 Args: 1122 field_proto: Data about the field in proto format. 1123 field_desc: The descriptor to modify. 1124 package: The package the field's container is in. 1125 scope: Enclosing scope of available types. 1126 """ 1127 if field_proto.type_name: 1128 desc = self._GetTypeFromScope(package, field_proto.type_name, scope) 1129 else: 1130 desc = None 1131 1132 if not field_proto.HasField('type'): 1133 if isinstance(desc, descriptor.Descriptor): 1134 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE 1135 else: 1136 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM 1137 1138 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( 1139 field_proto.type) 1140 1141 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE 1142 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): 1143 field_desc.message_type = desc 1144 1145 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1146 field_desc.enum_type = desc 1147 1148 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: 1149 field_desc.has_default_value = False 1150 field_desc.default_value = [] 1151 elif field_proto.HasField('default_value'): 1152 field_desc.has_default_value = True 1153 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 1154 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 1155 field_desc.default_value = float(field_proto.default_value) 1156 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 1157 field_desc.default_value = field_proto.default_value 1158 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 1159 field_desc.default_value = field_proto.default_value.lower() == 'true' 1160 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1161 field_desc.default_value = field_desc.enum_type.values_by_name[ 1162 field_proto.default_value].number 1163 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 1164 field_desc.default_value = text_encoding.CUnescape( 1165 field_proto.default_value) 1166 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 1167 field_desc.default_value = None 1168 else: 1169 # All other types are of the "int" type. 1170 field_desc.default_value = int(field_proto.default_value) 1171 else: 1172 field_desc.has_default_value = False 1173 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 1174 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 1175 field_desc.default_value = 0.0 1176 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 1177 field_desc.default_value = u'' 1178 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 1179 field_desc.default_value = False 1180 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1181 field_desc.default_value = field_desc.enum_type.values[0].number 1182 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 1183 field_desc.default_value = b'' 1184 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 1185 field_desc.default_value = None 1186 elif field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP: 1187 field_desc.default_value = None 1188 else: 1189 # All other types are of the "int" type. 1190 field_desc.default_value = 0 1191 1192 field_desc.type = field_proto.type 1193 1194 def _MakeEnumValueDescriptor(self, value_proto, index): 1195 """Creates a enum value descriptor object from a enum value proto. 1196 1197 Args: 1198 value_proto: The proto describing the enum value. 1199 index: The index of the enum value. 1200 1201 Returns: 1202 An initialized EnumValueDescriptor object. 1203 """ 1204 1205 return descriptor.EnumValueDescriptor( 1206 name=value_proto.name, 1207 index=index, 1208 number=value_proto.number, 1209 options=_OptionsOrNone(value_proto), 1210 type=None, 1211 # pylint: disable=protected-access 1212 create_key=descriptor._internal_create_key) 1213 1214 def _MakeServiceDescriptor(self, service_proto, service_index, scope, 1215 package, file_desc): 1216 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto. 1217 1218 Args: 1219 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message. 1220 service_index: The index of the service in the File. 1221 scope: Dict mapping short and full symbols to message and enum types. 1222 package: Optional package name for the new message EnumDescriptor. 1223 file_desc: The file containing the service descriptor. 1224 1225 Returns: 1226 The added descriptor. 1227 """ 1228 1229 if package: 1230 service_name = '.'.join((package, service_proto.name)) 1231 else: 1232 service_name = service_proto.name 1233 1234 methods = [self._MakeMethodDescriptor(method_proto, service_name, package, 1235 scope, index) 1236 for index, method_proto in enumerate(service_proto.method)] 1237 desc = descriptor.ServiceDescriptor( 1238 name=service_proto.name, 1239 full_name=service_name, 1240 index=service_index, 1241 methods=methods, 1242 options=_OptionsOrNone(service_proto), 1243 file=file_desc, 1244 # pylint: disable=protected-access 1245 create_key=descriptor._internal_create_key) 1246 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 1247 self._service_descriptors[service_name] = desc 1248 return desc 1249 1250 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, 1251 index): 1252 """Creates a method descriptor from a MethodDescriptorProto. 1253 1254 Args: 1255 method_proto: The proto describing the method. 1256 service_name: The name of the containing service. 1257 package: Optional package name to look up for types. 1258 scope: Scope containing available types. 1259 index: Index of the method in the service. 1260 1261 Returns: 1262 An initialized MethodDescriptor object. 1263 """ 1264 full_name = '.'.join((service_name, method_proto.name)) 1265 input_type = self._GetTypeFromScope( 1266 package, method_proto.input_type, scope) 1267 output_type = self._GetTypeFromScope( 1268 package, method_proto.output_type, scope) 1269 return descriptor.MethodDescriptor( 1270 name=method_proto.name, 1271 full_name=full_name, 1272 index=index, 1273 containing_service=None, 1274 input_type=input_type, 1275 output_type=output_type, 1276 client_streaming=method_proto.client_streaming, 1277 server_streaming=method_proto.server_streaming, 1278 options=_OptionsOrNone(method_proto), 1279 # pylint: disable=protected-access 1280 create_key=descriptor._internal_create_key) 1281 1282 def _ExtractSymbols(self, descriptors): 1283 """Pulls out all the symbols from descriptor protos. 1284 1285 Args: 1286 descriptors: The messages to extract descriptors from. 1287 Yields: 1288 A two element tuple of the type name and descriptor object. 1289 """ 1290 1291 for desc in descriptors: 1292 yield (_PrefixWithDot(desc.full_name), desc) 1293 for symbol in self._ExtractSymbols(desc.nested_types): 1294 yield symbol 1295 for enum in desc.enum_types: 1296 yield (_PrefixWithDot(enum.full_name), enum) 1297 1298 def _GetDeps(self, dependencies, visited=None): 1299 """Recursively finds dependencies for file protos. 1300 1301 Args: 1302 dependencies: The names of the files being depended on. 1303 visited: The names of files already found. 1304 1305 Yields: 1306 Each direct and indirect dependency. 1307 """ 1308 1309 visited = visited or set() 1310 for dependency in dependencies: 1311 if dependency not in visited: 1312 visited.add(dependency) 1313 dep_desc = self.FindFileByName(dependency) 1314 yield dep_desc 1315 public_files = [d.name for d in dep_desc.public_dependencies] 1316 yield from self._GetDeps(public_files, visited) 1317 1318 def _GetTypeFromScope(self, package, type_name, scope): 1319 """Finds a given type name in the current scope. 1320 1321 Args: 1322 package: The package the proto should be located in. 1323 type_name: The name of the type to be found in the scope. 1324 scope: Dict mapping short and full symbols to message and enum types. 1325 1326 Returns: 1327 The descriptor for the requested type. 1328 """ 1329 if type_name not in scope: 1330 components = _PrefixWithDot(package).split('.') 1331 while components: 1332 possible_match = '.'.join(components + [type_name]) 1333 if possible_match in scope: 1334 type_name = possible_match 1335 break 1336 else: 1337 components.pop(-1) 1338 return scope[type_name] 1339 1340 1341def _PrefixWithDot(name): 1342 return name if name.startswith('.') else '.%s' % name 1343 1344 1345if _USE_C_DESCRIPTORS: 1346 # TODO: This pool could be constructed from Python code, when we 1347 # support a flag like 'use_cpp_generated_pool=True'. 1348 # pylint: disable=protected-access 1349 _DEFAULT = descriptor._message.default_pool 1350else: 1351 _DEFAULT = DescriptorPool() 1352 1353 1354def Default(): 1355 return _DEFAULT 1356