1#!/usr/bin/env python 2# 3# Copyright 2010 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18"""Stub library.""" 19import six 20 21__author__ = 'rafek@google.com (Rafe Kaplan)' 22 23import sys 24import types 25 26from . import descriptor 27from . import message_types 28from . import messages 29from . import protobuf 30from . import remote 31from . import util 32 33__all__ = [ 34 'define_enum', 35 'define_field', 36 'define_file', 37 'define_message', 38 'define_service', 39 'import_file', 40 'import_file_set', 41] 42 43 44# Map variant back to message field classes. 45def _build_variant_map(): 46 """Map variants to fields. 47 48 Returns: 49 Dictionary mapping field variant to its associated field type. 50 """ 51 result = {} 52 for name in dir(messages): 53 value = getattr(messages, name) 54 if isinstance(value, type) and issubclass(value, messages.Field): 55 for variant in getattr(value, 'VARIANTS', []): 56 result[variant] = value 57 return result 58 59_VARIANT_MAP = _build_variant_map() 60 61_MESSAGE_TYPE_MAP = { 62 message_types.DateTimeMessage.definition_name(): message_types.DateTimeField, 63} 64 65 66def _get_or_define_module(full_name, modules): 67 """Helper method for defining new modules. 68 69 Args: 70 full_name: Fully qualified name of module to create or return. 71 modules: Dictionary of all modules. Defaults to sys.modules. 72 73 Returns: 74 Named module if found in 'modules', else creates new module and inserts in 75 'modules'. Will also construct parent modules if necessary. 76 """ 77 module = modules.get(full_name) 78 if not module: 79 module = types.ModuleType(full_name) 80 modules[full_name] = module 81 82 split_name = full_name.rsplit('.', 1) 83 if len(split_name) > 1: 84 parent_module_name, sub_module_name = split_name 85 parent_module = _get_or_define_module(parent_module_name, modules) 86 setattr(parent_module, sub_module_name, module) 87 88 return module 89 90 91def define_enum(enum_descriptor, module_name): 92 """Define Enum class from descriptor. 93 94 Args: 95 enum_descriptor: EnumDescriptor to build Enum class from. 96 module_name: Module name to give new descriptor class. 97 98 Returns: 99 New messages.Enum sub-class as described by enum_descriptor. 100 """ 101 enum_values = enum_descriptor.values or [] 102 103 class_dict = dict((value.name, value.number) for value in enum_values) 104 class_dict['__module__'] = module_name 105 return type(str(enum_descriptor.name), (messages.Enum,), class_dict) 106 107 108def define_field(field_descriptor): 109 """Define Field instance from descriptor. 110 111 Args: 112 field_descriptor: FieldDescriptor class to build field instance from. 113 114 Returns: 115 New field instance as described by enum_descriptor. 116 """ 117 field_class = _VARIANT_MAP[field_descriptor.variant] 118 params = {'number': field_descriptor.number, 119 'variant': field_descriptor.variant, 120 } 121 122 if field_descriptor.label == descriptor.FieldDescriptor.Label.REQUIRED: 123 params['required'] = True 124 elif field_descriptor.label == descriptor.FieldDescriptor.Label.REPEATED: 125 params['repeated'] = True 126 127 message_type_field = _MESSAGE_TYPE_MAP.get(field_descriptor.type_name) 128 if message_type_field: 129 return message_type_field(**params) 130 elif field_class in (messages.EnumField, messages.MessageField): 131 return field_class(field_descriptor.type_name, **params) 132 else: 133 if field_descriptor.default_value: 134 value = field_descriptor.default_value 135 try: 136 value = descriptor._DEFAULT_FROM_STRING_MAP[field_class](value) 137 except (TypeError, ValueError, KeyError): 138 pass # Let the value pass to the constructor. 139 params['default'] = value 140 return field_class(**params) 141 142 143def define_message(message_descriptor, module_name): 144 """Define Message class from descriptor. 145 146 Args: 147 message_descriptor: MessageDescriptor to describe message class from. 148 module_name: Module name to give to new descriptor class. 149 150 Returns: 151 New messages.Message sub-class as described by message_descriptor. 152 """ 153 class_dict = {'__module__': module_name} 154 155 for enum in message_descriptor.enum_types or []: 156 enum_instance = define_enum(enum, module_name) 157 class_dict[enum.name] = enum_instance 158 159 # TODO(rafek): support nested messages when supported by descriptor. 160 161 for field in message_descriptor.fields or []: 162 field_instance = define_field(field) 163 class_dict[field.name] = field_instance 164 165 class_name = message_descriptor.name.encode('utf-8') 166 return type(class_name, (messages.Message,), class_dict) 167 168 169def define_service(service_descriptor, module): 170 """Define a new service proxy. 171 172 Args: 173 service_descriptor: ServiceDescriptor class that describes the service. 174 module: Module to add service to. Request and response types are found 175 relative to this module. 176 177 Returns: 178 Service class proxy capable of communicating with a remote server. 179 """ 180 class_dict = {'__module__': module.__name__} 181 class_name = service_descriptor.name.encode('utf-8') 182 183 for method_descriptor in service_descriptor.methods or []: 184 request_definition = messages.find_definition( 185 method_descriptor.request_type, module) 186 response_definition = messages.find_definition( 187 method_descriptor.response_type, module) 188 189 method_name = method_descriptor.name.encode('utf-8') 190 def remote_method(self, request): 191 """Actual service method.""" 192 raise NotImplementedError('Method is not implemented') 193 remote_method.__name__ = method_name 194 remote_method_decorator = remote.method(request_definition, 195 response_definition) 196 197 class_dict[method_name] = remote_method_decorator(remote_method) 198 199 service_class = type(class_name, (remote.Service,), class_dict) 200 return service_class 201 202 203def define_file(file_descriptor, module=None): 204 """Define module from FileDescriptor. 205 206 Args: 207 file_descriptor: FileDescriptor instance to describe module from. 208 module: Module to add contained objects to. Module name overrides value 209 in file_descriptor.package. Definitions are added to existing 210 module if provided. 211 212 Returns: 213 If no module provided, will create a new module with its name set to the 214 file descriptor's package. If a module is provided, returns the same 215 module. 216 """ 217 if module is None: 218 module = types.ModuleType(file_descriptor.package) 219 220 for enum_descriptor in file_descriptor.enum_types or []: 221 enum_class = define_enum(enum_descriptor, module.__name__) 222 setattr(module, enum_descriptor.name, enum_class) 223 224 for message_descriptor in file_descriptor.message_types or []: 225 message_class = define_message(message_descriptor, module.__name__) 226 setattr(module, message_descriptor.name, message_class) 227 228 for service_descriptor in file_descriptor.service_types or []: 229 service_class = define_service(service_descriptor, module) 230 setattr(module, service_descriptor.name, service_class) 231 232 return module 233 234 235@util.positional(1) 236def import_file(file_descriptor, modules=None): 237 """Import FileDescriptor in to module space. 238 239 This is like define_file except that a new module and any required parent 240 modules are created and added to the modules parameter or sys.modules if not 241 provided. 242 243 Args: 244 file_descriptor: FileDescriptor instance to describe module from. 245 modules: Dictionary of modules to update. Modules and their parents that 246 do not exist will be created. If an existing module is found that 247 matches file_descriptor.package, that module is updated with the 248 FileDescriptor contents. 249 250 Returns: 251 Module found in modules, else a new module. 252 """ 253 if not file_descriptor.package: 254 raise ValueError('File descriptor must have package name') 255 256 if modules is None: 257 modules = sys.modules 258 259 module = _get_or_define_module(file_descriptor.package.encode('utf-8'), 260 modules) 261 262 return define_file(file_descriptor, module) 263 264 265@util.positional(1) 266def import_file_set(file_set, modules=None, _open=open): 267 """Import FileSet in to module space. 268 269 Args: 270 file_set: If string, open file and read serialized FileSet. Otherwise, 271 a FileSet instance to import definitions from. 272 modules: Dictionary of modules to update. Modules and their parents that 273 do not exist will be created. If an existing module is found that 274 matches file_descriptor.package, that module is updated with the 275 FileDescriptor contents. 276 _open: Used for dependency injection during tests. 277 """ 278 if isinstance(file_set, six.string_types): 279 encoded_file = _open(file_set, 'rb') 280 try: 281 encoded_file_set = encoded_file.read() 282 finally: 283 encoded_file.close() 284 285 file_set = protobuf.decode_message(descriptor.FileSet, encoded_file_set) 286 287 for file_descriptor in file_set.files: 288 # Do not reload built in protorpc classes. 289 if not file_descriptor.package.startswith('protorpc.'): 290 import_file(file_descriptor, modules=modules) 291