• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2015 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Message registry for apitools."""
18
19import collections
20import contextlib
21import json
22
23import six
24
25from apitools.base.protorpclite import descriptor
26from apitools.base.protorpclite import messages
27from apitools.gen import extended_descriptor
28from apitools.gen import util
29
30TypeInfo = collections.namedtuple('TypeInfo', ('type_name', 'variant'))
31
32
33class MessageRegistry(object):
34
35    """Registry for message types.
36
37    This closely mirrors a messages.FileDescriptor, but adds additional
38    attributes (such as message and field descriptions) and some extra
39    code for validation and cycle detection.
40    """
41
42    # Type information from these two maps comes from here:
43    #  https://developers.google.com/discovery/v1/type-format
44    PRIMITIVE_TYPE_INFO_MAP = {
45        'string': TypeInfo(type_name='string',
46                           variant=messages.StringField.DEFAULT_VARIANT),
47        'integer': TypeInfo(type_name='integer',
48                            variant=messages.IntegerField.DEFAULT_VARIANT),
49        'boolean': TypeInfo(type_name='boolean',
50                            variant=messages.BooleanField.DEFAULT_VARIANT),
51        'number': TypeInfo(type_name='number',
52                           variant=messages.FloatField.DEFAULT_VARIANT),
53        'any': TypeInfo(type_name='extra_types.JsonValue',
54                        variant=messages.Variant.MESSAGE),
55    }
56
57    PRIMITIVE_FORMAT_MAP = {
58        'int32': TypeInfo(type_name='integer',
59                          variant=messages.Variant.INT32),
60        'uint32': TypeInfo(type_name='integer',
61                           variant=messages.Variant.UINT32),
62        'int64': TypeInfo(type_name='string',
63                          variant=messages.Variant.INT64),
64        'uint64': TypeInfo(type_name='string',
65                           variant=messages.Variant.UINT64),
66        'double': TypeInfo(type_name='number',
67                           variant=messages.Variant.DOUBLE),
68        'float': TypeInfo(type_name='number',
69                          variant=messages.Variant.FLOAT),
70        'byte': TypeInfo(type_name='byte',
71                         variant=messages.BytesField.DEFAULT_VARIANT),
72        'date': TypeInfo(type_name='extra_types.DateField',
73                         variant=messages.Variant.STRING),
74        'date-time': TypeInfo(
75            type_name=('apitools.base.protorpclite.message_types.'
76                       'DateTimeMessage'),
77            variant=messages.Variant.MESSAGE),
78    }
79
80    def __init__(self, client_info, names, description, root_package_dir,
81                 base_files_package, protorpc_package):
82        self.__names = names
83        self.__client_info = client_info
84        self.__package = client_info.package
85        self.__description = util.CleanDescription(description)
86        self.__root_package_dir = root_package_dir
87        self.__base_files_package = base_files_package
88        self.__protorpc_package = protorpc_package
89        self.__file_descriptor = extended_descriptor.ExtendedFileDescriptor(
90            package=self.__package, description=self.__description)
91        # Add required imports
92        self.__file_descriptor.additional_imports = [
93            'from %s import messages as _messages' % self.__protorpc_package,
94        ]
95        # Map from scoped names (i.e. Foo.Bar) to MessageDescriptors.
96        self.__message_registry = collections.OrderedDict()
97        # A set of types that we're currently adding (for cycle detection).
98        self.__nascent_types = set()
99        # A set of types for which we've seen a reference but no
100        # definition; if this set is nonempty, validation fails.
101        self.__unknown_types = set()
102        # Used for tracking paths during message creation
103        self.__current_path = []
104        # Where to register created messages
105        self.__current_env = self.__file_descriptor
106        # TODO(craigcitro): Add a `Finalize` method.
107
108    @property
109    def file_descriptor(self):
110        self.Validate()
111        return self.__file_descriptor
112
113    def WriteProtoFile(self, printer):
114        """Write the messages file to out as proto."""
115        self.Validate()
116        extended_descriptor.WriteMessagesFile(
117            self.__file_descriptor, self.__package, self.__client_info.version,
118            printer)
119
120    def WriteFile(self, printer):
121        """Write the messages file to out."""
122        self.Validate()
123        extended_descriptor.WritePythonFile(
124            self.__file_descriptor, self.__package, self.__client_info.version,
125            printer)
126
127    def Validate(self):
128        mysteries = self.__nascent_types or self.__unknown_types
129        if mysteries:
130            raise ValueError('Malformed MessageRegistry: %s' % mysteries)
131
132    def __ComputeFullName(self, name):
133        return '.'.join(map(six.text_type, self.__current_path[:] + [name]))
134
135    def __AddImport(self, new_import):
136        if new_import not in self.__file_descriptor.additional_imports:
137            self.__file_descriptor.additional_imports.append(new_import)
138
139    def __DeclareDescriptor(self, name):
140        self.__nascent_types.add(self.__ComputeFullName(name))
141
142    def __RegisterDescriptor(self, new_descriptor):
143        """Register the given descriptor in this registry."""
144        if not isinstance(new_descriptor, (
145                extended_descriptor.ExtendedMessageDescriptor,
146                extended_descriptor.ExtendedEnumDescriptor)):
147            raise ValueError('Cannot add descriptor of type %s' % (
148                type(new_descriptor),))
149        full_name = self.__ComputeFullName(new_descriptor.name)
150        if full_name in self.__message_registry:
151            raise ValueError(
152                'Attempt to re-register descriptor %s' % full_name)
153        if full_name not in self.__nascent_types:
154            raise ValueError('Directly adding types is not supported')
155        new_descriptor.full_name = full_name
156        self.__message_registry[full_name] = new_descriptor
157        if isinstance(new_descriptor,
158                      extended_descriptor.ExtendedMessageDescriptor):
159            self.__current_env.message_types.append(new_descriptor)
160        elif isinstance(new_descriptor,
161                        extended_descriptor.ExtendedEnumDescriptor):
162            self.__current_env.enum_types.append(new_descriptor)
163        self.__unknown_types.discard(full_name)
164        self.__nascent_types.remove(full_name)
165
166    def LookupDescriptor(self, name):
167        return self.__GetDescriptorByName(name)
168
169    def LookupDescriptorOrDie(self, name):
170        message_descriptor = self.LookupDescriptor(name)
171        if message_descriptor is None:
172            raise ValueError('No message descriptor named "%s"', name)
173        return message_descriptor
174
175    def __GetDescriptor(self, name):
176        return self.__GetDescriptorByName(self.__ComputeFullName(name))
177
178    def __GetDescriptorByName(self, name):
179        if name in self.__message_registry:
180            return self.__message_registry[name]
181        if name in self.__nascent_types:
182            raise ValueError(
183                'Cannot retrieve type currently being created: %s' % name)
184        return None
185
186    @contextlib.contextmanager
187    def __DescriptorEnv(self, message_descriptor):
188        # TODO(craigcitro): Typecheck?
189        previous_env = self.__current_env
190        self.__current_path.append(message_descriptor.name)
191        self.__current_env = message_descriptor
192        yield
193        self.__current_path.pop()
194        self.__current_env = previous_env
195
196    def AddEnumDescriptor(self, name, description,
197                          enum_values, enum_descriptions):
198        """Add a new EnumDescriptor named name with the given enum values."""
199        message = extended_descriptor.ExtendedEnumDescriptor()
200        message.name = self.__names.ClassName(name)
201        message.description = util.CleanDescription(description)
202        self.__DeclareDescriptor(message.name)
203        for index, (enum_name, enum_description) in enumerate(
204                zip(enum_values, enum_descriptions)):
205            enum_value = extended_descriptor.ExtendedEnumValueDescriptor()
206            enum_value.name = self.__names.NormalizeEnumName(enum_name)
207            if enum_value.name != enum_name:
208                message.enum_mappings.append(
209                    extended_descriptor.ExtendedEnumDescriptor.JsonEnumMapping(
210                        python_name=enum_value.name, json_name=enum_name))
211                self.__AddImport('from %s import encoding' %
212                                 self.__base_files_package)
213            enum_value.number = index
214            enum_value.description = util.CleanDescription(
215                enum_description or '<no description>')
216            message.values.append(enum_value)
217        self.__RegisterDescriptor(message)
218
219    def __DeclareMessageAlias(self, schema, alias_for):
220        """Declare schema as an alias for alias_for."""
221        # TODO(craigcitro): This is a hack. Remove it.
222        message = extended_descriptor.ExtendedMessageDescriptor()
223        message.name = self.__names.ClassName(schema['id'])
224        message.alias_for = alias_for
225        self.__DeclareDescriptor(message.name)
226        self.__AddImport('from %s import extra_types' %
227                         self.__base_files_package)
228        self.__RegisterDescriptor(message)
229
230    def __AddAdditionalProperties(self, message, schema, properties):
231        """Add an additionalProperties field to message."""
232        additional_properties_info = schema['additionalProperties']
233        entries_type_name = self.__AddAdditionalPropertyType(
234            message.name, additional_properties_info)
235        description = util.CleanDescription(
236            additional_properties_info.get('description'))
237        if description is None:
238            description = 'Additional properties of type %s' % message.name
239        attrs = {
240            'items': {
241                '$ref': entries_type_name,
242            },
243            'description': description,
244            'type': 'array',
245        }
246        field_name = 'additionalProperties'
247        message.fields.append(self.__FieldDescriptorFromProperties(
248            field_name, len(properties) + 1, attrs))
249        self.__AddImport('from %s import encoding' % self.__base_files_package)
250        message.decorators.append(
251            'encoding.MapUnrecognizedFields(%r)' % field_name)
252
253    def AddDescriptorFromSchema(self, schema_name, schema):
254        """Add a new MessageDescriptor named schema_name based on schema."""
255        # TODO(craigcitro): Is schema_name redundant?
256        if self.__GetDescriptor(schema_name):
257            return
258        if schema.get('enum'):
259            self.__DeclareEnum(schema_name, schema)
260            return
261        if schema.get('type') == 'any':
262            self.__DeclareMessageAlias(schema, 'extra_types.JsonValue')
263            return
264        if schema.get('type') != 'object':
265            raise ValueError('Cannot create message descriptors for type %s',
266                             schema.get('type'))
267        message = extended_descriptor.ExtendedMessageDescriptor()
268        message.name = self.__names.ClassName(schema['id'])
269        message.description = util.CleanDescription(schema.get(
270            'description', 'A %s object.' % message.name))
271        self.__DeclareDescriptor(message.name)
272        with self.__DescriptorEnv(message):
273            properties = schema.get('properties', {})
274            for index, (name, attrs) in enumerate(sorted(properties.items())):
275                field = self.__FieldDescriptorFromProperties(
276                    name, index + 1, attrs)
277                message.fields.append(field)
278                if field.name != name:
279                    message.field_mappings.append(
280                        type(message).JsonFieldMapping(
281                            python_name=field.name, json_name=name))
282                    self.__AddImport(
283                        'from %s import encoding' % self.__base_files_package)
284            if 'additionalProperties' in schema:
285                self.__AddAdditionalProperties(message, schema, properties)
286        self.__RegisterDescriptor(message)
287
288    def __AddAdditionalPropertyType(self, name, property_schema):
289        """Add a new nested AdditionalProperty message."""
290        new_type_name = 'AdditionalProperty'
291        property_schema = dict(property_schema)
292        # We drop the description here on purpose, so the resulting
293        # messages are less repetitive.
294        property_schema.pop('description', None)
295        description = 'An additional property for a %s object.' % name
296        schema = {
297            'id': new_type_name,
298            'type': 'object',
299            'description': description,
300            'properties': {
301                'key': {
302                    'type': 'string',
303                    'description': 'Name of the additional property.',
304                },
305                'value': property_schema,
306            },
307        }
308        self.AddDescriptorFromSchema(new_type_name, schema)
309        return new_type_name
310
311    def __AddEntryType(self, entry_type_name, entry_schema, parent_name):
312        """Add a type for a list entry."""
313        entry_schema.pop('description', None)
314        description = 'Single entry in a %s.' % parent_name
315        schema = {
316            'id': entry_type_name,
317            'type': 'object',
318            'description': description,
319            'properties': {
320                'entry': {
321                    'type': 'array',
322                    'items': entry_schema,
323                },
324            },
325        }
326        self.AddDescriptorFromSchema(entry_type_name, schema)
327        return entry_type_name
328
329    def __FieldDescriptorFromProperties(self, name, index, attrs):
330        """Create a field descriptor for these attrs."""
331        field = descriptor.FieldDescriptor()
332        field.name = self.__names.CleanName(name)
333        field.number = index
334        field.label = self.__ComputeLabel(attrs)
335        new_type_name_hint = self.__names.ClassName(
336            '%sValue' % self.__names.ClassName(name))
337        type_info = self.__GetTypeInfo(attrs, new_type_name_hint)
338        field.type_name = type_info.type_name
339        field.variant = type_info.variant
340        if 'default' in attrs:
341            # TODO(craigcitro): Correctly handle non-primitive default values.
342            default = attrs['default']
343            if not (field.type_name == 'string' or
344                    field.variant == messages.Variant.ENUM):
345                default = str(json.loads(default))
346            if field.variant == messages.Variant.ENUM:
347                default = self.__names.NormalizeEnumName(default)
348            field.default_value = default
349        extended_field = extended_descriptor.ExtendedFieldDescriptor()
350        extended_field.name = field.name
351        extended_field.description = util.CleanDescription(
352            attrs.get('description', 'A %s attribute.' % field.type_name))
353        extended_field.field_descriptor = field
354        return extended_field
355
356    @staticmethod
357    def __ComputeLabel(attrs):
358        if attrs.get('required', False):
359            return descriptor.FieldDescriptor.Label.REQUIRED
360        elif attrs.get('type') == 'array':
361            return descriptor.FieldDescriptor.Label.REPEATED
362        elif attrs.get('repeated'):
363            return descriptor.FieldDescriptor.Label.REPEATED
364        return descriptor.FieldDescriptor.Label.OPTIONAL
365
366    def __DeclareEnum(self, enum_name, attrs):
367        description = util.CleanDescription(attrs.get('description', ''))
368        enum_values = attrs['enum']
369        enum_descriptions = attrs.get(
370            'enumDescriptions', [''] * len(enum_values))
371        self.AddEnumDescriptor(enum_name, description,
372                               enum_values, enum_descriptions)
373        self.__AddIfUnknown(enum_name)
374        return TypeInfo(type_name=enum_name, variant=messages.Variant.ENUM)
375
376    def __AddIfUnknown(self, type_name):
377        type_name = self.__names.ClassName(type_name)
378        full_type_name = self.__ComputeFullName(type_name)
379        if (full_type_name not in self.__message_registry.keys() and
380                type_name not in self.__message_registry.keys()):
381            self.__unknown_types.add(type_name)
382
383    def __GetTypeInfo(self, attrs, name_hint):
384        """Return a TypeInfo object for attrs, creating one if needed."""
385
386        type_ref = self.__names.ClassName(attrs.get('$ref'))
387        type_name = attrs.get('type')
388        if not (type_ref or type_name):
389            raise ValueError('No type found for %s' % attrs)
390
391        if type_ref:
392            self.__AddIfUnknown(type_ref)
393            # We don't actually know this is a message -- it might be an
394            # enum. However, we can't check that until we've created all the
395            # types, so we come back and fix this up later.
396            return TypeInfo(
397                type_name=type_ref, variant=messages.Variant.MESSAGE)
398
399        if 'enum' in attrs:
400            enum_name = '%sValuesEnum' % name_hint
401            return self.__DeclareEnum(enum_name, attrs)
402
403        if 'format' in attrs:
404            type_info = self.PRIMITIVE_FORMAT_MAP.get(attrs['format'])
405            if type_info is None:
406                # If we don't recognize the format, the spec says we fall back
407                # to just using the type name.
408                if type_name in self.PRIMITIVE_TYPE_INFO_MAP:
409                    return self.PRIMITIVE_TYPE_INFO_MAP[type_name]
410                raise ValueError('Unknown type/format "%s"/"%s"' % (
411                    attrs['format'], type_name))
412            if type_info.type_name.startswith((
413                    'apitools.base.protorpclite.message_types.',
414                    'message_types.')):
415                self.__AddImport(
416                    'from %s import message_types as _message_types' %
417                    self.__protorpc_package)
418            if type_info.type_name.startswith('extra_types.'):
419                self.__AddImport(
420                    'from %s import extra_types' % self.__base_files_package)
421            return type_info
422
423        if type_name in self.PRIMITIVE_TYPE_INFO_MAP:
424            type_info = self.PRIMITIVE_TYPE_INFO_MAP[type_name]
425            if type_info.type_name.startswith('extra_types.'):
426                self.__AddImport(
427                    'from %s import extra_types' % self.__base_files_package)
428            return type_info
429
430        if type_name == 'array':
431            items = attrs.get('items')
432            if not items:
433                raise ValueError('Array type with no item type: %s' % attrs)
434            entry_name_hint = self.__names.ClassName(
435                items.get('title') or '%sListEntry' % name_hint)
436            entry_label = self.__ComputeLabel(items)
437            if entry_label == descriptor.FieldDescriptor.Label.REPEATED:
438                parent_name = self.__names.ClassName(
439                    items.get('title') or name_hint)
440                entry_type_name = self.__AddEntryType(
441                    entry_name_hint, items.get('items'), parent_name)
442                return TypeInfo(type_name=entry_type_name,
443                                variant=messages.Variant.MESSAGE)
444            return self.__GetTypeInfo(items, entry_name_hint)
445        elif type_name == 'any':
446            self.__AddImport('from %s import extra_types' %
447                             self.__base_files_package)
448            return self.PRIMITIVE_TYPE_INFO_MAP['any']
449        elif type_name == 'object':
450            # TODO(craigcitro): Think of a better way to come up with names.
451            if not name_hint:
452                raise ValueError(
453                    'Cannot create subtype without some name hint')
454            schema = dict(attrs)
455            schema['id'] = name_hint
456            self.AddDescriptorFromSchema(name_hint, schema)
457            self.__AddIfUnknown(name_hint)
458            return TypeInfo(
459                type_name=name_hint, variant=messages.Variant.MESSAGE)
460
461        raise ValueError('Unknown type: %s' % type_name)
462
463    def FixupMessageFields(self):
464        for message_type in self.file_descriptor.message_types:
465            self._FixupMessage(message_type)
466
467    def _FixupMessage(self, message_type):
468        with self.__DescriptorEnv(message_type):
469            for field in message_type.fields:
470                if field.field_descriptor.variant == messages.Variant.MESSAGE:
471                    field_type_name = field.field_descriptor.type_name
472                    field_type = self.LookupDescriptor(field_type_name)
473                    if isinstance(field_type,
474                                  extended_descriptor.ExtendedEnumDescriptor):
475                        field.field_descriptor.variant = messages.Variant.ENUM
476            for submessage_type in message_type.message_types:
477                self._FixupMessage(submessage_type)
478