• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""A database of Python protocol buffer generated symbols.
9
10SymbolDatabase is the MessageFactory for messages generated at compile time,
11and makes it easy to create new instances of a registered type, given only the
12type's protocol buffer symbol name.
13
14Example usage::
15
16  db = symbol_database.SymbolDatabase()
17
18  # Register symbols of interest, from one or multiple files.
19  db.RegisterFileDescriptor(my_proto_pb2.DESCRIPTOR)
20  db.RegisterMessage(my_proto_pb2.MyMessage)
21  db.RegisterEnumDescriptor(my_proto_pb2.MyEnum.DESCRIPTOR)
22
23  # The database can be used as a MessageFactory, to generate types based on
24  # their name:
25  types = db.GetMessages(['my_proto.proto'])
26  my_message_instance = types['MyMessage']()
27
28  # The database's underlying descriptor pool can be queried, so it's not
29  # necessary to know a type's filename to be able to generate it:
30  filename = db.pool.FindFileContainingSymbol('MyMessage')
31  my_message_instance = db.GetMessages([filename])['MyMessage']()
32
33  # This functionality is also provided directly via a convenience method:
34  my_message_instance = db.GetSymbol('MyMessage')()
35"""
36
37import warnings
38
39from google.protobuf.internal import api_implementation
40from google.protobuf import descriptor_pool
41from google.protobuf import message_factory
42
43
44class SymbolDatabase():
45  """A database of Python generated symbols."""
46
47  # local cache of registered classes.
48  _classes = {}
49
50  def __init__(self, pool=None):
51    """Initializes a new SymbolDatabase."""
52    self.pool = pool or descriptor_pool.DescriptorPool()
53
54  def GetPrototype(self, descriptor):
55    warnings.warn('SymbolDatabase.GetPrototype() is deprecated. Please '
56                  'use message_factory.GetMessageClass() instead. '
57                  'SymbolDatabase.GetPrototype() will be removed soon.')
58    return message_factory.GetMessageClass(descriptor)
59
60  def CreatePrototype(self, descriptor):
61    warnings.warn('Directly call CreatePrototype() is wrong. Please use '
62                  'message_factory.GetMessageClass() instead. '
63                  'SymbolDatabase.CreatePrototype() will be removed soon.')
64    return message_factory._InternalCreateMessageClass(descriptor)
65
66  def GetMessages(self, files):
67    warnings.warn('SymbolDatabase.GetMessages() is deprecated. Please use '
68                  'message_factory.GetMessageClassedForFiles() instead. '
69                  'SymbolDatabase.GetMessages() will be removed soon.')
70    return message_factory.GetMessageClassedForFiles(files, self.pool)
71
72  def RegisterMessage(self, message):
73    """Registers the given message type in the local database.
74
75    Calls to GetSymbol() and GetMessages() will return messages registered here.
76
77    Args:
78      message: A :class:`google.protobuf.message.Message` subclass (or
79        instance); its descriptor will be registered.
80
81    Returns:
82      The provided message.
83    """
84
85    desc = message.DESCRIPTOR
86    self._classes[desc] = message
87    self.RegisterMessageDescriptor(desc)
88    return message
89
90  def RegisterMessageDescriptor(self, message_descriptor):
91    """Registers the given message descriptor in the local database.
92
93    Args:
94      message_descriptor (Descriptor): the message descriptor to add.
95    """
96    if api_implementation.Type() == 'python':
97      # pylint: disable=protected-access
98      self.pool._AddDescriptor(message_descriptor)
99
100  def RegisterEnumDescriptor(self, enum_descriptor):
101    """Registers the given enum descriptor in the local database.
102
103    Args:
104      enum_descriptor (EnumDescriptor): The enum descriptor to register.
105
106    Returns:
107      EnumDescriptor: The provided descriptor.
108    """
109    if api_implementation.Type() == 'python':
110      # pylint: disable=protected-access
111      self.pool._AddEnumDescriptor(enum_descriptor)
112    return enum_descriptor
113
114  def RegisterServiceDescriptor(self, service_descriptor):
115    """Registers the given service descriptor in the local database.
116
117    Args:
118      service_descriptor (ServiceDescriptor): the service descriptor to
119        register.
120    """
121    if api_implementation.Type() == 'python':
122      # pylint: disable=protected-access
123      self.pool._AddServiceDescriptor(service_descriptor)
124
125  def RegisterFileDescriptor(self, file_descriptor):
126    """Registers the given file descriptor in the local database.
127
128    Args:
129      file_descriptor (FileDescriptor): The file descriptor to register.
130    """
131    if api_implementation.Type() == 'python':
132      # pylint: disable=protected-access
133      self.pool._InternalAddFileDescriptor(file_descriptor)
134
135  def GetSymbol(self, symbol):
136    """Tries to find a symbol in the local database.
137
138    Currently, this method only returns message.Message instances, however, if
139    may be extended in future to support other symbol types.
140
141    Args:
142      symbol (str): a protocol buffer symbol.
143
144    Returns:
145      A Python class corresponding to the symbol.
146
147    Raises:
148      KeyError: if the symbol could not be found.
149    """
150
151    return self._classes[self.pool.FindMessageTypeByName(symbol)]
152
153  def GetMessages(self, files):
154    # TODO: Fix the differences with MessageFactory.
155    """Gets all registered messages from a specified file.
156
157    Only messages already created and registered will be returned; (this is the
158    case for imported _pb2 modules)
159    But unlike MessageFactory, this version also returns already defined nested
160    messages, but does not register any message extensions.
161
162    Args:
163      files (list[str]): The file names to extract messages from.
164
165    Returns:
166      A dictionary mapping proto names to the message classes.
167
168    Raises:
169      KeyError: if a file could not be found.
170    """
171
172    def _GetAllMessages(desc):
173      """Walk a message Descriptor and recursively yields all message names."""
174      yield desc
175      for msg_desc in desc.nested_types:
176        for nested_desc in _GetAllMessages(msg_desc):
177          yield nested_desc
178
179    result = {}
180    for file_name in files:
181      file_desc = self.pool.FindFileByName(file_name)
182      for msg_desc in file_desc.message_types_by_name.values():
183        for desc in _GetAllMessages(msg_desc):
184          try:
185            result[desc.full_name] = self._classes[desc]
186          except KeyError:
187            # This descriptor has no registered class, skip it.
188            pass
189    return result
190
191
192_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default())
193
194
195def Default():
196  """Returns the default SymbolDatabase."""
197  return _DEFAULT
198