1# Protocol Buffers - Google's data interchange format 2# Copyright 2008 Google Inc. All rights reserved. 3# 4# Use of this source code is governed by a BSD-style 5# license that can be found in the LICENSE file or at 6# https://developers.google.com/open-source/licenses/bsd 7 8"""Provides a factory class for generating dynamic messages. 9 10The easiest way to use this class is if you have access to the FileDescriptor 11protos containing the messages you want to create you can just do the following: 12 13message_classes = message_factory.GetMessages(iterable_of_file_descriptors) 14my_proto_instance = message_classes['some.proto.package.MessageName']() 15""" 16 17__author__ = 'matthewtoia@google.com (Matt Toia)' 18 19import warnings 20 21from google.protobuf import descriptor_pool 22from google.protobuf import message 23from google.protobuf.internal import api_implementation 24 25if api_implementation.Type() == 'python': 26 from google.protobuf.internal import python_message as message_impl 27else: 28 from google.protobuf.pyext import cpp_message as message_impl # pylint: disable=g-import-not-at-top 29 30 31# The type of all Message classes. 32_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType 33 34 35def GetMessageClass(descriptor): 36 """Obtains a proto2 message class based on the passed in descriptor. 37 38 Passing a descriptor with a fully qualified name matching a previous 39 invocation will cause the same class to be returned. 40 41 Args: 42 descriptor: The descriptor to build from. 43 44 Returns: 45 A class describing the passed in descriptor. 46 """ 47 concrete_class = getattr(descriptor, '_concrete_class', None) 48 if concrete_class: 49 return concrete_class 50 return _InternalCreateMessageClass(descriptor) 51 52 53def GetMessageClassesForFiles(files, pool): 54 """Gets all the messages from specified files. 55 56 This will find and resolve dependencies, failing if the descriptor 57 pool cannot satisfy them. 58 59 Args: 60 files: The file names to extract messages from. 61 pool: The descriptor pool to find the files including the dependent files. 62 63 Returns: 64 A dictionary mapping proto names to the message classes. 65 """ 66 result = {} 67 for file_name in files: 68 file_desc = pool.FindFileByName(file_name) 69 for desc in file_desc.message_types_by_name.values(): 70 result[desc.full_name] = GetMessageClass(desc) 71 72 # While the extension FieldDescriptors are created by the descriptor pool, 73 # the python classes created in the factory need them to be registered 74 # explicitly, which is done below. 75 # 76 # The call to RegisterExtension will specifically check if the 77 # extension was already registered on the object and either 78 # ignore the registration if the original was the same, or raise 79 # an error if they were different. 80 81 for extension in file_desc.extensions_by_name.values(): 82 _ = GetMessageClass(extension.containing_type) 83 if api_implementation.Type() != 'python': 84 # TODO: Remove this check here. Duplicate extension 85 # register check should be in descriptor_pool. 86 if extension is not pool.FindExtensionByNumber( 87 extension.containing_type, extension.number 88 ): 89 raise ValueError('Double registration of Extensions') 90 # Recursively load protos for extension field, in order to be able to 91 # fully represent the extension. This matches the behavior for regular 92 # fields too. 93 if extension.message_type: 94 GetMessageClass(extension.message_type) 95 return result 96 97 98def _InternalCreateMessageClass(descriptor): 99 """Builds a proto2 message class based on the passed in descriptor. 100 101 Args: 102 descriptor: The descriptor to build from. 103 104 Returns: 105 A class describing the passed in descriptor. 106 """ 107 descriptor_name = descriptor.name 108 result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE( 109 descriptor_name, 110 (message.Message,), 111 { 112 'DESCRIPTOR': descriptor, 113 # If module not set, it wrongly points to message_factory module. 114 '__module__': None, 115 }, 116 ) 117 for field in descriptor.fields: 118 if field.message_type: 119 GetMessageClass(field.message_type) 120 121 for extension in result_class.DESCRIPTOR.extensions: 122 extended_class = GetMessageClass(extension.containing_type) 123 if api_implementation.Type() != 'python': 124 # TODO: Remove this check here. Duplicate extension 125 # register check should be in descriptor_pool. 126 pool = extension.containing_type.file.pool 127 if extension is not pool.FindExtensionByNumber( 128 extension.containing_type, extension.number 129 ): 130 raise ValueError('Double registration of Extensions') 131 if extension.message_type: 132 GetMessageClass(extension.message_type) 133 return result_class 134 135 136# Deprecated. Please use GetMessageClass() or GetMessageClassesForFiles() 137# method above instead. 138class MessageFactory(object): 139 """Factory for creating Proto2 messages from descriptors in a pool.""" 140 141 def __init__(self, pool=None): 142 """Initializes a new factory.""" 143 self.pool = pool or descriptor_pool.DescriptorPool() 144 145 def GetPrototype(self, descriptor): 146 """Obtains a proto2 message class based on the passed in descriptor. 147 148 Passing a descriptor with a fully qualified name matching a previous 149 invocation will cause the same class to be returned. 150 151 Args: 152 descriptor: The descriptor to build from. 153 154 Returns: 155 A class describing the passed in descriptor. 156 """ 157 warnings.warn( 158 'MessageFactory class is deprecated. Please use ' 159 'GetMessageClass() instead of MessageFactory.GetPrototype. ' 160 'MessageFactory class will be removed after 2024.', 161 stacklevel=2, 162 ) 163 return GetMessageClass(descriptor) 164 165 def CreatePrototype(self, descriptor): 166 """Builds a proto2 message class based on the passed in descriptor. 167 168 Don't call this function directly, it always creates a new class. Call 169 GetMessageClass() instead. 170 171 Args: 172 descriptor: The descriptor to build from. 173 174 Returns: 175 A class describing the passed in descriptor. 176 """ 177 warnings.warn( 178 'Directly call CreatePrototype is wrong. Please use ' 179 'GetMessageClass() method instead. Directly use ' 180 'CreatePrototype will raise error after July 2023.', 181 stacklevel=2, 182 ) 183 return _InternalCreateMessageClass(descriptor) 184 185 def GetMessages(self, files): 186 """Gets all the messages from a specified file. 187 188 This will find and resolve dependencies, failing if the descriptor 189 pool cannot satisfy them. 190 191 Args: 192 files: The file names to extract messages from. 193 194 Returns: 195 A dictionary mapping proto names to the message classes. This will include 196 any dependent messages as well as any messages defined in the same file as 197 a specified message. 198 """ 199 warnings.warn( 200 'MessageFactory class is deprecated. Please use ' 201 'GetMessageClassesForFiles() instead of ' 202 'MessageFactory.GetMessages(). MessageFactory class ' 203 'will be removed after 2024.', 204 stacklevel=2, 205 ) 206 return GetMessageClassesForFiles(files, self.pool) 207 208 209def GetMessages(file_protos, pool=None): 210 """Builds a dictionary of all the messages available in a set of files. 211 212 Args: 213 file_protos: Iterable of FileDescriptorProto to build messages out of. 214 pool: The descriptor pool to add the file protos. 215 216 Returns: 217 A dictionary mapping proto names to the message classes. This will include 218 any dependent messages as well as any messages defined in the same file as 219 a specified message. 220 """ 221 # The cpp implementation of the protocol buffer library requires to add the 222 # message in topological order of the dependency graph. 223 des_pool = pool or descriptor_pool.DescriptorPool() 224 file_by_name = {file_proto.name: file_proto for file_proto in file_protos} 225 226 def _AddFile(file_proto): 227 for dependency in file_proto.dependency: 228 if dependency in file_by_name: 229 # Remove from elements to be visited, in order to cut cycles. 230 _AddFile(file_by_name.pop(dependency)) 231 des_pool.Add(file_proto) 232 233 while file_by_name: 234 _AddFile(file_by_name.popitem()[1]) 235 return GetMessageClassesForFiles( 236 [file_proto.name for file_proto in file_protos], des_pool 237 ) 238