• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3#
4# Use of this source code is governed by a BSD-style
5# license that can be found in the LICENSE file or at
6# https://developers.google.com/open-source/licenses/bsd
7
8"""Provides a factory class for generating dynamic messages.
9
10The easiest way to use this class is if you have access to the FileDescriptor
11protos containing the messages you want to create you can just do the following:
12
13message_classes = message_factory.GetMessages(iterable_of_file_descriptors)
14my_proto_instance = message_classes['some.proto.package.MessageName']()
15"""
16
17__author__ = 'matthewtoia@google.com (Matt Toia)'
18
19import warnings
20
21from google.protobuf import descriptor_pool
22from google.protobuf import message
23from google.protobuf.internal import api_implementation
24
25if api_implementation.Type() == 'python':
26  from google.protobuf.internal import python_message as message_impl
27else:
28  from google.protobuf.pyext import cpp_message as message_impl  # pylint: disable=g-import-not-at-top
29
30
31# The type of all Message classes.
32_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType
33
34
35def GetMessageClass(descriptor):
36  """Obtains a proto2 message class based on the passed in descriptor.
37
38  Passing a descriptor with a fully qualified name matching a previous
39  invocation will cause the same class to be returned.
40
41  Args:
42    descriptor: The descriptor to build from.
43
44  Returns:
45    A class describing the passed in descriptor.
46  """
47  concrete_class = getattr(descriptor, '_concrete_class', None)
48  if concrete_class:
49    return concrete_class
50  return _InternalCreateMessageClass(descriptor)
51
52
53def GetMessageClassesForFiles(files, pool):
54  """Gets all the messages from specified files.
55
56  This will find and resolve dependencies, failing if the descriptor
57  pool cannot satisfy them.
58
59  Args:
60    files: The file names to extract messages from.
61    pool: The descriptor pool to find the files including the dependent files.
62
63  Returns:
64    A dictionary mapping proto names to the message classes.
65  """
66  result = {}
67  for file_name in files:
68    file_desc = pool.FindFileByName(file_name)
69    for desc in file_desc.message_types_by_name.values():
70      result[desc.full_name] = GetMessageClass(desc)
71
72    # While the extension FieldDescriptors are created by the descriptor pool,
73    # the python classes created in the factory need them to be registered
74    # explicitly, which is done below.
75    #
76    # The call to RegisterExtension will specifically check if the
77    # extension was already registered on the object and either
78    # ignore the registration if the original was the same, or raise
79    # an error if they were different.
80
81    for extension in file_desc.extensions_by_name.values():
82      _ = GetMessageClass(extension.containing_type)
83      if api_implementation.Type() != 'python':
84        # TODO: Remove this check here. Duplicate extension
85        # register check should be in descriptor_pool.
86        if extension is not pool.FindExtensionByNumber(
87            extension.containing_type, extension.number
88        ):
89          raise ValueError('Double registration of Extensions')
90      # Recursively load protos for extension field, in order to be able to
91      # fully represent the extension. This matches the behavior for regular
92      # fields too.
93      if extension.message_type:
94        GetMessageClass(extension.message_type)
95  return result
96
97
98def _InternalCreateMessageClass(descriptor):
99  """Builds a proto2 message class based on the passed in descriptor.
100
101  Args:
102    descriptor: The descriptor to build from.
103
104  Returns:
105    A class describing the passed in descriptor.
106  """
107  descriptor_name = descriptor.name
108  result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
109      descriptor_name,
110      (message.Message,),
111      {
112          'DESCRIPTOR': descriptor,
113          # If module not set, it wrongly points to message_factory module.
114          '__module__': None,
115      },
116  )
117  for field in descriptor.fields:
118    if field.message_type:
119      GetMessageClass(field.message_type)
120
121  for extension in result_class.DESCRIPTOR.extensions:
122    extended_class = GetMessageClass(extension.containing_type)
123    if api_implementation.Type() != 'python':
124      # TODO: Remove this check here. Duplicate extension
125      # register check should be in descriptor_pool.
126      pool = extension.containing_type.file.pool
127      if extension is not pool.FindExtensionByNumber(
128          extension.containing_type, extension.number
129      ):
130        raise ValueError('Double registration of Extensions')
131    if extension.message_type:
132      GetMessageClass(extension.message_type)
133  return result_class
134
135
136# Deprecated. Please use GetMessageClass() or GetMessageClassesForFiles()
137# method above instead.
138class MessageFactory(object):
139  """Factory for creating Proto2 messages from descriptors in a pool."""
140
141  def __init__(self, pool=None):
142    """Initializes a new factory."""
143    self.pool = pool or descriptor_pool.DescriptorPool()
144
145  def GetPrototype(self, descriptor):
146    """Obtains a proto2 message class based on the passed in descriptor.
147
148    Passing a descriptor with a fully qualified name matching a previous
149    invocation will cause the same class to be returned.
150
151    Args:
152      descriptor: The descriptor to build from.
153
154    Returns:
155      A class describing the passed in descriptor.
156    """
157    warnings.warn(
158        'MessageFactory class is deprecated. Please use '
159        'GetMessageClass() instead of MessageFactory.GetPrototype. '
160        'MessageFactory class will be removed after 2024.',
161        stacklevel=2,
162    )
163    return GetMessageClass(descriptor)
164
165  def CreatePrototype(self, descriptor):
166    """Builds a proto2 message class based on the passed in descriptor.
167
168    Don't call this function directly, it always creates a new class. Call
169    GetMessageClass() instead.
170
171    Args:
172      descriptor: The descriptor to build from.
173
174    Returns:
175      A class describing the passed in descriptor.
176    """
177    warnings.warn(
178        'Directly call CreatePrototype is wrong. Please use '
179        'GetMessageClass() method instead. Directly use '
180        'CreatePrototype will raise error after July 2023.',
181        stacklevel=2,
182    )
183    return _InternalCreateMessageClass(descriptor)
184
185  def GetMessages(self, files):
186    """Gets all the messages from a specified file.
187
188    This will find and resolve dependencies, failing if the descriptor
189    pool cannot satisfy them.
190
191    Args:
192      files: The file names to extract messages from.
193
194    Returns:
195      A dictionary mapping proto names to the message classes. This will include
196      any dependent messages as well as any messages defined in the same file as
197      a specified message.
198    """
199    warnings.warn(
200        'MessageFactory class is deprecated. Please use '
201        'GetMessageClassesForFiles() instead of '
202        'MessageFactory.GetMessages(). MessageFactory class '
203        'will be removed after 2024.',
204        stacklevel=2,
205    )
206    return GetMessageClassesForFiles(files, self.pool)
207
208
209def GetMessages(file_protos, pool=None):
210  """Builds a dictionary of all the messages available in a set of files.
211
212  Args:
213    file_protos: Iterable of FileDescriptorProto to build messages out of.
214    pool: The descriptor pool to add the file protos.
215
216  Returns:
217    A dictionary mapping proto names to the message classes. This will include
218    any dependent messages as well as any messages defined in the same file as
219    a specified message.
220  """
221  # The cpp implementation of the protocol buffer library requires to add the
222  # message in topological order of the dependency graph.
223  des_pool = pool or descriptor_pool.DescriptorPool()
224  file_by_name = {file_proto.name: file_proto for file_proto in file_protos}
225
226  def _AddFile(file_proto):
227    for dependency in file_proto.dependency:
228      if dependency in file_by_name:
229        # Remove from elements to be visited, in order to cut cycles.
230        _AddFile(file_by_name.pop(dependency))
231    des_pool.Add(file_proto)
232
233  while file_by_name:
234    _AddFile(file_by_name.popitem()[1])
235  return GetMessageClassesForFiles(
236      [file_proto.name for file_proto in file_protos], des_pool
237  )
238