1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# https://developers.google.com/protocol-buffers/ 4# 5# Redistribution and use in source and binary forms, with or without 6# modification, are permitted provided that the following conditions are 7# met: 8# 9# * Redistributions of source code must retain the above copyright 10# notice, this list of conditions and the following disclaimer. 11# * Redistributions in binary form must reproduce the above 12# copyright notice, this list of conditions and the following disclaimer 13# in the documentation and/or other materials provided with the 14# distribution. 15# * Neither the name of Google Inc. nor the names of its 16# contributors may be used to endorse or promote products derived from 17# this software without specific prior written permission. 18# 19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 31"""Provides DescriptorPool to use as a container for proto2 descriptors. 32 33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain 34a collection of protocol buffer descriptors for use when dynamically creating 35message types at runtime. 36 37For most applications protocol buffers should be used via modules generated by 38the protocol buffer compiler tool. This should only be used when the type of 39protocol buffers used in an application or library cannot be predetermined. 40 41Below is a straightforward example on how to use this class: 42 43 pool = DescriptorPool() 44 file_descriptor_protos = [ ... ] 45 for file_descriptor_proto in file_descriptor_protos: 46 pool.Add(file_descriptor_proto) 47 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType') 48 49The message descriptor can be used in conjunction with the message_factory 50module in order to create a protocol buffer class that can be encoded and 51decoded. 52 53If you want to get a Python class for the specified proto, use the 54helper functions inside google.protobuf.message_factory 55directly instead of this class. 56""" 57 58__author__ = 'matthewtoia@google.com (Matt Toia)' 59 60import collections 61import warnings 62 63from google.protobuf import descriptor 64from google.protobuf import descriptor_database 65from google.protobuf import text_encoding 66 67 68_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access 69 70 71def _NormalizeFullyQualifiedName(name): 72 """Remove leading period from fully-qualified type name. 73 74 Due to b/13860351 in descriptor_database.py, types in the root namespace are 75 generated with a leading period. This function removes that prefix. 76 77 Args: 78 name: A str, the fully-qualified symbol name. 79 80 Returns: 81 A str, the normalized fully-qualified symbol name. 82 """ 83 return name.lstrip('.') 84 85 86def _OptionsOrNone(descriptor_proto): 87 """Returns the value of the field `options`, or None if it is not set.""" 88 if descriptor_proto.HasField('options'): 89 return descriptor_proto.options 90 else: 91 return None 92 93 94def _IsMessageSetExtension(field): 95 return (field.is_extension and 96 field.containing_type.has_options and 97 field.containing_type.GetOptions().message_set_wire_format and 98 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and 99 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL) 100 101 102class DescriptorPool(object): 103 """A collection of protobufs dynamically constructed by descriptor protos.""" 104 105 if _USE_C_DESCRIPTORS: 106 107 def __new__(cls, descriptor_db=None): 108 # pylint: disable=protected-access 109 return descriptor._message.DescriptorPool(descriptor_db) 110 111 def __init__(self, descriptor_db=None): 112 """Initializes a Pool of proto buffs. 113 114 The descriptor_db argument to the constructor is provided to allow 115 specialized file descriptor proto lookup code to be triggered on demand. An 116 example would be an implementation which will read and compile a file 117 specified in a call to FindFileByName() and not require the call to Add() 118 at all. Results from this database will be cached internally here as well. 119 120 Args: 121 descriptor_db: A secondary source of file descriptors. 122 """ 123 124 self._internal_db = descriptor_database.DescriptorDatabase() 125 self._descriptor_db = descriptor_db 126 self._descriptors = {} 127 self._enum_descriptors = {} 128 self._service_descriptors = {} 129 self._file_descriptors = {} 130 self._toplevel_extensions = {} 131 # TODO(jieluo): Remove _file_desc_by_toplevel_extension after 132 # maybe year 2020 for compatibility issue (with 3.4.1 only). 133 self._file_desc_by_toplevel_extension = {} 134 self._top_enum_values = {} 135 # We store extensions in two two-level mappings: The first key is the 136 # descriptor of the message being extended, the second key is the extension 137 # full name or its tag number. 138 self._extensions_by_name = collections.defaultdict(dict) 139 self._extensions_by_number = collections.defaultdict(dict) 140 141 def _CheckConflictRegister(self, desc, desc_name, file_name): 142 """Check if the descriptor name conflicts with another of the same name. 143 144 Args: 145 desc: Descriptor of a message, enum, service, extension or enum value. 146 desc_name: the full name of desc. 147 file_name: The file name of descriptor. 148 """ 149 for register, descriptor_type in [ 150 (self._descriptors, descriptor.Descriptor), 151 (self._enum_descriptors, descriptor.EnumDescriptor), 152 (self._service_descriptors, descriptor.ServiceDescriptor), 153 (self._toplevel_extensions, descriptor.FieldDescriptor), 154 (self._top_enum_values, descriptor.EnumValueDescriptor)]: 155 if desc_name in register: 156 old_desc = register[desc_name] 157 if isinstance(old_desc, descriptor.EnumValueDescriptor): 158 old_file = old_desc.type.file.name 159 else: 160 old_file = old_desc.file.name 161 162 if not isinstance(desc, descriptor_type) or ( 163 old_file != file_name): 164 error_msg = ('Conflict register for file "' + file_name + 165 '": ' + desc_name + 166 ' is already defined in file "' + 167 old_file + '". Please fix the conflict by adding ' 168 'package name on the proto file, or use different ' 169 'name for the duplication.') 170 if isinstance(desc, descriptor.EnumValueDescriptor): 171 error_msg += ('\nNote: enum values appear as ' 172 'siblings of the enum type instead of ' 173 'children of it.') 174 175 raise TypeError(error_msg) 176 177 return 178 179 def Add(self, file_desc_proto): 180 """Adds the FileDescriptorProto and its types to this pool. 181 182 Args: 183 file_desc_proto: The FileDescriptorProto to add. 184 """ 185 186 self._internal_db.Add(file_desc_proto) 187 188 def AddSerializedFile(self, serialized_file_desc_proto): 189 """Adds the FileDescriptorProto and its types to this pool. 190 191 Args: 192 serialized_file_desc_proto: A bytes string, serialization of the 193 FileDescriptorProto to add. 194 """ 195 196 # pylint: disable=g-import-not-at-top 197 from google.protobuf import descriptor_pb2 198 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( 199 serialized_file_desc_proto) 200 self.Add(file_desc_proto) 201 202 def AddDescriptor(self, desc): 203 """Adds a Descriptor to the pool, non-recursively. 204 205 If the Descriptor contains nested messages or enums, the caller must 206 explicitly register them. This method also registers the FileDescriptor 207 associated with the message. 208 209 Args: 210 desc: A Descriptor. 211 """ 212 if not isinstance(desc, descriptor.Descriptor): 213 raise TypeError('Expected instance of descriptor.Descriptor.') 214 215 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 216 217 self._descriptors[desc.full_name] = desc 218 self._AddFileDescriptor(desc.file) 219 220 def AddEnumDescriptor(self, enum_desc): 221 """Adds an EnumDescriptor to the pool. 222 223 This method also registers the FileDescriptor associated with the enum. 224 225 Args: 226 enum_desc: An EnumDescriptor. 227 """ 228 229 if not isinstance(enum_desc, descriptor.EnumDescriptor): 230 raise TypeError('Expected instance of descriptor.EnumDescriptor.') 231 232 file_name = enum_desc.file.name 233 self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name) 234 self._enum_descriptors[enum_desc.full_name] = enum_desc 235 236 # Top enum values need to be indexed. 237 # Count the number of dots to see whether the enum is toplevel or nested 238 # in a message. We cannot use enum_desc.containing_type at this stage. 239 if enum_desc.file.package: 240 top_level = (enum_desc.full_name.count('.') 241 - enum_desc.file.package.count('.') == 1) 242 else: 243 top_level = enum_desc.full_name.count('.') == 0 244 if top_level: 245 file_name = enum_desc.file.name 246 package = enum_desc.file.package 247 for enum_value in enum_desc.values: 248 full_name = _NormalizeFullyQualifiedName( 249 '.'.join((package, enum_value.name))) 250 self._CheckConflictRegister(enum_value, full_name, file_name) 251 self._top_enum_values[full_name] = enum_value 252 self._AddFileDescriptor(enum_desc.file) 253 254 def AddServiceDescriptor(self, service_desc): 255 """Adds a ServiceDescriptor to the pool. 256 257 Args: 258 service_desc: A ServiceDescriptor. 259 """ 260 261 if not isinstance(service_desc, descriptor.ServiceDescriptor): 262 raise TypeError('Expected instance of descriptor.ServiceDescriptor.') 263 264 self._CheckConflictRegister(service_desc, service_desc.full_name, 265 service_desc.file.name) 266 self._service_descriptors[service_desc.full_name] = service_desc 267 268 def AddExtensionDescriptor(self, extension): 269 """Adds a FieldDescriptor describing an extension to the pool. 270 271 Args: 272 extension: A FieldDescriptor. 273 274 Raises: 275 AssertionError: when another extension with the same number extends the 276 same message. 277 TypeError: when the specified extension is not a 278 descriptor.FieldDescriptor. 279 """ 280 if not (isinstance(extension, descriptor.FieldDescriptor) and 281 extension.is_extension): 282 raise TypeError('Expected an extension descriptor.') 283 284 if extension.extension_scope is None: 285 self._toplevel_extensions[extension.full_name] = extension 286 287 try: 288 existing_desc = self._extensions_by_number[ 289 extension.containing_type][extension.number] 290 except KeyError: 291 pass 292 else: 293 if extension is not existing_desc: 294 raise AssertionError( 295 'Extensions "%s" and "%s" both try to extend message type "%s" ' 296 'with field number %d.' % 297 (extension.full_name, existing_desc.full_name, 298 extension.containing_type.full_name, extension.number)) 299 300 self._extensions_by_number[extension.containing_type][ 301 extension.number] = extension 302 self._extensions_by_name[extension.containing_type][ 303 extension.full_name] = extension 304 305 # Also register MessageSet extensions with the type name. 306 if _IsMessageSetExtension(extension): 307 self._extensions_by_name[extension.containing_type][ 308 extension.message_type.full_name] = extension 309 310 def AddFileDescriptor(self, file_desc): 311 """Adds a FileDescriptor to the pool, non-recursively. 312 313 If the FileDescriptor contains messages or enums, the caller must explicitly 314 register them. 315 316 Args: 317 file_desc: A FileDescriptor. 318 """ 319 320 self._AddFileDescriptor(file_desc) 321 # TODO(jieluo): This is a temporary solution for FieldDescriptor.file. 322 # FieldDescriptor.file is added in code gen. Remove this solution after 323 # maybe 2020 for compatibility reason (with 3.4.1 only). 324 for extension in file_desc.extensions_by_name.values(): 325 self._file_desc_by_toplevel_extension[ 326 extension.full_name] = file_desc 327 328 def _AddFileDescriptor(self, file_desc): 329 """Adds a FileDescriptor to the pool, non-recursively. 330 331 If the FileDescriptor contains messages or enums, the caller must explicitly 332 register them. 333 334 Args: 335 file_desc: A FileDescriptor. 336 """ 337 338 if not isinstance(file_desc, descriptor.FileDescriptor): 339 raise TypeError('Expected instance of descriptor.FileDescriptor.') 340 self._file_descriptors[file_desc.name] = file_desc 341 342 def FindFileByName(self, file_name): 343 """Gets a FileDescriptor by file name. 344 345 Args: 346 file_name: The path to the file to get a descriptor for. 347 348 Returns: 349 A FileDescriptor for the named file. 350 351 Raises: 352 KeyError: if the file cannot be found in the pool. 353 """ 354 355 try: 356 return self._file_descriptors[file_name] 357 except KeyError: 358 pass 359 360 try: 361 file_proto = self._internal_db.FindFileByName(file_name) 362 except KeyError as error: 363 if self._descriptor_db: 364 file_proto = self._descriptor_db.FindFileByName(file_name) 365 else: 366 raise error 367 if not file_proto: 368 raise KeyError('Cannot find a file named %s' % file_name) 369 return self._ConvertFileProtoToFileDescriptor(file_proto) 370 371 def FindFileContainingSymbol(self, symbol): 372 """Gets the FileDescriptor for the file containing the specified symbol. 373 374 Args: 375 symbol: The name of the symbol to search for. 376 377 Returns: 378 A FileDescriptor that contains the specified symbol. 379 380 Raises: 381 KeyError: if the file cannot be found in the pool. 382 """ 383 384 symbol = _NormalizeFullyQualifiedName(symbol) 385 try: 386 return self._InternalFindFileContainingSymbol(symbol) 387 except KeyError: 388 pass 389 390 try: 391 # Try fallback database. Build and find again if possible. 392 self._FindFileContainingSymbolInDb(symbol) 393 return self._InternalFindFileContainingSymbol(symbol) 394 except KeyError: 395 raise KeyError('Cannot find a file containing %s' % symbol) 396 397 def _InternalFindFileContainingSymbol(self, symbol): 398 """Gets the already built FileDescriptor containing the specified symbol. 399 400 Args: 401 symbol: The name of the symbol to search for. 402 403 Returns: 404 A FileDescriptor that contains the specified symbol. 405 406 Raises: 407 KeyError: if the file cannot be found in the pool. 408 """ 409 try: 410 return self._descriptors[symbol].file 411 except KeyError: 412 pass 413 414 try: 415 return self._enum_descriptors[symbol].file 416 except KeyError: 417 pass 418 419 try: 420 return self._service_descriptors[symbol].file 421 except KeyError: 422 pass 423 424 try: 425 return self._top_enum_values[symbol].type.file 426 except KeyError: 427 pass 428 429 try: 430 return self._file_desc_by_toplevel_extension[symbol] 431 except KeyError: 432 pass 433 434 # Try fields, enum values and nested extensions inside a message. 435 top_name, _, sub_name = symbol.rpartition('.') 436 try: 437 message = self.FindMessageTypeByName(top_name) 438 assert (sub_name in message.extensions_by_name or 439 sub_name in message.fields_by_name or 440 sub_name in message.enum_values_by_name) 441 return message.file 442 except (KeyError, AssertionError): 443 raise KeyError('Cannot find a file containing %s' % symbol) 444 445 def FindMessageTypeByName(self, full_name): 446 """Loads the named descriptor from the pool. 447 448 Args: 449 full_name: The full name of the descriptor to load. 450 451 Returns: 452 The descriptor for the named type. 453 454 Raises: 455 KeyError: if the message cannot be found in the pool. 456 """ 457 458 full_name = _NormalizeFullyQualifiedName(full_name) 459 if full_name not in self._descriptors: 460 self._FindFileContainingSymbolInDb(full_name) 461 return self._descriptors[full_name] 462 463 def FindEnumTypeByName(self, full_name): 464 """Loads the named enum descriptor from the pool. 465 466 Args: 467 full_name: The full name of the enum descriptor to load. 468 469 Returns: 470 The enum descriptor for the named type. 471 472 Raises: 473 KeyError: if the enum cannot be found in the pool. 474 """ 475 476 full_name = _NormalizeFullyQualifiedName(full_name) 477 if full_name not in self._enum_descriptors: 478 self._FindFileContainingSymbolInDb(full_name) 479 return self._enum_descriptors[full_name] 480 481 def FindFieldByName(self, full_name): 482 """Loads the named field descriptor from the pool. 483 484 Args: 485 full_name: The full name of the field descriptor to load. 486 487 Returns: 488 The field descriptor for the named field. 489 490 Raises: 491 KeyError: if the field cannot be found in the pool. 492 """ 493 full_name = _NormalizeFullyQualifiedName(full_name) 494 message_name, _, field_name = full_name.rpartition('.') 495 message_descriptor = self.FindMessageTypeByName(message_name) 496 return message_descriptor.fields_by_name[field_name] 497 498 def FindOneofByName(self, full_name): 499 """Loads the named oneof descriptor from the pool. 500 501 Args: 502 full_name: The full name of the oneof descriptor to load. 503 504 Returns: 505 The oneof descriptor for the named oneof. 506 507 Raises: 508 KeyError: if the oneof cannot be found in the pool. 509 """ 510 full_name = _NormalizeFullyQualifiedName(full_name) 511 message_name, _, oneof_name = full_name.rpartition('.') 512 message_descriptor = self.FindMessageTypeByName(message_name) 513 return message_descriptor.oneofs_by_name[oneof_name] 514 515 def FindExtensionByName(self, full_name): 516 """Loads the named extension descriptor from the pool. 517 518 Args: 519 full_name: The full name of the extension descriptor to load. 520 521 Returns: 522 A FieldDescriptor, describing the named extension. 523 524 Raises: 525 KeyError: if the extension cannot be found in the pool. 526 """ 527 full_name = _NormalizeFullyQualifiedName(full_name) 528 try: 529 # The proto compiler does not give any link between the FileDescriptor 530 # and top-level extensions unless the FileDescriptorProto is added to 531 # the DescriptorDatabase, but this can impact memory usage. 532 # So we registered these extensions by name explicitly. 533 return self._toplevel_extensions[full_name] 534 except KeyError: 535 pass 536 message_name, _, extension_name = full_name.rpartition('.') 537 try: 538 # Most extensions are nested inside a message. 539 scope = self.FindMessageTypeByName(message_name) 540 except KeyError: 541 # Some extensions are defined at file scope. 542 scope = self._FindFileContainingSymbolInDb(full_name) 543 return scope.extensions_by_name[extension_name] 544 545 def FindExtensionByNumber(self, message_descriptor, number): 546 """Gets the extension of the specified message with the specified number. 547 548 Extensions have to be registered to this pool by calling 549 AddExtensionDescriptor. 550 551 Args: 552 message_descriptor: descriptor of the extended message. 553 number: integer, number of the extension field. 554 555 Returns: 556 A FieldDescriptor describing the extension. 557 558 Raises: 559 KeyError: when no extension with the given number is known for the 560 specified message. 561 """ 562 try: 563 return self._extensions_by_number[message_descriptor][number] 564 except KeyError: 565 self._TryLoadExtensionFromDB(message_descriptor, number) 566 return self._extensions_by_number[message_descriptor][number] 567 568 def FindAllExtensions(self, message_descriptor): 569 """Gets all the known extension of a given message. 570 571 Extensions have to be registered to this pool by calling 572 AddExtensionDescriptor. 573 574 Args: 575 message_descriptor: descriptor of the extended message. 576 577 Returns: 578 A list of FieldDescriptor describing the extensions. 579 """ 580 # Fallback to descriptor db if FindAllExtensionNumbers is provided. 581 if self._descriptor_db and hasattr( 582 self._descriptor_db, 'FindAllExtensionNumbers'): 583 full_name = message_descriptor.full_name 584 all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name) 585 for number in all_numbers: 586 if number in self._extensions_by_number[message_descriptor]: 587 continue 588 self._TryLoadExtensionFromDB(message_descriptor, number) 589 590 return list(self._extensions_by_number[message_descriptor].values()) 591 592 def _TryLoadExtensionFromDB(self, message_descriptor, number): 593 """Try to Load extensions from decriptor db. 594 595 Args: 596 message_descriptor: descriptor of the extended message. 597 number: the extension number that needs to be loaded. 598 """ 599 if not self._descriptor_db: 600 return 601 # Only supported when FindFileContainingExtension is provided. 602 if not hasattr( 603 self._descriptor_db, 'FindFileContainingExtension'): 604 return 605 606 full_name = message_descriptor.full_name 607 file_proto = self._descriptor_db.FindFileContainingExtension( 608 full_name, number) 609 610 if file_proto is None: 611 return 612 613 try: 614 file_desc = self._ConvertFileProtoToFileDescriptor(file_proto) 615 for extension in file_desc.extensions_by_name.values(): 616 self._extensions_by_number[extension.containing_type][ 617 extension.number] = extension 618 self._extensions_by_name[extension.containing_type][ 619 extension.full_name] = extension 620 for message_type in file_desc.message_types_by_name.values(): 621 for extension in message_type.extensions: 622 self._extensions_by_number[extension.containing_type][ 623 extension.number] = extension 624 self._extensions_by_name[extension.containing_type][ 625 extension.full_name] = extension 626 except: 627 warn_msg = ('Unable to load proto file %s for extension number %d.' % 628 (file_proto.name, number)) 629 warnings.warn(warn_msg, RuntimeWarning) 630 631 def FindServiceByName(self, full_name): 632 """Loads the named service descriptor from the pool. 633 634 Args: 635 full_name: The full name of the service descriptor to load. 636 637 Returns: 638 The service descriptor for the named service. 639 640 Raises: 641 KeyError: if the service cannot be found in the pool. 642 """ 643 full_name = _NormalizeFullyQualifiedName(full_name) 644 if full_name not in self._service_descriptors: 645 self._FindFileContainingSymbolInDb(full_name) 646 return self._service_descriptors[full_name] 647 648 def FindMethodByName(self, full_name): 649 """Loads the named service method descriptor from the pool. 650 651 Args: 652 full_name: The full name of the method descriptor to load. 653 654 Returns: 655 The method descriptor for the service method. 656 657 Raises: 658 KeyError: if the method cannot be found in the pool. 659 """ 660 full_name = _NormalizeFullyQualifiedName(full_name) 661 service_name, _, method_name = full_name.rpartition('.') 662 service_descriptor = self.FindServiceByName(service_name) 663 return service_descriptor.methods_by_name[method_name] 664 665 def _FindFileContainingSymbolInDb(self, symbol): 666 """Finds the file in descriptor DB containing the specified symbol. 667 668 Args: 669 symbol: The name of the symbol to search for. 670 671 Returns: 672 A FileDescriptor that contains the specified symbol. 673 674 Raises: 675 KeyError: if the file cannot be found in the descriptor database. 676 """ 677 try: 678 file_proto = self._internal_db.FindFileContainingSymbol(symbol) 679 except KeyError as error: 680 if self._descriptor_db: 681 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) 682 else: 683 raise error 684 if not file_proto: 685 raise KeyError('Cannot find a file containing %s' % symbol) 686 return self._ConvertFileProtoToFileDescriptor(file_proto) 687 688 def _ConvertFileProtoToFileDescriptor(self, file_proto): 689 """Creates a FileDescriptor from a proto or returns a cached copy. 690 691 This method also has the side effect of loading all the symbols found in 692 the file into the appropriate dictionaries in the pool. 693 694 Args: 695 file_proto: The proto to convert. 696 697 Returns: 698 A FileDescriptor matching the passed in proto. 699 """ 700 if file_proto.name not in self._file_descriptors: 701 built_deps = list(self._GetDeps(file_proto.dependency)) 702 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] 703 public_deps = [direct_deps[i] for i in file_proto.public_dependency] 704 705 file_descriptor = descriptor.FileDescriptor( 706 pool=self, 707 name=file_proto.name, 708 package=file_proto.package, 709 syntax=file_proto.syntax, 710 options=_OptionsOrNone(file_proto), 711 serialized_pb=file_proto.SerializeToString(), 712 dependencies=direct_deps, 713 public_dependencies=public_deps) 714 scope = {} 715 716 # This loop extracts all the message and enum types from all the 717 # dependencies of the file_proto. This is necessary to create the 718 # scope of available message types when defining the passed in 719 # file proto. 720 for dependency in built_deps: 721 scope.update(self._ExtractSymbols( 722 dependency.message_types_by_name.values())) 723 scope.update((_PrefixWithDot(enum.full_name), enum) 724 for enum in dependency.enum_types_by_name.values()) 725 726 for message_type in file_proto.message_type: 727 message_desc = self._ConvertMessageDescriptor( 728 message_type, file_proto.package, file_descriptor, scope, 729 file_proto.syntax) 730 file_descriptor.message_types_by_name[message_desc.name] = ( 731 message_desc) 732 733 for enum_type in file_proto.enum_type: 734 file_descriptor.enum_types_by_name[enum_type.name] = ( 735 self._ConvertEnumDescriptor(enum_type, file_proto.package, 736 file_descriptor, None, scope, True)) 737 738 for index, extension_proto in enumerate(file_proto.extension): 739 extension_desc = self._MakeFieldDescriptor( 740 extension_proto, file_proto.package, index, file_descriptor, 741 is_extension=True) 742 extension_desc.containing_type = self._GetTypeFromScope( 743 file_descriptor.package, extension_proto.extendee, scope) 744 self._SetFieldType(extension_proto, extension_desc, 745 file_descriptor.package, scope) 746 file_descriptor.extensions_by_name[extension_desc.name] = ( 747 extension_desc) 748 self._file_desc_by_toplevel_extension[extension_desc.full_name] = ( 749 file_descriptor) 750 751 for desc_proto in file_proto.message_type: 752 self._SetAllFieldTypes(file_proto.package, desc_proto, scope) 753 754 if file_proto.package: 755 desc_proto_prefix = _PrefixWithDot(file_proto.package) 756 else: 757 desc_proto_prefix = '' 758 759 for desc_proto in file_proto.message_type: 760 desc = self._GetTypeFromScope( 761 desc_proto_prefix, desc_proto.name, scope) 762 file_descriptor.message_types_by_name[desc_proto.name] = desc 763 764 for index, service_proto in enumerate(file_proto.service): 765 file_descriptor.services_by_name[service_proto.name] = ( 766 self._MakeServiceDescriptor(service_proto, index, scope, 767 file_proto.package, file_descriptor)) 768 769 self.Add(file_proto) 770 self._file_descriptors[file_proto.name] = file_descriptor 771 772 return self._file_descriptors[file_proto.name] 773 774 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, 775 scope=None, syntax=None): 776 """Adds the proto to the pool in the specified package. 777 778 Args: 779 desc_proto: The descriptor_pb2.DescriptorProto protobuf message. 780 package: The package the proto should be located in. 781 file_desc: The file containing this message. 782 scope: Dict mapping short and full symbols to message and enum types. 783 syntax: string indicating syntax of the file ("proto2" or "proto3") 784 785 Returns: 786 The added descriptor. 787 """ 788 789 if package: 790 desc_name = '.'.join((package, desc_proto.name)) 791 else: 792 desc_name = desc_proto.name 793 794 if file_desc is None: 795 file_name = None 796 else: 797 file_name = file_desc.name 798 799 if scope is None: 800 scope = {} 801 802 nested = [ 803 self._ConvertMessageDescriptor( 804 nested, desc_name, file_desc, scope, syntax) 805 for nested in desc_proto.nested_type] 806 enums = [ 807 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, 808 scope, False) 809 for enum in desc_proto.enum_type] 810 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) 811 for index, field in enumerate(desc_proto.field)] 812 extensions = [ 813 self._MakeFieldDescriptor(extension, desc_name, index, file_desc, 814 is_extension=True) 815 for index, extension in enumerate(desc_proto.extension)] 816 oneofs = [ 817 descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), 818 index, None, [], desc.options) 819 for index, desc in enumerate(desc_proto.oneof_decl)] 820 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] 821 if extension_ranges: 822 is_extendable = True 823 else: 824 is_extendable = False 825 desc = descriptor.Descriptor( 826 name=desc_proto.name, 827 full_name=desc_name, 828 filename=file_name, 829 containing_type=None, 830 fields=fields, 831 oneofs=oneofs, 832 nested_types=nested, 833 enum_types=enums, 834 extensions=extensions, 835 options=_OptionsOrNone(desc_proto), 836 is_extendable=is_extendable, 837 extension_ranges=extension_ranges, 838 file=file_desc, 839 serialized_start=None, 840 serialized_end=None, 841 syntax=syntax) 842 for nested in desc.nested_types: 843 nested.containing_type = desc 844 for enum in desc.enum_types: 845 enum.containing_type = desc 846 for field_index, field_desc in enumerate(desc_proto.field): 847 if field_desc.HasField('oneof_index'): 848 oneof_index = field_desc.oneof_index 849 oneofs[oneof_index].fields.append(fields[field_index]) 850 fields[field_index].containing_oneof = oneofs[oneof_index] 851 852 scope[_PrefixWithDot(desc_name)] = desc 853 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 854 self._descriptors[desc_name] = desc 855 return desc 856 857 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, 858 containing_type=None, scope=None, top_level=False): 859 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. 860 861 Args: 862 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message. 863 package: Optional package name for the new message EnumDescriptor. 864 file_desc: The file containing the enum descriptor. 865 containing_type: The type containing this enum. 866 scope: Scope containing available types. 867 top_level: If True, the enum is a top level symbol. If False, the enum 868 is defined inside a message. 869 870 Returns: 871 The added descriptor 872 """ 873 874 if package: 875 enum_name = '.'.join((package, enum_proto.name)) 876 else: 877 enum_name = enum_proto.name 878 879 if file_desc is None: 880 file_name = None 881 else: 882 file_name = file_desc.name 883 884 values = [self._MakeEnumValueDescriptor(value, index) 885 for index, value in enumerate(enum_proto.value)] 886 desc = descriptor.EnumDescriptor(name=enum_proto.name, 887 full_name=enum_name, 888 filename=file_name, 889 file=file_desc, 890 values=values, 891 containing_type=containing_type, 892 options=_OptionsOrNone(enum_proto)) 893 scope['.%s' % enum_name] = desc 894 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 895 self._enum_descriptors[enum_name] = desc 896 897 # Add top level enum values. 898 if top_level: 899 for value in values: 900 full_name = _NormalizeFullyQualifiedName( 901 '.'.join((package, value.name))) 902 self._CheckConflictRegister(value, full_name, file_name) 903 self._top_enum_values[full_name] = value 904 905 return desc 906 907 def _MakeFieldDescriptor(self, field_proto, message_name, index, 908 file_desc, is_extension=False): 909 """Creates a field descriptor from a FieldDescriptorProto. 910 911 For message and enum type fields, this method will do a look up 912 in the pool for the appropriate descriptor for that type. If it 913 is unavailable, it will fall back to the _source function to 914 create it. If this type is still unavailable, construction will 915 fail. 916 917 Args: 918 field_proto: The proto describing the field. 919 message_name: The name of the containing message. 920 index: Index of the field 921 file_desc: The file containing the field descriptor. 922 is_extension: Indication that this field is for an extension. 923 924 Returns: 925 An initialized FieldDescriptor object 926 """ 927 928 if message_name: 929 full_name = '.'.join((message_name, field_proto.name)) 930 else: 931 full_name = field_proto.name 932 933 return descriptor.FieldDescriptor( 934 name=field_proto.name, 935 full_name=full_name, 936 index=index, 937 number=field_proto.number, 938 type=field_proto.type, 939 cpp_type=None, 940 message_type=None, 941 enum_type=None, 942 containing_type=None, 943 label=field_proto.label, 944 has_default_value=False, 945 default_value=None, 946 is_extension=is_extension, 947 extension_scope=None, 948 options=_OptionsOrNone(field_proto), 949 file=file_desc) 950 951 def _SetAllFieldTypes(self, package, desc_proto, scope): 952 """Sets all the descriptor's fields's types. 953 954 This method also sets the containing types on any extensions. 955 956 Args: 957 package: The current package of desc_proto. 958 desc_proto: The message descriptor to update. 959 scope: Enclosing scope of available types. 960 """ 961 962 package = _PrefixWithDot(package) 963 964 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope) 965 966 if package == '.': 967 nested_package = _PrefixWithDot(desc_proto.name) 968 else: 969 nested_package = '.'.join([package, desc_proto.name]) 970 971 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): 972 self._SetFieldType(field_proto, field_desc, nested_package, scope) 973 974 for extension_proto, extension_desc in ( 975 zip(desc_proto.extension, main_desc.extensions)): 976 extension_desc.containing_type = self._GetTypeFromScope( 977 nested_package, extension_proto.extendee, scope) 978 self._SetFieldType(extension_proto, extension_desc, nested_package, scope) 979 980 for nested_type in desc_proto.nested_type: 981 self._SetAllFieldTypes(nested_package, nested_type, scope) 982 983 def _SetFieldType(self, field_proto, field_desc, package, scope): 984 """Sets the field's type, cpp_type, message_type and enum_type. 985 986 Args: 987 field_proto: Data about the field in proto format. 988 field_desc: The descriptor to modiy. 989 package: The package the field's container is in. 990 scope: Enclosing scope of available types. 991 """ 992 if field_proto.type_name: 993 desc = self._GetTypeFromScope(package, field_proto.type_name, scope) 994 else: 995 desc = None 996 997 if not field_proto.HasField('type'): 998 if isinstance(desc, descriptor.Descriptor): 999 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE 1000 else: 1001 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM 1002 1003 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( 1004 field_proto.type) 1005 1006 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE 1007 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): 1008 field_desc.message_type = desc 1009 1010 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1011 field_desc.enum_type = desc 1012 1013 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: 1014 field_desc.has_default_value = False 1015 field_desc.default_value = [] 1016 elif field_proto.HasField('default_value'): 1017 field_desc.has_default_value = True 1018 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 1019 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 1020 field_desc.default_value = float(field_proto.default_value) 1021 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 1022 field_desc.default_value = field_proto.default_value 1023 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 1024 field_desc.default_value = field_proto.default_value.lower() == 'true' 1025 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1026 field_desc.default_value = field_desc.enum_type.values_by_name[ 1027 field_proto.default_value].number 1028 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 1029 field_desc.default_value = text_encoding.CUnescape( 1030 field_proto.default_value) 1031 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 1032 field_desc.default_value = None 1033 else: 1034 # All other types are of the "int" type. 1035 field_desc.default_value = int(field_proto.default_value) 1036 else: 1037 field_desc.has_default_value = False 1038 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or 1039 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): 1040 field_desc.default_value = 0.0 1041 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: 1042 field_desc.default_value = u'' 1043 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: 1044 field_desc.default_value = False 1045 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: 1046 field_desc.default_value = field_desc.enum_type.values[0].number 1047 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: 1048 field_desc.default_value = b'' 1049 elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: 1050 field_desc.default_value = None 1051 else: 1052 # All other types are of the "int" type. 1053 field_desc.default_value = 0 1054 1055 field_desc.type = field_proto.type 1056 1057 def _MakeEnumValueDescriptor(self, value_proto, index): 1058 """Creates a enum value descriptor object from a enum value proto. 1059 1060 Args: 1061 value_proto: The proto describing the enum value. 1062 index: The index of the enum value. 1063 1064 Returns: 1065 An initialized EnumValueDescriptor object. 1066 """ 1067 1068 return descriptor.EnumValueDescriptor( 1069 name=value_proto.name, 1070 index=index, 1071 number=value_proto.number, 1072 options=_OptionsOrNone(value_proto), 1073 type=None) 1074 1075 def _MakeServiceDescriptor(self, service_proto, service_index, scope, 1076 package, file_desc): 1077 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto. 1078 1079 Args: 1080 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message. 1081 service_index: The index of the service in the File. 1082 scope: Dict mapping short and full symbols to message and enum types. 1083 package: Optional package name for the new message EnumDescriptor. 1084 file_desc: The file containing the service descriptor. 1085 1086 Returns: 1087 The added descriptor. 1088 """ 1089 1090 if package: 1091 service_name = '.'.join((package, service_proto.name)) 1092 else: 1093 service_name = service_proto.name 1094 1095 methods = [self._MakeMethodDescriptor(method_proto, service_name, package, 1096 scope, index) 1097 for index, method_proto in enumerate(service_proto.method)] 1098 desc = descriptor.ServiceDescriptor(name=service_proto.name, 1099 full_name=service_name, 1100 index=service_index, 1101 methods=methods, 1102 options=_OptionsOrNone(service_proto), 1103 file=file_desc) 1104 self._CheckConflictRegister(desc, desc.full_name, desc.file.name) 1105 self._service_descriptors[service_name] = desc 1106 return desc 1107 1108 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, 1109 index): 1110 """Creates a method descriptor from a MethodDescriptorProto. 1111 1112 Args: 1113 method_proto: The proto describing the method. 1114 service_name: The name of the containing service. 1115 package: Optional package name to look up for types. 1116 scope: Scope containing available types. 1117 index: Index of the method in the service. 1118 1119 Returns: 1120 An initialized MethodDescriptor object. 1121 """ 1122 full_name = '.'.join((service_name, method_proto.name)) 1123 input_type = self._GetTypeFromScope( 1124 package, method_proto.input_type, scope) 1125 output_type = self._GetTypeFromScope( 1126 package, method_proto.output_type, scope) 1127 return descriptor.MethodDescriptor(name=method_proto.name, 1128 full_name=full_name, 1129 index=index, 1130 containing_service=None, 1131 input_type=input_type, 1132 output_type=output_type, 1133 options=_OptionsOrNone(method_proto)) 1134 1135 def _ExtractSymbols(self, descriptors): 1136 """Pulls out all the symbols from descriptor protos. 1137 1138 Args: 1139 descriptors: The messages to extract descriptors from. 1140 Yields: 1141 A two element tuple of the type name and descriptor object. 1142 """ 1143 1144 for desc in descriptors: 1145 yield (_PrefixWithDot(desc.full_name), desc) 1146 for symbol in self._ExtractSymbols(desc.nested_types): 1147 yield symbol 1148 for enum in desc.enum_types: 1149 yield (_PrefixWithDot(enum.full_name), enum) 1150 1151 def _GetDeps(self, dependencies): 1152 """Recursively finds dependencies for file protos. 1153 1154 Args: 1155 dependencies: The names of the files being depended on. 1156 1157 Yields: 1158 Each direct and indirect dependency. 1159 """ 1160 1161 for dependency in dependencies: 1162 dep_desc = self.FindFileByName(dependency) 1163 yield dep_desc 1164 for parent_dep in dep_desc.dependencies: 1165 yield parent_dep 1166 1167 def _GetTypeFromScope(self, package, type_name, scope): 1168 """Finds a given type name in the current scope. 1169 1170 Args: 1171 package: The package the proto should be located in. 1172 type_name: The name of the type to be found in the scope. 1173 scope: Dict mapping short and full symbols to message and enum types. 1174 1175 Returns: 1176 The descriptor for the requested type. 1177 """ 1178 if type_name not in scope: 1179 components = _PrefixWithDot(package).split('.') 1180 while components: 1181 possible_match = '.'.join(components + [type_name]) 1182 if possible_match in scope: 1183 type_name = possible_match 1184 break 1185 else: 1186 components.pop(-1) 1187 return scope[type_name] 1188 1189 1190def _PrefixWithDot(name): 1191 return name if name.startswith('.') else '.%s' % name 1192 1193 1194if _USE_C_DESCRIPTORS: 1195 # TODO(amauryfa): This pool could be constructed from Python code, when we 1196 # support a flag like 'use_cpp_generated_pool=True'. 1197 # pylint: disable=protected-access 1198 _DEFAULT = descriptor._message.default_pool 1199else: 1200 _DEFAULT = DescriptorPool() 1201 1202 1203def Default(): 1204 return _DEFAULT 1205