1#!/usr/bin/env python 2# 3# Copyright 2015 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Message registry for apitools.""" 18 19import collections 20import contextlib 21import json 22 23import six 24 25from apitools.base.protorpclite import descriptor 26from apitools.base.protorpclite import messages 27from apitools.gen import extended_descriptor 28from apitools.gen import util 29 30TypeInfo = collections.namedtuple('TypeInfo', ('type_name', 'variant')) 31 32 33class MessageRegistry(object): 34 35 """Registry for message types. 36 37 This closely mirrors a messages.FileDescriptor, but adds additional 38 attributes (such as message and field descriptions) and some extra 39 code for validation and cycle detection. 40 """ 41 42 # Type information from these two maps comes from here: 43 # https://developers.google.com/discovery/v1/type-format 44 PRIMITIVE_TYPE_INFO_MAP = { 45 'string': TypeInfo(type_name='string', 46 variant=messages.StringField.DEFAULT_VARIANT), 47 'integer': TypeInfo(type_name='integer', 48 variant=messages.IntegerField.DEFAULT_VARIANT), 49 'boolean': TypeInfo(type_name='boolean', 50 variant=messages.BooleanField.DEFAULT_VARIANT), 51 'number': TypeInfo(type_name='number', 52 variant=messages.FloatField.DEFAULT_VARIANT), 53 'any': TypeInfo(type_name='extra_types.JsonValue', 54 variant=messages.Variant.MESSAGE), 55 } 56 57 PRIMITIVE_FORMAT_MAP = { 58 'int32': TypeInfo(type_name='integer', 59 variant=messages.Variant.INT32), 60 'uint32': TypeInfo(type_name='integer', 61 variant=messages.Variant.UINT32), 62 'int64': TypeInfo(type_name='string', 63 variant=messages.Variant.INT64), 64 'uint64': TypeInfo(type_name='string', 65 variant=messages.Variant.UINT64), 66 'double': TypeInfo(type_name='number', 67 variant=messages.Variant.DOUBLE), 68 'float': TypeInfo(type_name='number', 69 variant=messages.Variant.FLOAT), 70 'byte': TypeInfo(type_name='byte', 71 variant=messages.BytesField.DEFAULT_VARIANT), 72 'date': TypeInfo(type_name='extra_types.DateField', 73 variant=messages.Variant.STRING), 74 'date-time': TypeInfo( 75 type_name=('apitools.base.protorpclite.message_types.' 76 'DateTimeMessage'), 77 variant=messages.Variant.MESSAGE), 78 } 79 80 def __init__(self, client_info, names, description, root_package_dir, 81 base_files_package, protorpc_package): 82 self.__names = names 83 self.__client_info = client_info 84 self.__package = client_info.package 85 self.__description = util.CleanDescription(description) 86 self.__root_package_dir = root_package_dir 87 self.__base_files_package = base_files_package 88 self.__protorpc_package = protorpc_package 89 self.__file_descriptor = extended_descriptor.ExtendedFileDescriptor( 90 package=self.__package, description=self.__description) 91 # Add required imports 92 self.__file_descriptor.additional_imports = [ 93 'from %s import messages as _messages' % self.__protorpc_package, 94 ] 95 # Map from scoped names (i.e. Foo.Bar) to MessageDescriptors. 96 self.__message_registry = collections.OrderedDict() 97 # A set of types that we're currently adding (for cycle detection). 98 self.__nascent_types = set() 99 # A set of types for which we've seen a reference but no 100 # definition; if this set is nonempty, validation fails. 101 self.__unknown_types = set() 102 # Used for tracking paths during message creation 103 self.__current_path = [] 104 # Where to register created messages 105 self.__current_env = self.__file_descriptor 106 # TODO(craigcitro): Add a `Finalize` method. 107 108 @property 109 def file_descriptor(self): 110 self.Validate() 111 return self.__file_descriptor 112 113 def WriteProtoFile(self, printer): 114 """Write the messages file to out as proto.""" 115 self.Validate() 116 extended_descriptor.WriteMessagesFile( 117 self.__file_descriptor, self.__package, self.__client_info.version, 118 printer) 119 120 def WriteFile(self, printer): 121 """Write the messages file to out.""" 122 self.Validate() 123 extended_descriptor.WritePythonFile( 124 self.__file_descriptor, self.__package, self.__client_info.version, 125 printer) 126 127 def Validate(self): 128 mysteries = self.__nascent_types or self.__unknown_types 129 if mysteries: 130 raise ValueError('Malformed MessageRegistry: %s' % mysteries) 131 132 def __ComputeFullName(self, name): 133 return '.'.join(map(six.text_type, self.__current_path[:] + [name])) 134 135 def __AddImport(self, new_import): 136 if new_import not in self.__file_descriptor.additional_imports: 137 self.__file_descriptor.additional_imports.append(new_import) 138 139 def __DeclareDescriptor(self, name): 140 self.__nascent_types.add(self.__ComputeFullName(name)) 141 142 def __RegisterDescriptor(self, new_descriptor): 143 """Register the given descriptor in this registry.""" 144 if not isinstance(new_descriptor, ( 145 extended_descriptor.ExtendedMessageDescriptor, 146 extended_descriptor.ExtendedEnumDescriptor)): 147 raise ValueError('Cannot add descriptor of type %s' % ( 148 type(new_descriptor),)) 149 full_name = self.__ComputeFullName(new_descriptor.name) 150 if full_name in self.__message_registry: 151 raise ValueError( 152 'Attempt to re-register descriptor %s' % full_name) 153 if full_name not in self.__nascent_types: 154 raise ValueError('Directly adding types is not supported') 155 new_descriptor.full_name = full_name 156 self.__message_registry[full_name] = new_descriptor 157 if isinstance(new_descriptor, 158 extended_descriptor.ExtendedMessageDescriptor): 159 self.__current_env.message_types.append(new_descriptor) 160 elif isinstance(new_descriptor, 161 extended_descriptor.ExtendedEnumDescriptor): 162 self.__current_env.enum_types.append(new_descriptor) 163 self.__unknown_types.discard(full_name) 164 self.__nascent_types.remove(full_name) 165 166 def LookupDescriptor(self, name): 167 return self.__GetDescriptorByName(name) 168 169 def LookupDescriptorOrDie(self, name): 170 message_descriptor = self.LookupDescriptor(name) 171 if message_descriptor is None: 172 raise ValueError('No message descriptor named "%s"' % name) 173 return message_descriptor 174 175 def __GetDescriptor(self, name): 176 return self.__GetDescriptorByName(self.__ComputeFullName(name)) 177 178 def __GetDescriptorByName(self, name): 179 if name in self.__message_registry: 180 return self.__message_registry[name] 181 if name in self.__nascent_types: 182 raise ValueError( 183 'Cannot retrieve type currently being created: %s' % name) 184 return None 185 186 @contextlib.contextmanager 187 def __DescriptorEnv(self, message_descriptor): 188 # TODO(craigcitro): Typecheck? 189 previous_env = self.__current_env 190 self.__current_path.append(message_descriptor.name) 191 self.__current_env = message_descriptor 192 yield 193 self.__current_path.pop() 194 self.__current_env = previous_env 195 196 def AddEnumDescriptor(self, name, description, 197 enum_values, enum_descriptions): 198 """Add a new EnumDescriptor named name with the given enum values.""" 199 message = extended_descriptor.ExtendedEnumDescriptor() 200 message.name = self.__names.ClassName(name) 201 message.description = util.CleanDescription(description) 202 self.__DeclareDescriptor(message.name) 203 for index, (enum_name, enum_description) in enumerate( 204 zip(enum_values, enum_descriptions)): 205 enum_value = extended_descriptor.ExtendedEnumValueDescriptor() 206 enum_value.name = self.__names.NormalizeEnumName(enum_name) 207 if enum_value.name != enum_name: 208 message.enum_mappings.append( 209 extended_descriptor.ExtendedEnumDescriptor.JsonEnumMapping( 210 python_name=enum_value.name, json_name=enum_name)) 211 self.__AddImport('from %s import encoding' % 212 self.__base_files_package) 213 enum_value.number = index 214 enum_value.description = util.CleanDescription( 215 enum_description or '<no description>') 216 message.values.append(enum_value) 217 self.__RegisterDescriptor(message) 218 219 def __DeclareMessageAlias(self, schema, alias_for): 220 """Declare schema as an alias for alias_for.""" 221 # TODO(craigcitro): This is a hack. Remove it. 222 message = extended_descriptor.ExtendedMessageDescriptor() 223 message.name = self.__names.ClassName(schema['id']) 224 message.alias_for = alias_for 225 self.__DeclareDescriptor(message.name) 226 self.__AddImport('from %s import extra_types' % 227 self.__base_files_package) 228 self.__RegisterDescriptor(message) 229 230 def __AddAdditionalProperties(self, message, schema, properties): 231 """Add an additionalProperties field to message.""" 232 additional_properties_info = schema['additionalProperties'] 233 entries_type_name = self.__AddAdditionalPropertyType( 234 message.name, additional_properties_info) 235 description = util.CleanDescription( 236 additional_properties_info.get('description')) 237 if description is None: 238 description = 'Additional properties of type %s' % message.name 239 attrs = { 240 'items': { 241 '$ref': entries_type_name, 242 }, 243 'description': description, 244 'type': 'array', 245 } 246 field_name = 'additionalProperties' 247 message.fields.append(self.__FieldDescriptorFromProperties( 248 field_name, len(properties) + 1, attrs)) 249 self.__AddImport('from %s import encoding' % self.__base_files_package) 250 message.decorators.append( 251 'encoding.MapUnrecognizedFields(%r)' % field_name) 252 253 def AddDescriptorFromSchema(self, schema_name, schema): 254 """Add a new MessageDescriptor named schema_name based on schema.""" 255 # TODO(craigcitro): Is schema_name redundant? 256 if self.__GetDescriptor(schema_name): 257 return 258 if schema.get('enum'): 259 self.__DeclareEnum(schema_name, schema) 260 return 261 if schema.get('type') == 'any': 262 self.__DeclareMessageAlias(schema, 'extra_types.JsonValue') 263 return 264 if schema.get('type') != 'object': 265 raise ValueError('Cannot create message descriptors for type %s' % 266 schema.get('type')) 267 message = extended_descriptor.ExtendedMessageDescriptor() 268 message.name = self.__names.ClassName(schema['id']) 269 message.description = util.CleanDescription(schema.get( 270 'description', 'A %s object.' % message.name)) 271 self.__DeclareDescriptor(message.name) 272 with self.__DescriptorEnv(message): 273 properties = schema.get('properties', {}) 274 for index, (name, attrs) in enumerate(sorted(properties.items())): 275 field = self.__FieldDescriptorFromProperties( 276 name, index + 1, attrs) 277 message.fields.append(field) 278 if field.name != name: 279 message.field_mappings.append( 280 type(message).JsonFieldMapping( 281 python_name=field.name, json_name=name)) 282 self.__AddImport( 283 'from %s import encoding' % self.__base_files_package) 284 if 'additionalProperties' in schema: 285 self.__AddAdditionalProperties(message, schema, properties) 286 self.__RegisterDescriptor(message) 287 288 def __AddAdditionalPropertyType(self, name, property_schema): 289 """Add a new nested AdditionalProperty message.""" 290 new_type_name = 'AdditionalProperty' 291 property_schema = dict(property_schema) 292 # We drop the description here on purpose, so the resulting 293 # messages are less repetitive. 294 property_schema.pop('description', None) 295 description = 'An additional property for a %s object.' % name 296 schema = { 297 'id': new_type_name, 298 'type': 'object', 299 'description': description, 300 'properties': { 301 'key': { 302 'type': 'string', 303 'description': 'Name of the additional property.', 304 }, 305 'value': property_schema, 306 }, 307 } 308 self.AddDescriptorFromSchema(new_type_name, schema) 309 return new_type_name 310 311 def __AddEntryType(self, entry_type_name, entry_schema, parent_name): 312 """Add a type for a list entry.""" 313 entry_schema.pop('description', None) 314 description = 'Single entry in a %s.' % parent_name 315 schema = { 316 'id': entry_type_name, 317 'type': 'object', 318 'description': description, 319 'properties': { 320 'entry': { 321 'type': 'array', 322 'items': entry_schema, 323 }, 324 }, 325 } 326 self.AddDescriptorFromSchema(entry_type_name, schema) 327 return entry_type_name 328 329 def __FieldDescriptorFromProperties(self, name, index, attrs): 330 """Create a field descriptor for these attrs.""" 331 field = descriptor.FieldDescriptor() 332 field.name = self.__names.CleanName(name) 333 field.number = index 334 field.label = self.__ComputeLabel(attrs) 335 new_type_name_hint = self.__names.ClassName( 336 '%sValue' % self.__names.ClassName(name)) 337 type_info = self.__GetTypeInfo(attrs, new_type_name_hint) 338 field.type_name = type_info.type_name 339 field.variant = type_info.variant 340 if 'default' in attrs: 341 # TODO(craigcitro): Correctly handle non-primitive default values. 342 default = attrs['default'] 343 if not (field.type_name == 'string' or 344 field.variant == messages.Variant.ENUM): 345 default = str(json.loads(default)) 346 if field.variant == messages.Variant.ENUM: 347 default = self.__names.NormalizeEnumName(default) 348 field.default_value = default 349 extended_field = extended_descriptor.ExtendedFieldDescriptor() 350 extended_field.name = field.name 351 extended_field.description = util.CleanDescription( 352 attrs.get('description', 'A %s attribute.' % field.type_name)) 353 extended_field.field_descriptor = field 354 return extended_field 355 356 @staticmethod 357 def __ComputeLabel(attrs): 358 if attrs.get('required', False): 359 return descriptor.FieldDescriptor.Label.REQUIRED 360 elif attrs.get('type') == 'array': 361 return descriptor.FieldDescriptor.Label.REPEATED 362 elif attrs.get('repeated'): 363 return descriptor.FieldDescriptor.Label.REPEATED 364 return descriptor.FieldDescriptor.Label.OPTIONAL 365 366 def __DeclareEnum(self, enum_name, attrs): 367 description = util.CleanDescription(attrs.get('description', '')) 368 enum_values = attrs['enum'] 369 enum_descriptions = attrs.get( 370 'enumDescriptions', [''] * len(enum_values)) 371 self.AddEnumDescriptor(enum_name, description, 372 enum_values, enum_descriptions) 373 self.__AddIfUnknown(enum_name) 374 return TypeInfo(type_name=enum_name, variant=messages.Variant.ENUM) 375 376 def __AddIfUnknown(self, type_name): 377 type_name = self.__names.ClassName(type_name) 378 full_type_name = self.__ComputeFullName(type_name) 379 if (full_type_name not in self.__message_registry.keys() and 380 type_name not in self.__message_registry.keys()): 381 self.__unknown_types.add(type_name) 382 383 def __GetTypeInfo(self, attrs, name_hint): 384 """Return a TypeInfo object for attrs, creating one if needed.""" 385 386 type_ref = self.__names.ClassName(attrs.get('$ref')) 387 type_name = attrs.get('type') 388 if not (type_ref or type_name): 389 raise ValueError('No type found for %s' % attrs) 390 391 if type_ref: 392 self.__AddIfUnknown(type_ref) 393 # We don't actually know this is a message -- it might be an 394 # enum. However, we can't check that until we've created all the 395 # types, so we come back and fix this up later. 396 return TypeInfo( 397 type_name=type_ref, variant=messages.Variant.MESSAGE) 398 399 if 'enum' in attrs: 400 enum_name = '%sValuesEnum' % name_hint 401 return self.__DeclareEnum(enum_name, attrs) 402 403 if 'format' in attrs: 404 type_info = self.PRIMITIVE_FORMAT_MAP.get(attrs['format']) 405 if type_info is None: 406 # If we don't recognize the format, the spec says we fall back 407 # to just using the type name. 408 if type_name in self.PRIMITIVE_TYPE_INFO_MAP: 409 return self.PRIMITIVE_TYPE_INFO_MAP[type_name] 410 raise ValueError('Unknown type/format "%s"/"%s"' % ( 411 attrs['format'], type_name)) 412 if type_info.type_name.startswith(( 413 'apitools.base.protorpclite.message_types.', 414 'message_types.')): 415 self.__AddImport( 416 'from %s import message_types as _message_types' % 417 self.__protorpc_package) 418 if type_info.type_name.startswith('extra_types.'): 419 self.__AddImport( 420 'from %s import extra_types' % self.__base_files_package) 421 return type_info 422 423 if type_name in self.PRIMITIVE_TYPE_INFO_MAP: 424 type_info = self.PRIMITIVE_TYPE_INFO_MAP[type_name] 425 if type_info.type_name.startswith('extra_types.'): 426 self.__AddImport( 427 'from %s import extra_types' % self.__base_files_package) 428 return type_info 429 430 if type_name == 'array': 431 items = attrs.get('items') 432 if not items: 433 raise ValueError('Array type with no item type: %s' % attrs) 434 entry_name_hint = self.__names.ClassName( 435 items.get('title') or '%sListEntry' % name_hint) 436 entry_label = self.__ComputeLabel(items) 437 if entry_label == descriptor.FieldDescriptor.Label.REPEATED: 438 parent_name = self.__names.ClassName( 439 items.get('title') or name_hint) 440 entry_type_name = self.__AddEntryType( 441 entry_name_hint, items.get('items'), parent_name) 442 return TypeInfo(type_name=entry_type_name, 443 variant=messages.Variant.MESSAGE) 444 return self.__GetTypeInfo(items, entry_name_hint) 445 elif type_name == 'any': 446 self.__AddImport('from %s import extra_types' % 447 self.__base_files_package) 448 return self.PRIMITIVE_TYPE_INFO_MAP['any'] 449 elif type_name == 'object': 450 # TODO(craigcitro): Think of a better way to come up with names. 451 if not name_hint: 452 raise ValueError( 453 'Cannot create subtype without some name hint') 454 schema = dict(attrs) 455 schema['id'] = name_hint 456 self.AddDescriptorFromSchema(name_hint, schema) 457 self.__AddIfUnknown(name_hint) 458 return TypeInfo( 459 type_name=name_hint, variant=messages.Variant.MESSAGE) 460 461 raise ValueError('Unknown type: %s' % type_name) 462 463 def FixupMessageFields(self): 464 for message_type in self.file_descriptor.message_types: 465 self._FixupMessage(message_type) 466 467 def _FixupMessage(self, message_type): 468 with self.__DescriptorEnv(message_type): 469 for field in message_type.fields: 470 if field.field_descriptor.variant == messages.Variant.MESSAGE: 471 field_type_name = field.field_descriptor.type_name 472 field_type = self.LookupDescriptor(field_type_name) 473 if isinstance(field_type, 474 extended_descriptor.ExtendedEnumDescriptor): 475 field.field_descriptor.variant = messages.Variant.ENUM 476 for submessage_type in message_type.message_types: 477 self._FixupMessage(submessage_type) 478