1#!/usr/bin/env python 2# 3# Copyright 2015 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Service registry for apitools.""" 18 19import collections 20import logging 21import re 22import textwrap 23 24from apitools.base.py import base_api 25from apitools.gen import util 26 27# We're a code generator. I don't care. 28# pylint:disable=too-many-statements 29 30_MIME_PATTERN_RE = re.compile(r'(?i)[a-z0-9_*-]+/[a-z0-9_*-]+') 31 32 33class ServiceRegistry(object): 34 35 """Registry for service types.""" 36 37 def __init__(self, client_info, message_registry, 38 names, root_package, base_files_package, 39 unelidable_request_methods): 40 self.__client_info = client_info 41 self.__package = client_info.package 42 self.__names = names 43 self.__service_method_info_map = collections.OrderedDict() 44 self.__message_registry = message_registry 45 self.__root_package = root_package 46 self.__base_files_package = base_files_package 47 self.__unelidable_request_methods = unelidable_request_methods 48 self.__all_scopes = set(self.__client_info.scopes) 49 50 def Validate(self): 51 self.__message_registry.Validate() 52 53 @property 54 def scopes(self): 55 return sorted(list(self.__all_scopes)) 56 57 def __GetServiceClassName(self, service_name): 58 return self.__names.ClassName( 59 '%sService' % self.__names.ClassName(service_name)) 60 61 def __PrintDocstring(self, printer, method_info, method_name, name): 62 """Print a docstring for a service method.""" 63 if method_info.description: 64 description = util.CleanDescription(method_info.description) 65 first_line, newline, remaining = method_info.description.partition( 66 '\n') 67 if not first_line.endswith('.'): 68 first_line = '%s.' % first_line 69 description = '%s%s%s' % (first_line, newline, remaining) 70 else: 71 description = '%s method for the %s service.' % (method_name, name) 72 with printer.CommentContext(): 73 printer('r"""%s' % description) 74 printer() 75 printer('Args:') 76 printer(' request: (%s) input message', method_info.request_type_name) 77 printer(' global_params: (StandardQueryParameters, default: None) ' 78 'global arguments') 79 if method_info.upload_config: 80 printer(' upload: (Upload, default: None) If present, upload') 81 printer(' this stream with the request.') 82 if method_info.supports_download: 83 printer( 84 ' download: (Download, default: None) If present, download') 85 printer(' data from the request via this stream.') 86 printer('Returns:') 87 printer(' (%s) The response message.', method_info.response_type_name) 88 printer('"""') 89 90 def __WriteSingleService( 91 self, printer, name, method_info_map, client_class_name): 92 printer() 93 class_name = self.__GetServiceClassName(name) 94 printer('class %s(base_api.BaseApiService):', class_name) 95 with printer.Indent(): 96 printer('"""Service class for the %s resource."""', name) 97 printer() 98 printer('_NAME = %s', repr(name)) 99 100 # Print the configs for the methods first. 101 printer() 102 printer('def __init__(self, client):') 103 with printer.Indent(): 104 printer('super(%s.%s, self).__init__(client)', 105 client_class_name, class_name) 106 printer('self._upload_configs = {') 107 with printer.Indent(indent=' '): 108 for method_name, method_info in method_info_map.items(): 109 upload_config = method_info.upload_config 110 if upload_config is not None: 111 printer( 112 "'%s': base_api.ApiUploadInfo(", method_name) 113 with printer.Indent(indent=' '): 114 attrs = sorted( 115 x.name for x in upload_config.all_fields()) 116 for attr in attrs: 117 printer('%s=%r,', 118 attr, getattr(upload_config, attr)) 119 printer('),') 120 printer('}') 121 122 # Now write each method in turn. 123 for method_name, method_info in method_info_map.items(): 124 printer() 125 params = ['self', 'request', 'global_params=None'] 126 if method_info.upload_config: 127 params.append('upload=None') 128 if method_info.supports_download: 129 params.append('download=None') 130 printer('def %s(%s):', method_name, ', '.join(params)) 131 with printer.Indent(): 132 self.__PrintDocstring( 133 printer, method_info, method_name, name) 134 printer("config = self.GetMethodConfig('%s')", method_name) 135 upload_config = method_info.upload_config 136 if upload_config is not None: 137 printer("upload_config = self.GetUploadConfig('%s')", 138 method_name) 139 arg_lines = [ 140 'config, request, global_params=global_params'] 141 if method_info.upload_config: 142 arg_lines.append( 143 'upload=upload, upload_config=upload_config') 144 if method_info.supports_download: 145 arg_lines.append('download=download') 146 printer('return self._RunMethod(') 147 with printer.Indent(indent=' '): 148 for line in arg_lines[:-1]: 149 printer('%s,', line) 150 printer('%s)', arg_lines[-1]) 151 printer() 152 printer('{0}.method_config = lambda: base_api.ApiMethodInfo(' 153 .format(method_name)) 154 with printer.Indent(indent=' '): 155 method_info = method_info_map[method_name] 156 attrs = sorted( 157 x.name for x in method_info.all_fields()) 158 for attr in attrs: 159 if attr in ('upload_config', 'description'): 160 continue 161 value = getattr(method_info, attr) 162 if value is not None: 163 printer('%s=%r,', attr, value) 164 printer(')') 165 166 def __WriteProtoServiceDeclaration(self, printer, name, method_info_map): 167 """Write a single service declaration to a proto file.""" 168 printer() 169 printer('service %s {', self.__GetServiceClassName(name)) 170 with printer.Indent(): 171 for method_name, method_info in method_info_map.items(): 172 for line in textwrap.wrap(method_info.description, 173 printer.CalculateWidth() - 3): 174 printer('// %s', line) 175 printer('rpc %s (%s) returns (%s);', 176 method_name, 177 method_info.request_type_name, 178 method_info.response_type_name) 179 printer('}') 180 181 def WriteProtoFile(self, printer): 182 """Write the services in this registry to out as proto.""" 183 self.Validate() 184 client_info = self.__client_info 185 printer('// Generated services for %s version %s.', 186 client_info.package, client_info.version) 187 printer() 188 printer('syntax = "proto2";') 189 printer('package %s;', self.__package) 190 printer('import "%s";', client_info.messages_proto_file_name) 191 printer() 192 for name, method_info_map in self.__service_method_info_map.items(): 193 self.__WriteProtoServiceDeclaration(printer, name, method_info_map) 194 195 def WriteFile(self, printer): 196 """Write the services in this registry to out.""" 197 self.Validate() 198 client_info = self.__client_info 199 printer('"""Generated client library for %s version %s."""', 200 client_info.package, client_info.version) 201 printer('# NOTE: This file is autogenerated and should not be edited ' 202 'by hand.') 203 printer() 204 printer('from __future__ import absolute_import') 205 printer() 206 printer('from %s import base_api', self.__base_files_package) 207 if self.__root_package: 208 import_prefix = 'from {0} '.format(self.__root_package) 209 else: 210 import_prefix = '' 211 printer('%simport %s as messages', import_prefix, 212 client_info.messages_rule_name) 213 printer() 214 printer() 215 printer('class %s(base_api.BaseApiClient):', 216 client_info.client_class_name) 217 with printer.Indent(): 218 printer( 219 '"""Generated client library for service %s version %s."""', 220 client_info.package, client_info.version) 221 printer() 222 printer('MESSAGES_MODULE = messages') 223 printer('BASE_URL = {0!r}'.format(client_info.base_url)) 224 printer('MTLS_BASE_URL = {0!r}'.format(client_info.mtls_base_url)) 225 printer() 226 printer('_PACKAGE = {0!r}'.format(client_info.package)) 227 printer('_SCOPES = {0!r}'.format( 228 client_info.scopes or 229 ['https://www.googleapis.com/auth/userinfo.email'])) 230 printer('_VERSION = {0!r}'.format(client_info.version)) 231 printer('_CLIENT_ID = {0!r}'.format(client_info.client_id)) 232 printer('_CLIENT_SECRET = {0!r}'.format(client_info.client_secret)) 233 printer('_USER_AGENT = {0!r}'.format(client_info.user_agent)) 234 printer('_CLIENT_CLASS_NAME = {0!r}'.format( 235 client_info.client_class_name)) 236 printer('_URL_VERSION = {0!r}'.format(client_info.url_version)) 237 printer('_API_KEY = {0!r}'.format(client_info.api_key)) 238 printer() 239 printer("def __init__(self, url='', credentials=None,") 240 with printer.Indent(indent=' '): 241 printer('get_credentials=True, http=None, model=None,') 242 printer('log_request=False, log_response=False,') 243 printer('credentials_args=None, default_global_params=None,') 244 printer('additional_http_headers=None, ' 245 'response_encoding=None):') 246 with printer.Indent(): 247 printer('"""Create a new %s handle."""', client_info.package) 248 printer('url = url or self.BASE_URL') 249 printer( 250 'super(%s, self).__init__(', client_info.client_class_name) 251 printer(' url, credentials=credentials,') 252 printer(' get_credentials=get_credentials, http=http, ' 253 'model=model,') 254 printer(' log_request=log_request, ' 255 'log_response=log_response,') 256 printer(' credentials_args=credentials_args,') 257 printer(' default_global_params=default_global_params,') 258 printer(' additional_http_headers=additional_http_headers,') 259 printer(' response_encoding=response_encoding)') 260 for name in self.__service_method_info_map.keys(): 261 printer('self.%s = self.%s(self)', 262 name, self.__GetServiceClassName(name)) 263 for name, method_info in self.__service_method_info_map.items(): 264 self.__WriteSingleService( 265 printer, name, method_info, client_info.client_class_name) 266 267 def __RegisterService(self, service_name, method_info_map): 268 if service_name in self.__service_method_info_map: 269 raise ValueError( 270 'Attempt to re-register descriptor %s' % service_name) 271 self.__service_method_info_map[service_name] = method_info_map 272 273 def __CreateRequestType(self, method_description, body_type=None): 274 """Create a request type for this method.""" 275 schema = {} 276 schema['id'] = self.__names.ClassName('%sRequest' % ( 277 self.__names.ClassName(method_description['id'], separator='.'),)) 278 schema['type'] = 'object' 279 schema['properties'] = collections.OrderedDict() 280 if 'parameterOrder' not in method_description: 281 ordered_parameters = list(method_description.get('parameters', [])) 282 else: 283 ordered_parameters = method_description['parameterOrder'][:] 284 for k in method_description['parameters']: 285 if k not in ordered_parameters: 286 ordered_parameters.append(k) 287 for parameter_name in ordered_parameters: 288 field = dict(method_description['parameters'][parameter_name]) 289 if 'type' not in field: 290 raise ValueError('No type found in parameter %s' % field) 291 schema['properties'][parameter_name] = field 292 if body_type is not None: 293 body_field_name = self.__GetRequestField( 294 method_description, body_type) 295 if body_field_name in schema['properties']: 296 raise ValueError('Failed to normalize request resource name') 297 if 'description' not in body_type: 298 body_type['description'] = ( 299 'A %s resource to be passed as the request body.' % ( 300 self.__GetRequestType(body_type),)) 301 schema['properties'][body_field_name] = body_type 302 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) 303 return schema['id'] 304 305 def __CreateVoidResponseType(self, method_description): 306 """Create an empty response type.""" 307 schema = {} 308 method_name = self.__names.ClassName( 309 method_description['id'], separator='.') 310 schema['id'] = self.__names.ClassName('%sResponse' % method_name) 311 schema['type'] = 'object' 312 schema['description'] = 'An empty %s response.' % method_name 313 self.__message_registry.AddDescriptorFromSchema(schema['id'], schema) 314 return schema['id'] 315 316 def __NeedRequestType(self, method_description, request_type): 317 """Determine if this method needs a new request type created.""" 318 if not request_type: 319 return True 320 method_id = method_description.get('id', '') 321 if method_id in self.__unelidable_request_methods: 322 return True 323 message = self.__message_registry.LookupDescriptorOrDie(request_type) 324 if message is None: 325 return True 326 field_names = [x.name for x in message.fields] 327 parameters = method_description.get('parameters', {}) 328 for param_name, param_info in parameters.items(): 329 if (param_info.get('location') != 'path' or 330 self.__names.CleanName(param_name) not in field_names): 331 break 332 else: 333 return False 334 return True 335 336 def __MaxSizeToInt(self, max_size): 337 """Convert max_size to an int.""" 338 size_groups = re.match(r'(?P<size>\d+)(?P<unit>.B)?$', max_size) 339 if size_groups is None: 340 raise ValueError('Could not parse maxSize') 341 size, unit = size_groups.group('size', 'unit') 342 shift = 0 343 if unit is not None: 344 unit_dict = {'KB': 10, 'MB': 20, 'GB': 30, 'TB': 40} 345 shift = unit_dict.get(unit.upper()) 346 if shift is None: 347 raise ValueError('Unknown unit %s' % unit) 348 return int(size) * (1 << shift) 349 350 def __ComputeUploadConfig(self, media_upload_config, method_id): 351 """Fill out the upload config for this method.""" 352 config = base_api.ApiUploadInfo() 353 if 'maxSize' in media_upload_config: 354 config.max_size = self.__MaxSizeToInt( 355 media_upload_config['maxSize']) 356 if 'accept' not in media_upload_config: 357 logging.warning( 358 'No accept types found for upload configuration in ' 359 'method %s, using */*', method_id) 360 config.accept.extend([ 361 str(a) for a in media_upload_config.get('accept', '*/*')]) 362 363 for accept_pattern in config.accept: 364 if not _MIME_PATTERN_RE.match(accept_pattern): 365 logging.warning('Unexpected MIME type: %s', accept_pattern) 366 protocols = media_upload_config.get('protocols', {}) 367 for protocol in ('simple', 'resumable'): 368 media = protocols.get(protocol, {}) 369 for attr in ('multipart', 'path'): 370 if attr in media: 371 setattr(config, '%s_%s' % (protocol, attr), media[attr]) 372 return config 373 374 def __ComputeMethodInfo(self, method_description, request, response, 375 request_field): 376 """Compute the base_api.ApiMethodInfo for this method.""" 377 relative_path = self.__names.NormalizeRelativePath( 378 ''.join((self.__client_info.base_path, 379 method_description['path']))) 380 method_id = method_description['id'] 381 ordered_params = [] 382 for param_name in method_description.get('parameterOrder', []): 383 param_info = method_description['parameters'][param_name] 384 if param_info.get('required', False): 385 ordered_params.append(param_name) 386 method_info = base_api.ApiMethodInfo( 387 relative_path=relative_path, 388 method_id=method_id, 389 http_method=method_description['httpMethod'], 390 description=util.CleanDescription( 391 method_description.get('description', '')), 392 query_params=[], 393 path_params=[], 394 ordered_params=ordered_params, 395 request_type_name=self.__names.ClassName(request), 396 response_type_name=self.__names.ClassName(response), 397 request_field=request_field, 398 ) 399 flat_path = method_description.get('flatPath', None) 400 if flat_path is not None: 401 flat_path = self.__names.NormalizeRelativePath( 402 self.__client_info.base_path + flat_path) 403 if flat_path != relative_path: 404 method_info.flat_path = flat_path 405 if method_description.get('supportsMediaUpload', False): 406 method_info.upload_config = self.__ComputeUploadConfig( 407 method_description.get('mediaUpload'), method_id) 408 method_info.supports_download = method_description.get( 409 'supportsMediaDownload', False) 410 self.__all_scopes.update(method_description.get('scopes', ())) 411 for param, desc in method_description.get('parameters', {}).items(): 412 param = self.__names.CleanName(param) 413 location = desc['location'] 414 if location == 'query': 415 method_info.query_params.append(param) 416 elif location == 'path': 417 method_info.path_params.append(param) 418 else: 419 raise ValueError( 420 'Unknown parameter location %s for parameter %s' % ( 421 location, param)) 422 method_info.path_params.sort() 423 method_info.query_params.sort() 424 return method_info 425 426 def __BodyFieldName(self, body_type): 427 if body_type is None: 428 return '' 429 return self.__names.FieldName(body_type['$ref']) 430 431 def __GetRequestType(self, body_type): 432 return self.__names.ClassName(body_type.get('$ref')) 433 434 def __GetRequestField(self, method_description, body_type): 435 """Determine the request field for this method.""" 436 body_field_name = self.__BodyFieldName(body_type) 437 if body_field_name in method_description.get('parameters', {}): 438 body_field_name = self.__names.FieldName( 439 '%s_resource' % body_field_name) 440 # It's exceedingly unlikely that we'd get two name collisions, which 441 # means it's bound to happen at some point. 442 while body_field_name in method_description.get('parameters', {}): 443 body_field_name = self.__names.FieldName( 444 '%s_body' % body_field_name) 445 return body_field_name 446 447 def AddServiceFromResource(self, service_name, methods): 448 """Add a new service named service_name with the given methods.""" 449 service_name = self.__names.CleanName(service_name) 450 method_descriptions = methods.get('methods', {}) 451 method_info_map = collections.OrderedDict() 452 items = sorted(method_descriptions.items()) 453 for method_name, method_description in items: 454 method_name = self.__names.MethodName(method_name) 455 456 # NOTE: According to the discovery document, if the request or 457 # response is present, it will simply contain a `$ref`. 458 body_type = method_description.get('request') 459 if body_type is None: 460 request_type = None 461 else: 462 request_type = self.__GetRequestType(body_type) 463 if self.__NeedRequestType(method_description, request_type): 464 request = self.__CreateRequestType( 465 method_description, body_type=body_type) 466 request_field = self.__GetRequestField( 467 method_description, body_type) 468 else: 469 request = request_type 470 request_field = base_api.REQUEST_IS_BODY 471 472 if 'response' in method_description: 473 response = method_description['response']['$ref'] 474 else: 475 response = self.__CreateVoidResponseType(method_description) 476 477 method_info_map[method_name] = self.__ComputeMethodInfo( 478 method_description, request, response, request_field) 479 480 nested_services = methods.get('resources', {}) 481 services = sorted(nested_services.items()) 482 for subservice_name, submethods in services: 483 new_service_name = '%s_%s' % (service_name, subservice_name) 484 self.AddServiceFromResource(new_service_name, submethods) 485 486 self.__RegisterService(service_name, method_info_map) 487