• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Stub library."""
19import six
20
21__author__ = 'rafek@google.com (Rafe Kaplan)'
22
23import sys
24import types
25
26from . import descriptor
27from . import message_types
28from . import messages
29from . import protobuf
30from . import remote
31from . import util
32
33__all__ = [
34    'define_enum',
35    'define_field',
36    'define_file',
37    'define_message',
38    'define_service',
39    'import_file',
40    'import_file_set',
41]
42
43
44# Map variant back to message field classes.
45def _build_variant_map():
46  """Map variants to fields.
47
48  Returns:
49    Dictionary mapping field variant to its associated field type.
50  """
51  result = {}
52  for name in dir(messages):
53    value = getattr(messages, name)
54    if isinstance(value, type) and issubclass(value, messages.Field):
55      for variant in getattr(value, 'VARIANTS', []):
56        result[variant] = value
57  return result
58
59_VARIANT_MAP = _build_variant_map()
60
61_MESSAGE_TYPE_MAP = {
62  message_types.DateTimeMessage.definition_name(): message_types.DateTimeField,
63}
64
65
66def _get_or_define_module(full_name, modules):
67  """Helper method for defining new modules.
68
69  Args:
70    full_name: Fully qualified name of module to create or return.
71    modules: Dictionary of all modules.  Defaults to sys.modules.
72
73  Returns:
74    Named module if found in 'modules', else creates new module and inserts in
75    'modules'.  Will also construct parent modules if necessary.
76  """
77  module = modules.get(full_name)
78  if not module:
79    module = types.ModuleType(full_name)
80    modules[full_name] = module
81
82    split_name = full_name.rsplit('.', 1)
83    if len(split_name) > 1:
84      parent_module_name, sub_module_name = split_name
85      parent_module = _get_or_define_module(parent_module_name, modules)
86      setattr(parent_module, sub_module_name, module)
87
88  return module
89
90
91def define_enum(enum_descriptor, module_name):
92  """Define Enum class from descriptor.
93
94  Args:
95    enum_descriptor: EnumDescriptor to build Enum class from.
96    module_name: Module name to give new descriptor class.
97
98  Returns:
99    New messages.Enum sub-class as described by enum_descriptor.
100  """
101  enum_values = enum_descriptor.values or []
102
103  class_dict = dict((value.name, value.number) for value in enum_values)
104  class_dict['__module__'] = module_name
105  return type(str(enum_descriptor.name), (messages.Enum,), class_dict)
106
107
108def define_field(field_descriptor):
109  """Define Field instance from descriptor.
110
111  Args:
112    field_descriptor: FieldDescriptor class to build field instance from.
113
114  Returns:
115    New field instance as described by enum_descriptor.
116  """
117  field_class = _VARIANT_MAP[field_descriptor.variant]
118  params = {'number': field_descriptor.number,
119            'variant': field_descriptor.variant,
120           }
121
122  if field_descriptor.label == descriptor.FieldDescriptor.Label.REQUIRED:
123    params['required'] = True
124  elif field_descriptor.label == descriptor.FieldDescriptor.Label.REPEATED:
125    params['repeated'] = True
126
127  message_type_field = _MESSAGE_TYPE_MAP.get(field_descriptor.type_name)
128  if message_type_field:
129    return message_type_field(**params)
130  elif field_class in (messages.EnumField, messages.MessageField):
131    return field_class(field_descriptor.type_name, **params)
132  else:
133    if field_descriptor.default_value:
134      value = field_descriptor.default_value
135      try:
136        value = descriptor._DEFAULT_FROM_STRING_MAP[field_class](value)
137      except (TypeError, ValueError, KeyError):
138        pass  # Let the value pass to the constructor.
139      params['default'] = value
140    return field_class(**params)
141
142
143def define_message(message_descriptor, module_name):
144  """Define Message class from descriptor.
145
146  Args:
147    message_descriptor: MessageDescriptor to describe message class from.
148    module_name: Module name to give to new descriptor class.
149
150  Returns:
151    New messages.Message sub-class as described by message_descriptor.
152  """
153  class_dict = {'__module__': module_name}
154
155  for enum in message_descriptor.enum_types or []:
156    enum_instance = define_enum(enum, module_name)
157    class_dict[enum.name] = enum_instance
158
159  # TODO(rafek): support nested messages when supported by descriptor.
160
161  for field in message_descriptor.fields or []:
162    field_instance = define_field(field)
163    class_dict[field.name] = field_instance
164
165  class_name = message_descriptor.name.encode('utf-8')
166  return type(class_name, (messages.Message,), class_dict)
167
168
169def define_service(service_descriptor, module):
170  """Define a new service proxy.
171
172  Args:
173    service_descriptor: ServiceDescriptor class that describes the service.
174    module: Module to add service to.  Request and response types are found
175      relative to this module.
176
177  Returns:
178    Service class proxy capable of communicating with a remote server.
179  """
180  class_dict = {'__module__': module.__name__}
181  class_name = service_descriptor.name.encode('utf-8')
182
183  for method_descriptor in service_descriptor.methods or []:
184    request_definition = messages.find_definition(
185        method_descriptor.request_type, module)
186    response_definition = messages.find_definition(
187        method_descriptor.response_type, module)
188
189    method_name = method_descriptor.name.encode('utf-8')
190    def remote_method(self, request):
191      """Actual service method."""
192      raise NotImplementedError('Method is not implemented')
193    remote_method.__name__ = method_name
194    remote_method_decorator = remote.method(request_definition,
195                                            response_definition)
196
197    class_dict[method_name] = remote_method_decorator(remote_method)
198
199  service_class = type(class_name, (remote.Service,), class_dict)
200  return service_class
201
202
203def define_file(file_descriptor, module=None):
204  """Define module from FileDescriptor.
205
206  Args:
207    file_descriptor: FileDescriptor instance to describe module from.
208    module: Module to add contained objects to.  Module name overrides value
209      in file_descriptor.package.  Definitions are added to existing
210      module if provided.
211
212  Returns:
213    If no module provided, will create a new module with its name set to the
214    file descriptor's package.  If a module is provided, returns the same
215    module.
216  """
217  if module is None:
218    module = types.ModuleType(file_descriptor.package)
219
220  for enum_descriptor in file_descriptor.enum_types or []:
221    enum_class = define_enum(enum_descriptor, module.__name__)
222    setattr(module, enum_descriptor.name, enum_class)
223
224  for message_descriptor in file_descriptor.message_types or []:
225    message_class = define_message(message_descriptor, module.__name__)
226    setattr(module, message_descriptor.name, message_class)
227
228  for service_descriptor in file_descriptor.service_types or []:
229    service_class = define_service(service_descriptor, module)
230    setattr(module, service_descriptor.name, service_class)
231
232  return module
233
234
235@util.positional(1)
236def import_file(file_descriptor, modules=None):
237  """Import FileDescriptor in to module space.
238
239  This is like define_file except that a new module and any required parent
240  modules are created and added to the modules parameter or sys.modules if not
241  provided.
242
243  Args:
244    file_descriptor: FileDescriptor instance to describe module from.
245    modules: Dictionary of modules to update.  Modules and their parents that
246      do not exist will be created.  If an existing module is found that
247      matches file_descriptor.package, that module is updated with the
248      FileDescriptor contents.
249
250  Returns:
251    Module found in modules, else a new module.
252  """
253  if not file_descriptor.package:
254    raise ValueError('File descriptor must have package name')
255
256  if modules is None:
257    modules = sys.modules
258
259  module = _get_or_define_module(file_descriptor.package.encode('utf-8'),
260                                 modules)
261
262  return define_file(file_descriptor, module)
263
264
265@util.positional(1)
266def import_file_set(file_set, modules=None, _open=open):
267  """Import FileSet in to module space.
268
269  Args:
270    file_set: If string, open file and read serialized FileSet.  Otherwise,
271      a FileSet instance to import definitions from.
272    modules: Dictionary of modules to update.  Modules and their parents that
273      do not exist will be created.  If an existing module is found that
274      matches file_descriptor.package, that module is updated with the
275      FileDescriptor contents.
276    _open: Used for dependency injection during tests.
277  """
278  if isinstance(file_set, six.string_types):
279    encoded_file = _open(file_set, 'rb')
280    try:
281      encoded_file_set = encoded_file.read()
282    finally:
283      encoded_file.close()
284
285    file_set = protobuf.decode_message(descriptor.FileSet, encoded_file_set)
286
287  for file_descriptor in file_set.files:
288    # Do not reload built in protorpc classes.
289    if not file_descriptor.package.startswith('protorpc.'):
290      import_file(file_descriptor, modules=modules)
291